GateIO: Add websocket subscription configuration (#1599)

* GateIO: Switch TestMain to using testexch.Setup

* GateIO: Test config updates

* GateIO: Privatise and rename genSubs

* GateIO: Subscription configuration
This commit is contained in:
Gareth Kirwan
2024-10-01 00:39:24 +01:00
committed by GitHub
parent a09fefedd3
commit bfd499f0c9
9 changed files with 261 additions and 252 deletions

View File

@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/core"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
@@ -24,7 +23,9 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions"
"github.com/thrasher-corp/gocryptotrader/portfolio/withdraw"
)
@@ -36,30 +37,20 @@ const (
canManipulateRealOrders = false
)
var g = &Gateio{}
var g *Gateio
func TestMain(m *testing.M) {
g.SetDefaults()
cfg := config.GetConfig()
err := cfg.LoadConfig("../../testdata/configtest.json", true)
if err != nil {
log.Fatal("GateIO load config error", err)
g = new(Gateio)
if err := testexch.Setup(g); err != nil {
log.Fatal(err)
}
gConf, err := cfg.GetExchangeConfig("GateIO")
if err != nil {
log.Fatal("GateIO Setup() init error")
}
gConf.API.AuthenticatedSupport = true
gConf.API.AuthenticatedWebsocketSupport = true
gConf.API.Credentials.Key = apiKey
gConf.API.Credentials.Secret = apiSecret
g.Websocket = sharedtestvalues.NewTestWebsocket()
gConf.Features.Enabled.FillsFeed = true
gConf.Features.Enabled.TradeFeed = true
err = g.Setup(gConf)
if err != nil {
log.Fatal("GateIO setup error", err)
if apiKey != "" && apiSecret != "" {
g.API.AuthenticatedSupport = true
g.API.AuthenticatedWebsocketSupport = true
g.SetCredentials(apiKey, apiSecret, "", "", "", "")
}
os.Exit(m.Run())
}
@@ -2963,12 +2954,66 @@ func TestFuturesCandlestickPushData(t *testing.T) {
}
}
func TestGenerateDefaultSubscriptions(t *testing.T) {
// TestGenerateSubscriptions exercises generateSubscriptions
func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
if _, err := g.GenerateDefaultSubscriptions(); err != nil {
t.Error(err)
g := new(Gateio) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
require.NoError(t, testexch.Setup(g), "Test instance Setup must not error")
g.Websocket.SetCanUseAuthenticatedEndpoints(true)
g.Features.Subscriptions = append(g.Features.Subscriptions, &subscription.Subscription{
Enabled: true, Channel: spotOrderbookChannel, Asset: asset.Spot, Interval: kline.ThousandMilliseconds, Levels: 5,
})
subs, err := g.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions must not error")
exp := subscription.List{}
for _, s := range g.Features.Subscriptions {
for _, a := range g.GetAssetTypes(true) {
if s.Asset != asset.All && s.Asset != a {
continue
}
pairs, err := g.GetEnabledPairs(a)
require.NoErrorf(t, err, "GetEnabledPairs %s must not error", a)
pairs = common.SortStrings(pairs).Format(currency.PairFormat{Uppercase: true, Delimiter: "_"})
s := s.Clone() //nolint:govet // Intentional lexical scope shadow
s.Asset = a
if singleSymbolChannel(channelName(s)) {
for i := range pairs {
s := s.Clone() //nolint:govet // Intentional lexical scope shadow
switch s.Channel {
case subscription.CandlesChannel:
s.QualifiedChannel = "5m," + pairs[i].String()
case subscription.OrderbookChannel:
s.QualifiedChannel = pairs[i].String() + ",100ms"
case spotOrderbookChannel:
s.QualifiedChannel = pairs[i].String() + ",5,1000ms"
}
s.Pairs = pairs[i : i+1]
exp = append(exp, s)
}
} else {
s.Pairs = pairs
s.QualifiedChannel = pairs.Join()
exp = append(exp, s)
}
}
}
testsubs.EqualLists(t, exp, subs)
}
func TestSubscribe(t *testing.T) {
t.Parallel()
g := new(Gateio) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
require.NoError(t, testexch.Setup(g), "Test instance Setup must not error")
subs, err := g.Features.Subscriptions.ExpandTemplates(g)
require.NoError(t, err, "ExpandTemplates must not error")
g.Features.Subscriptions = subscription.List{}
testexch.SetupWs(t, g)
err = g.Subscribe(subs)
require.NoError(t, err, "Subscribe must not error")
}
func TestGenerateDeliveryFuturesDefaultSubscriptions(t *testing.T) {
t.Parallel()
if _, err := g.GenerateDeliveryFuturesDefaultSubscriptions(); err != nil {

View File

@@ -11,8 +11,10 @@ import (
"net/http"
"strconv"
"strings"
"text/template"
"time"
"github.com/Masterminds/sprig/v3"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
@@ -50,14 +52,25 @@ const (
crossMarginLoanChannel = "spot.cross_loan"
)
var defaultSubscriptions = []string{
spotTickerChannel,
spotCandlesticksChannel,
spotOrderbookTickerChannel,
var defaultSubscriptions = subscription.List{
{Enabled: true, Channel: subscription.TickerChannel, Asset: asset.Spot},
{Enabled: true, Channel: subscription.CandlesChannel, Asset: asset.Spot, Interval: kline.FiveMin},
{Enabled: true, Channel: subscription.OrderbookChannel, Asset: asset.Spot, Interval: kline.HundredMilliseconds},
{Enabled: true, Channel: spotBalancesChannel, Asset: asset.Spot, Authenticated: true},
{Enabled: true, Channel: crossMarginBalanceChannel, Asset: asset.CrossMargin, Authenticated: true},
{Enabled: true, Channel: marginBalancesChannel, Asset: asset.Margin, Authenticated: true},
{Enabled: false, Channel: subscription.AllTradesChannel, Asset: asset.Spot},
}
var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool)
var subscriptionNames = map[string]string{
subscription.TickerChannel: spotTickerChannel,
subscription.OrderbookChannel: spotOrderbookUpdateChannel,
subscription.CandlesChannel: spotCandlesticksChannel,
subscription.AllTradesChannel: spotTradesChannel,
}
// WsConnect initiates a websocket connection
func (g *Gateio) WsConnect() error {
if !g.Websocket.IsEnabled() || !g.IsEnabled() {
@@ -86,8 +99,8 @@ func (g *Gateio) WsConnect() error {
return nil
}
func (g *Gateio) generateWsSignature(secret, event, channel string, dtime time.Time) (string, error) {
msg := "channel=" + channel + "&event=" + event + "&time=" + strconv.FormatInt(dtime.Unix(), 10)
func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (string, error) {
msg := "channel=" + channel + "&event=" + event + "&time=" + strconv.FormatInt(t, 10)
mac := hmac.New(sha512.New, []byte(secret))
if _, err := mac.Write([]byte(msg)); err != nil {
return "", err
@@ -628,232 +641,91 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error {
return nil
}
// GenerateDefaultSubscriptions returns default subscriptions
func (g *Gateio) GenerateDefaultSubscriptions() (subscription.List, error) {
channelsToSubscribe := defaultSubscriptions
if g.Websocket.CanUseAuthenticatedEndpoints() {
channelsToSubscribe = append(channelsToSubscribe, []string{
crossMarginBalanceChannel,
marginBalancesChannel,
spotBalancesChannel}...)
}
if g.IsSaveTradeDataEnabled() || g.IsTradeFeedEnabled() {
channelsToSubscribe = append(channelsToSubscribe, spotTradesChannel)
}
var subscriptions subscription.List
var err error
for i := range channelsToSubscribe {
var pairs []currency.Pair
var assetType asset.Item
switch channelsToSubscribe[i] {
case marginBalancesChannel:
assetType = asset.Margin
pairs, err = g.GetEnabledPairs(asset.Margin)
case crossMarginBalanceChannel:
assetType = asset.CrossMargin
pairs, err = g.GetEnabledPairs(asset.CrossMargin)
default:
assetType = asset.Spot
pairs, err = g.GetEnabledPairs(asset.Spot)
}
if err != nil {
if errors.Is(err, asset.ErrNotEnabled) {
continue // Skip if asset is not enabled.
}
return nil, err
}
for j := range pairs {
params := make(map[string]interface{})
switch channelsToSubscribe[i] {
case spotOrderbookChannel:
params["level"] = 100
params["interval"] = kline.HundredMilliseconds
case spotCandlesticksChannel:
params["interval"] = kline.FiveMin
case spotOrderbookUpdateChannel:
params["interval"] = kline.HundredMilliseconds
}
fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Spot)
if err != nil {
return nil, err
}
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channelsToSubscribe[i],
Pairs: currency.Pairs{fpair.Upper()},
Asset: assetType,
Params: params,
})
}
}
return subscriptions, nil
// generateSubscriptions returns configured subscriptions
func (g *Gateio) generateSubscriptions() (subscription.List, error) {
return g.Features.Subscriptions.ExpandTemplates(g)
}
// handleSubscription sends a websocket message to receive data from the channel
func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscription.List) error {
payloads, err := g.generatePayload(event, channelsToSubscribe)
if err != nil {
return err
}
// GetSubscriptionTemplate returns a subscription channel template
func (g *Gateio) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) {
return template.New("master.tmpl").Funcs(sprig.FuncMap()).Funcs(template.FuncMap{
"channelName": channelName,
"singleSymbolChannel": singleSymbolChannel,
"interval": g.GetIntervalString,
}).Parse(subTplText)
}
// manageSubs sends a websocket message to subscribe or unsubscribe from a list of channel
func (g *Gateio) manageSubs(event string, subs subscription.List) error {
var errs error
for k := range payloads {
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[k].ID, payloads[k])
if err != nil {
errs = common.AppendError(errs, err)
continue
}
var resp WsEventResponse
if err = json.Unmarshal(result, &resp); err != nil {
errs = common.AppendError(errs, err)
} else {
if resp.Error != nil && resp.Error.Code != 0 {
errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message))
continue
}
if payloads[k].Event == "subscribe" {
err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
} else {
err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k])
}
subs, errs = subs.ExpandTemplates(g)
if errs != nil {
return errs
}
for _, s := range subs {
if err := func() error {
msg, err := g.manageSubReq(event, s)
if err != nil {
errs = common.AppendError(errs, err)
return err
}
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, msg.ID, msg)
if err != nil {
return err
}
var resp WsEventResponse
if err := json.Unmarshal(result, &resp); err != nil {
return err
}
if resp.Error != nil && resp.Error.Code != 0 {
return fmt.Errorf("(%d) %s", resp.Error.Code, resp.Error.Message)
}
if event == "unsubscribe" {
return g.Websocket.RemoveSubscriptions(s)
}
return g.Websocket.AddSuccessfulSubscriptions(s)
}(); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%s %s %s: %w", s.Channel, s.Asset, s.Pairs, err))
}
}
return errs
}
func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) {
if len(channelsToSubscribe) == 0 {
return nil, errors.New("cannot generate payload, no channels supplied")
// manageSubReq constructs the subscription management message for a subscription
func (g *Gateio) manageSubReq(event string, s *subscription.Subscription) (*WsInput, error) {
req := &WsInput{
ID: g.Websocket.Conn.GenerateMessageID(false),
Event: event,
Channel: channelName(s),
Time: time.Now().Unix(),
Payload: strings.Split(s.QualifiedChannel, ","),
}
var creds *account.Credentials
var err error
if g.Websocket.CanUseAuthenticatedEndpoints() {
creds, err = g.GetCredentials(context.TODO())
if s.Authenticated {
creds, err := g.GetCredentials(context.TODO())
if err != nil {
return nil, err
}
sig, err := g.generateWsSignature(creds.Secret, event, req.Channel, req.Time)
if err != nil {
return nil, err
}
req.Auth = &WsAuthInput{
Method: "api_key",
Key: creds.Key,
Sign: sig,
}
}
var batch *[]string
var intervalString string
payloads := make([]WsInput, 0, len(channelsToSubscribe))
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
var auth *WsAuthInput
timestamp := time.Now()
channelsToSubscribe[i].Pairs[0].Delimiter = currency.UnderscoreDelimiter
params := []string{channelsToSubscribe[i].Pairs[0].String()}
switch channelsToSubscribe[i].Channel {
case spotOrderbookChannel:
interval, okay := channelsToSubscribe[i].Params["interval"].(kline.Interval)
if !okay {
return nil, errors.New("invalid interval parameter")
}
level, okay := channelsToSubscribe[i].Params["level"].(int)
if !okay {
return nil, errors.New("invalid spot order level")
}
intervalString, err = g.GetIntervalString(interval)
if err != nil {
return nil, err
}
params = append(params,
strconv.Itoa(level),
intervalString,
)
case spotCandlesticksChannel:
interval, ok := channelsToSubscribe[i].Params["interval"].(kline.Interval)
if !ok {
return nil, errors.New("missing spot candlesticks interval")
}
intervalString, err = g.GetIntervalString(interval)
if err != nil {
return nil, err
}
params = append(
[]string{intervalString},
params...)
}
switch channelsToSubscribe[i].Channel {
case spotUserTradesChannel,
spotBalancesChannel,
marginBalancesChannel,
spotFundingBalanceChannel,
crossMarginBalanceChannel,
crossMarginLoanChannel:
if !g.Websocket.CanUseAuthenticatedEndpoints() {
continue
}
value, ok := channelsToSubscribe[i].Params["user"].(string)
if ok {
params = append(
[]string{value},
params...)
}
var sigTemp string
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp)
if err != nil {
return nil, err
}
auth = &WsAuthInput{
Method: "api_key",
Key: creds.Key,
Sign: sigTemp,
}
case spotOrderbookUpdateChannel:
interval, ok := channelsToSubscribe[i].Params["interval"].(kline.Interval)
if !ok {
return nil, errors.New("missing spot orderbook interval")
}
intervalString, err = g.GetIntervalString(interval)
if err != nil {
return nil, err
}
params = append(params, intervalString)
}
payload := WsInput{
ID: g.Websocket.Conn.GenerateMessageID(false),
Event: event,
Channel: channelsToSubscribe[i].Channel,
Payload: params,
Auth: auth,
Time: timestamp.Unix(),
}
if channelsToSubscribe[i].Channel == "spot.book_ticker" {
// To get all orderbook assets subscribed it needs to be batched and
// only spot.book_ticker can be batched, if not it will take about
// half an hour for initial sync.
if batch != nil {
*batch = append(*batch, params...)
} else {
// Sets up pointer to the field for the outbound payload.
payloads = append(payloads, payload)
batch = &payloads[len(payloads)-1].Payload
}
continue
}
payloads = append(payloads, payload)
}
return payloads, nil
return req, nil
}
// Subscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) Subscribe(channelsToUnsubscribe subscription.List) error {
return g.handleSubscription("subscribe", channelsToUnsubscribe)
func (g *Gateio) Subscribe(subs subscription.List) error {
return g.manageSubs("subscribe", subs)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleSubscription("unsubscribe", channelsToUnsubscribe)
func (g *Gateio) Unsubscribe(subs subscription.List) error {
return g.manageSubs("unsubscribe", subs)
}
func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.Item]bool {
@@ -870,8 +742,43 @@ func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.
return assetPairEnabled
}
// GenerateWebsocketMessageID generates a message ID for the individual
// connection.
// GenerateWebsocketMessageID generates a message ID for the individual connection
func (g *Gateio) GenerateWebsocketMessageID(bool) int64 {
return g.Counter.IncrementAndGet()
}
// channelName converts global channel names to gateio specific channel names
func channelName(s *subscription.Subscription) string {
if name, ok := subscriptionNames[s.Channel]; ok {
return name
}
return s.Channel
}
// singleSymbolChannel returns if the channel should be fanned out into single symbol requests
func singleSymbolChannel(name string) bool {
switch name {
case spotCandlesticksChannel, spotOrderbookUpdateChannel, spotOrderbookChannel:
return true
}
return false
}
const subTplText = `
{{- with $name := channelName $.S }}
{{- range $asset, $pairs := $.AssetPairs }}
{{- if singleSymbolChannel $name }}
{{- range $i, $p := $pairs -}}
{{- if eq $name "spot.candlesticks" }}{{ interval $.S.Interval -}} , {{- end }}
{{- $p }}
{{- if eq "spot.order_book" $name -}} , {{- $.S.Levels }}{{ end }}
{{- if hasPrefix "spot.order_book" $name -}} , {{- interval $.S.Interval }}{{ end }}
{{- $.PairSeparator }}
{{- end }}
{{- $.AssetSeparator }}
{{- else }}
{{- $pairs.Join }}
{{- end }}
{{- end }}
{{- end }}
`

View File

@@ -146,6 +146,7 @@ func (g *Gateio) SetDefaults() {
GlobalResultLimit: 1000,
},
},
Subscriptions: defaultSubscriptions.Clone(),
}
g.Requester, err = request.New(g.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
@@ -217,7 +218,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
Connector: g.WsConnect,
Subscriber: g.Subscribe,
Unsubscriber: g.Unsubscribe,
GenerateSubscriptions: g.GenerateDefaultSubscriptions,
GenerateSubscriptions: g.generateSubscriptions,
Features: &g.Features.Supports.WebsocketCapabilities,
FillsFeed: g.Features.Enabled.FillsFeed,
TradeFeed: g.Features.Enabled.TradeFeed,

View File

@@ -266,7 +266,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib
params = append([]string{value}, params...)
}
var sigTemp string
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp)
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp.Unix())
if err != nil {
return [2][]WsInput{}, err
}

View File

@@ -351,7 +351,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr
params...)
}
var sigTemp string
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp)
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp.Unix())
if err != nil {
return [2][]WsInput{}, err
}

View File

@@ -238,7 +238,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr
return nil, err
}
var sigTemp string
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp)
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp.Unix())
if err != nil {
return nil, err
}