From 50448ec6a04731e5876884c734dcf634e8a68e08 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Fri, 20 Dec 2024 13:50:31 +1100 Subject: [PATCH] websocket/gateio: Add request functions for websocket multi-connection [SPOT] (#1598) * gateio: Add multi asset websocket support WIP. * meow * Add tests and shenanigans * integrate flushing and for enabling/disabling pairs from rpc shenanigans * some changes * linter: fixes strikes again. * Change name ConnectionAssociation -> ConnectionCandidate for better clarity on purpose. Change connections map to point to candidate to track subscriptions for future dynamic connections holder and drop struct ConnectionDetails. * Add subscription tests (state functional) * glorious:nits + proxy handling * Spelling * linter: fixerino * instead of nil, dont do nil. * clean up nils * cya nils * don't need to set URL or check if its running * stream match update * update tests * linter: fix * glorious: nits + handle context cancellations * stop ping handler routine leak * * Fix bug where reader routine on error that is not a disconnection error but websocket frame error or anything really makes the reader routine return and then connection never cycles and the buffer gets filled. * Handle reconnection via an errors.Is check which is simpler and in that scope allow for quick disconnect reconnect without waiting for connection cycle. * Dial now uses code from DialContext but just calls context.Background() * Don't allow reader to return on parse binary response error. Just output error and return a non nil response * Allow rollback on connect on any error across all connections * fix shadow jutsu * glorious/gk: nitters - adds in ws mock server * linter: fix * fix deadlock on connection as the previous channel had no reader and would hang connection reader for eternity. * glorious: whooops * gk: nits * Leak issue and edge case * Websocket: Add SendMessageReturnResponses * whooooooopsie * gk: nitssssss * Update exchanges/stream/stream_match.go Co-authored-by: Gareth Kirwan * Update exchanges/stream/stream_match_test.go Co-authored-by: Gareth Kirwan * linter: appease the linter gods * gk: nits * gk: drain brain * started * more changes before merge match pr * gateio: still building out * gateio: finish spot * fix up tests in gateio * Add tests for stream package * rm unused field * glorious: nits * rn files, specifically set function names to asset and offload routing to websocket type. * linter: fix * glorious: nits * add counter and update gateio * fix collision issue * Update exchanges/stream/websocket.go Co-authored-by: Scott * glorious: nits * add tests * linter: fix * After merge * Add error connection info * upgrade to upstream merge * Fix edge case where it does not reconnect made by an already closed connection * stream coverage * glorious: nits * glorious: nits removed asset error handling in stream package * linter: fix * rm block * Add basic readme * fix asset enabled flush cycle for multi connection * spella: fix * linter: fix * Add glorious suggestions, fix some race thing * reinstate name before any routine gets spawned * stop on error in mock tests * glorious: nits * glorious: nits found in CI build * Add test for drain, bumped wait times as there seems to be something happening on macos CI builds, used context.WithTimeout because its instant. * mutex across shutdown and connect for protection * lint: fix * test time withoffset, reinstate stop * fix whoops * const trafficCheckInterval; rm testmain * y * fix lint * bump time check window * stream: fix intermittant test failures while testing routines and remove code that is not needed. * spells * cant do what I did * protect race due to routine. * update testURL * use mock websocket connection instead of test URL's * linter: fix * remove url because its throwing errors on CI builds * connections drop all the time, don't need to worry about not being able to echo back ws data as it can be easily reviewed _test file side. * remove another superfluous url thats not really set up for this * spawn overwatch routine when there is no errors, inline checker instead of waiting for a time period, add sleep inline with echo handler as this is really quick and wanted to ensure that latency is handing correctly * linter: fixerino uperino * glorious: panix * linter: things * whoops * dont need to make consecutive Unix() calls * websocket: fix potential panic on error and no responses and adding waitForResponses * rm json parser and handle in json package instead * linter: fix * linter: fix again * * change field name OutboundRequestSignature to WrapperDefinedConnectionSignature for agnostic inbound and outbound connections. * change method name GetOutboundConnection to GetConnection for agnostic inbound and outbound connections. * drop outbound field map for improved performance just using a range and field check (less complex as well) * change field name connections to connectionToWrapper for better clarity * spells and magic and wands * glorious: nits * comparable check for signature * mv err var * glorious: nits and stuff * attempt to fix race * glorious: nits * gk: nits; engine log cleanup * gk: nits; OCD * gk: nits; move function change file names * gk: nits; :rocket: * gk: nits; convert variadic function and message inspection to interface and include a specific function for that handling so as to not need nil on every call * gk: nits; continued * gk: engine nits; rm loaded exchange * gk: nits; drop WebsocketLoginResponse * stream: Add match method EnsureMatchWithData * gk: nits; rn Inspect to IsFinal * gk: nits; rn to MessageFilter * linter: fix * gateio: update rate limit definitions (cherry-pick) * Add test and missing * Shared REST rate limit definitions with Websocket service, set lookup item to nil for systems that do not require rate limiting; add glorious nit * integrate rate limits for websocket trading spot * bitstamp: fix issue * glorious: nits * ch name and commentary * fix bug add test * rm a thing * fix test * Update engine/engine.go Co-authored-by: Adrian Gallagher * thrasher: nits * Update exchanges/stream/stream_match_test.go Co-authored-by: Adrian Gallagher * Update exchanges/stream/stream_match_test.go Co-authored-by: Adrian Gallagher * GK: nits rn websocket functions * explicit function names for single to multi outbound orders * linter: fix --------- Co-authored-by: shazbert Co-authored-by: Gareth Kirwan Co-authored-by: Scott Co-authored-by: Adrian Gallagher --- engine/engine.go | 46 +--- exchanges/bitfinex/bitfinex_websocket.go | 24 +- exchanges/bitmex/bitmex_websocket.go | 5 +- exchanges/bitstamp/bitstamp_websocket.go | 5 +- exchanges/exchange.go | 8 +- exchanges/gateio/gateio_types.go | 13 +- exchanges/gateio/gateio_websocket.go | 72 ++++- ...o => gateio_websocket_delivery_futures.go} | 3 +- ...futures.go => gateio_websocket_futures.go} | 3 +- ...s_option.go => gateio_websocket_option.go} | 3 +- .../gateio/gateio_websocket_request_spot.go | 224 ++++++++++++++++ .../gateio_websocket_request_spot_test.go | 248 ++++++++++++++++++ .../gateio/gateio_websocket_request_types.go | 143 ++++++++++ exchanges/gateio/gateio_wrapper.go | 6 + exchanges/kraken/kraken_websocket.go | 2 +- exchanges/kucoin/kucoin_websocket.go | 8 +- exchanges/stream/stream_match.go | 13 + exchanges/stream/stream_match_test.go | 15 ++ exchanges/stream/stream_types.go | 13 + exchanges/stream/websocket.go | 69 ++++- exchanges/stream/websocket_connection.go | 41 ++- exchanges/stream/websocket_test.go | 88 ++++++- exchanges/stream/websocket_types.go | 2 +- 23 files changed, 951 insertions(+), 103 deletions(-) rename exchanges/gateio/{gateio_ws_delivery_futures.go => gateio_websocket_delivery_futures.go} (98%) rename exchanges/gateio/{gateio_ws_futures.go => gateio_websocket_futures.go} (99%) rename exchanges/gateio/{gateio_ws_option.go => gateio_websocket_option.go} (99%) create mode 100644 exchanges/gateio/gateio_websocket_request_spot.go create mode 100644 exchanges/gateio/gateio_websocket_request_spot_test.go create mode 100644 exchanges/gateio/gateio_websocket_request_types.go diff --git a/engine/engine.go b/engine/engine.go index bd9dba18..8ad14d87 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -791,17 +791,11 @@ func (bot *Engine) LoadExchange(name string) error { localWG.Wait() if !bot.Settings.EnableExchangeHTTPRateLimiter { - gctlog.Warnf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting has been turned off.\n", - exch.GetName(), - ) err = exch.DisableRateLimiter() if err != nil { - gctlog.Errorf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting cannot be turned off: %s.\n", - exch.GetName(), - err, - ) + gctlog.Errorf(gctlog.ExchangeSys, "%s error disabling rate limiter: %v", exch.GetName(), err) + } else { + gctlog.Warnf(gctlog.ExchangeSys, "%s rate limiting has been turned off", exch.GetName()) } } @@ -820,29 +814,18 @@ func (bot *Engine) LoadExchange(name string) error { return err } - base := exch.GetBase() - if base.API.AuthenticatedSupport || - base.API.AuthenticatedWebsocketSupport { - assetTypes := base.GetAssetTypes(false) - var useAsset asset.Item - for a := range assetTypes { - err = base.CurrencyPairs.IsAssetEnabled(assetTypes[a]) - if err != nil { - continue - } - useAsset = assetTypes[a] - break - } - err = exch.ValidateAPICredentials(context.TODO(), useAsset) + b := exch.GetBase() + if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { + err = exch.ValidateAPICredentials(context.TODO(), asset.Spot) if err != nil { - gctlog.Warnf(gctlog.ExchangeSys, - "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s\n", - base.Name, - err) - base.API.AuthenticatedSupport = false - base.API.AuthenticatedWebsocketSupport = false + gctlog.Warnf(gctlog.ExchangeSys, "%s: Error validating credentials: %v", b.Name, err) + b.API.AuthenticatedSupport = false + b.API.AuthenticatedWebsocketSupport = false exchCfg.API.AuthenticatedSupport = false exchCfg.API.AuthenticatedWebsocketSupport = false + if b.Websocket != nil { + b.Websocket.SetCanUseAuthenticatedEndpoints(false) + } } } @@ -854,10 +837,7 @@ func (bot *Engine) dryRunParamInteraction(param string) { return } - gctlog.Warnf(gctlog.Global, - "Command line argument '-%s' induces dry run mode."+ - " Set -dryrun=false if you wish to override this.", - param) + gctlog.Warnf(gctlog.Global, "Command line argument '-%s' induces dry run mode. Set -dryrun=false if you wish to override this.", param) if !bot.Settings.EnableDryRun { bot.Settings.EnableDryRun = true diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index 74c04afc..ee1b854b 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -456,17 +456,20 @@ func (b *Bitfinex) handleWSEvent(respRaw []byte) error { if err != nil { return fmt.Errorf("%w 'chanId': %w from message: %s", errParsingWSField, err, respRaw) } - if !b.Websocket.Match.IncomingWithData("unsubscribe:"+chanID, respRaw) { - return fmt.Errorf("%w: unsubscribe:%v", stream.ErrNoMessageListener, chanID) + err = b.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) + if err != nil { + return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } case wsEventError: if subID, err := jsonparser.GetUnsafeString(respRaw, "subId"); err == nil { - if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { - return fmt.Errorf("%w: subscribe:%v", stream.ErrNoMessageListener, subID) + err = b.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) + if err != nil { + return fmt.Errorf("%w: subscribe:%v", err, subID) } } else if chanID, err := jsonparser.GetUnsafeString(respRaw, "chanId"); err == nil { - if !b.Websocket.Match.IncomingWithData("unsubscribe:"+chanID, respRaw) { - return fmt.Errorf("%w: unsubscribe:%v", stream.ErrNoMessageListener, chanID) + err = b.Websocket.Match.RequireMatchWithData("unsubscribe:"+chanID, respRaw) + if err != nil { + return fmt.Errorf("%w: unsubscribe:%v", err, chanID) } } else { return fmt.Errorf("unknown channel error; Message: %s", respRaw) @@ -531,17 +534,16 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c.Key = int(chanID) // subscribeToChan removes the old subID keyed Subscription - if err := b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c); err != nil { + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c) + if err != nil { return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) } if b.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } - if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { - return fmt.Errorf("%w: subscribe:%v", stream.ErrNoMessageListener, subID) - } - return nil + + return b.Websocket.Match.RequireMatchWithData("subscribe:"+subID, respRaw) } func (b *Bitfinex) handleWSChannelUpdate(s *subscription.Subscription, eventType string, d []interface{}) error { diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index f539787c..d322d5a6 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -170,8 +170,9 @@ func (b *Bitmex) wsHandleData(respRaw []byte) error { if e2 != nil { return fmt.Errorf("%w parsing stream", e2) } - if !b.Websocket.Match.IncomingWithData(op+":"+streamID, msg) { - return fmt.Errorf("%w: %s:%s", stream.ErrNoMessageListener, op, streamID) + err = b.Websocket.Match.RequireMatchWithData(op+":"+streamID, msg) + if err != nil { + return fmt.Errorf("%w: %s:%s", err, op, streamID) } return nil } diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index de45f2fb..8fe4f2fc 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -135,10 +135,7 @@ func (b *Bitstamp) handleWSSubscription(event string, respRaw []byte) error { return fmt.Errorf("%w `channel`: %w", errParsingWSField, err) } event = strings.TrimSuffix(event, "scription_succeeded") - if !b.Websocket.Match.IncomingWithData(event+":"+channel, respRaw) { - return fmt.Errorf("%w: %s", stream.ErrNoMessageListener, event+":"+channel) - } - return nil + return b.Websocket.Match.RequireMatchWithData(event+":"+channel, respRaw) } func (b *Bitstamp) handleWSTrade(msg []byte) error { diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 938923e2..4369b1dd 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -975,8 +975,7 @@ func (b *Base) SupportsAsset(a asset.Item) bool { // PrintEnabledPairs prints the exchanges enabled asset pairs func (b *Base) PrintEnabledPairs() { for k, v := range b.CurrencyPairs.Pairs { - log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", - b.Name, strings.ToUpper(k.String()), v.Enabled) + log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", b.Name, strings.ToUpper(k.String()), v.Enabled) } } @@ -987,10 +986,7 @@ func (b *Base) GetBase() *Base { return b } // for validation of API credentials func (b *Base) CheckTransientError(err error) error { if _, ok := err.(net.Error); ok { - log.Warnf(log.ExchangeSys, - "%s net error captured, will not disable authentication %s", - b.Name, - err) + log.Warnf(log.ExchangeSys, "%s net error captured, will not disable authentication %s", b.Name, err) return nil } return err diff --git a/exchanges/gateio/gateio_types.go b/exchanges/gateio/gateio_types.go index 80a2c8b8..0577f6d2 100644 --- a/exchanges/gateio/gateio_types.go +++ b/exchanges/gateio/gateio_types.go @@ -2008,12 +2008,13 @@ type WsEventResponse struct { // WsResponse represents generalized websocket push data from the server. type WsResponse struct { - ID int64 `json:"id"` - Time types.Time `json:"time"` - TimeMs types.Time `json:"time_ms"` - Channel string `json:"channel"` - Event string `json:"event"` - Result json.RawMessage `json:"result"` + ID int64 `json:"id"` + Time types.Time `json:"time"` + TimeMs types.Time `json:"time_ms"` + Channel string `json:"channel"` + Event string `json:"event"` + Result json.RawMessage `json:"result"` + RequestID string `json:"request_id"` } // WsTicker websocket ticker information. diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 932768ec..ea53f120 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -97,6 +97,64 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro return nil } +// authenticateSpot sends an authentication message to the websocket connection +func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) error { + return g.websocketLogin(ctx, conn, "spot.login") +} + +// websocketLogin authenticates the websocket connection +func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) error { + if conn == nil { + return fmt.Errorf("%w: %T", common.ErrNilPointer, conn) + } + + if channel == "" { + return errChannelEmpty + } + + creds, err := g.GetCredentials(ctx) + if err != nil { + return err + } + + tn := time.Now().Unix() + msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) + mac := hmac.New(sha512.New, []byte(creds.Secret)) + if _, err = mac.Write([]byte(msg)); err != nil { + return err + } + signature := hex.EncodeToString(mac.Sum(nil)) + + payload := WebsocketPayload{ + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + APIKey: creds.Key, + Signature: signature, + Timestamp: strconv.FormatInt(tn, 10), + } + + req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} + + resp, err := conn.SendMessageReturnResponse(ctx, websocketRateLimitNotNeededEPL, req.Payload.RequestID, req) + if err != nil { + return err + } + + var inbound WebsocketAPIResponse + if err := json.Unmarshal(resp, &inbound); err != nil { + return err + } + + if inbound.Header.Status != "200" { + var wsErr WebsocketErrors + if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { + return err + } + return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + return nil +} + func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (string, error) { msg := "channel=" + channel + "&event=" + event + "&time=" + strconv.FormatInt(t, 10) mac := hmac.New(sha512.New, []byte(secret)) @@ -109,21 +167,21 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (st // WsHandleSpotData handles spot data func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { var push WsResponse - err := json.Unmarshal(respRaw, &push) - if err != nil { + if err := json.Unmarshal(respRaw, &push); err != nil { return err } + if push.RequestID != "" { + return g.Websocket.Match.RequireMatchWithData(push.RequestID, respRaw) + } + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { - return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) - } - return nil + return g.Websocket.Match.RequireMatchWithData(push.ID, respRaw) } switch push.Channel { // TODO: Convert function params below to only use push.Result case spotTickerChannel: - return g.processTicker(push.Result, push.Time.Time()) + return g.processTicker(push.Result, push.TimeMs.Time()) case spotTradesChannel: return g.processTrades(push.Result) case spotCandlesticksChannel: diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_websocket_delivery_futures.go similarity index 98% rename from exchanges/gateio/gateio_ws_delivery_futures.go rename to exchanges/gateio/gateio_websocket_delivery_futures.go index d4e1c8ab..eeb1a706 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_websocket_delivery_futures.go @@ -13,7 +13,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) @@ -55,7 +54,7 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_websocket_futures.go similarity index 99% rename from exchanges/gateio/gateio_ws_futures.go rename to exchanges/gateio/gateio_websocket_futures.go index 520eb816..ba393277 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -19,7 +19,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" @@ -76,7 +75,7 @@ func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, MessageType: websocket.PingMessage, Delay: time.Second * 15, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_websocket_option.go similarity index 99% rename from exchanges/gateio/gateio_ws_option.go rename to exchanges/gateio/gateio_websocket_option.go index 40bc63fe..091a9ad1 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_websocket_option.go @@ -18,7 +18,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" @@ -85,7 +84,7 @@ func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go new file mode 100644 index 00000000..25ac1ea1 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -0,0 +1,224 @@ +package gateio + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" +) + +var ( + errOrdersEmpty = errors.New("orders cannot be empty") + errNoOrdersToCancel = errors.New("no orders to cancel") + errChannelEmpty = errors.New("channel cannot be empty") +) + +// WebsocketSpotSubmitOrder submits an order via the websocket connection +func (g *Gateio) WebsocketSpotSubmitOrder(ctx context.Context, order *WebsocketOrder) ([]WebsocketOrderResponse, error) { + return g.WebsocketSpotSubmitOrders(ctx, []WebsocketOrder{*order}) +} + +// WebsocketSpotSubmitOrders submits orders via the websocket connection. You can +// send multiple orders in a single request. But only for one asset route. +func (g *Gateio) WebsocketSpotSubmitOrders(ctx context.Context, orders []WebsocketOrder) ([]WebsocketOrderResponse, error) { + if len(orders) == 0 { + return nil, errOrdersEmpty + } + + for i := range orders { + if orders[i].Text == "" { + // API requires Text field, or it will be rejected + orders[i].Text = "t-" + strconv.FormatInt(g.Counter.IncrementAndGet(), 10) + } + if orders[i].CurrencyPair == "" { + return nil, currency.ErrCurrencyPairEmpty + } + if orders[i].Side == "" { + return nil, order.ErrSideIsInvalid + } + if orders[i].Amount == "" { + return nil, errInvalidAmount + } + if orders[i].Type == "limit" && orders[i].Price == "" { + return nil, errInvalidPrice + } + } + + if len(orders) == 1 { + var singleResponse WebsocketOrderResponse + return []WebsocketOrderResponse{singleResponse}, g.SendWebsocketRequest(ctx, spotPlaceOrderEPL, "spot.order_place", asset.Spot, orders[0], &singleResponse, 2) + } + var resp []WebsocketOrderResponse + return resp, g.SendWebsocketRequest(ctx, spotBatchOrdersEPL, "spot.order_place", asset.Spot, orders, &resp, 2) +} + +// WebsocketSpotCancelOrder cancels an order via the websocket connection +func (g *Gateio) WebsocketSpotCancelOrder(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} + + var resp WebsocketOrderResponse + return &resp, g.SendWebsocketRequest(ctx, spotCancelSingleOrderEPL, "spot.order_cancel", asset.Spot, params, &resp, 1) +} + +// WebsocketSpotCancelAllOrdersByIDs cancels multiple orders via the websocket +func (g *Gateio) WebsocketSpotCancelAllOrdersByIDs(ctx context.Context, o []WebsocketOrderBatchRequest) ([]WebsocketCancellAllResponse, error) { + if len(o) == 0 { + return nil, errNoOrdersToCancel + } + + for i := range o { + if o[i].OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + if o[i].Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + } + + var resp []WebsocketCancellAllResponse + return resp, g.SendWebsocketRequest(ctx, spotCancelBatchOrdersEPL, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) +} + +// WebsocketSpotCancelAllOrdersByPair cancels all orders for a specific pair +func (g *Gateio) WebsocketSpotCancelAllOrdersByPair(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { + if !pair.IsEmpty() && side == order.UnknownSide { + // This case will cancel all orders for every pair, this can be introduced later + return nil, fmt.Errorf("'%v' %w while pair is set", side, order.ErrSideIsInvalid) + } + + sideStr := "" + if side != order.UnknownSide { + sideStr = side.Lower() + } + + params := &WebsocketCancelParam{ + Pair: pair, + Side: sideStr, + Account: account, + } + + var resp []WebsocketOrderResponse + return resp, g.SendWebsocketRequest(ctx, spotCancelAllOpenOrdersEPL, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) +} + +// WebsocketSpotAmendOrder amends an order via the websocket connection +func (g *Gateio) WebsocketSpotAmendOrder(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { + if amend == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) + } + + if amend.OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + + if amend.Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if amend.Amount == "" && amend.Price == "" { + return nil, fmt.Errorf("%w: amount or price must be set", errInvalidAmount) + } + + var resp WebsocketOrderResponse + return &resp, g.SendWebsocketRequest(ctx, spotAmendOrderEPL, "spot.order_amend", asset.Spot, amend, &resp, 1) +} + +// WebsocketSpotGetOrderStatus gets the status of an order via the websocket connection +func (g *Gateio) WebsocketSpotGetOrderStatus(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} + + var resp WebsocketOrderResponse + return &resp, g.SendWebsocketRequest(ctx, spotGetOrdersEPL, "spot.order_status", asset.Spot, params, &resp, 1) +} + +// funnelResult is used to unmarshal the result of a websocket request back to the required caller type +type funnelResult struct { + Result any `json:"result"` +} + +// SendWebsocketRequest sends a websocket request to the exchange +func (g *Gateio) SendWebsocketRequest(ctx context.Context, epl request.EndpointLimit, channel string, connSignature, params, result any, expectedResponses int) error { + paramPayload, err := json.Marshal(params) + if err != nil { + return err + } + + conn, err := g.Websocket.GetConnection(connSignature) + if err != nil { + return err + } + + tn := time.Now().Unix() + req := &WebsocketRequest{ + Time: tn, + Channel: channel, + Event: "api", + Payload: WebsocketPayload{ + // This request ID associated with the payload is the match to the + // response. + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + RequestParam: paramPayload, + Timestamp: strconv.FormatInt(tn, 10), + }, + } + + responses, err := conn.SendMessageReturnResponsesWithInspector(ctx, epl, req.Payload.RequestID, req, expectedResponses, wsRespAckInspector{}) + if err != nil { + return err + } + + if len(responses) == 0 { + return common.ErrNoResponse + } + + var inbound WebsocketAPIResponse + // The last response is the one we want to unmarshal, the other is just + // an ack. If the request fails on the ACK then we can unmarshal the error + // from that as the next response won't come anyway. + endResponse := responses[len(responses)-1] + + if err := json.Unmarshal(endResponse, &inbound); err != nil { + return err + } + + if inbound.Header.Status != "200" { + var wsErr WebsocketErrors + if err := json.Unmarshal(inbound.Data, &wsErr); err != nil { + return err + } + return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + return json.Unmarshal(inbound.Data, &funnelResult{Result: result}) +} + +type wsRespAckInspector struct{} + +// IsFinal checks the payload for an ack, it returns true if the payload does not contain an ack. +// This will force the cancellation of further waiting for responses. +func (wsRespAckInspector) IsFinal(data []byte) bool { + return !strings.Contains(string(data), "ack") +} diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go new file mode 100644 index 00000000..7933c117 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -0,0 +1,248 @@ +package gateio + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" +) + +func TestWebsocketLogin(t *testing.T) { + t.Parallel() + err := g.websocketLogin(context.Background(), nil, "") + require.ErrorIs(t, err, common.ErrNilPointer) + + err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") + require.ErrorIs(t, err, errChannelEmpty) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) + require.NoError(t, err) + + err = g.websocketLogin(context.Background(), demonstrationConn, "spot.login") + require.NoError(t, err) +} + +func TestWebsocketSpotSubmitOrder(t *testing.T) { + t.Parallel() + _, err := g.WebsocketSpotSubmitOrder(context.Background(), &WebsocketOrder{}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + out := &WebsocketOrder{CurrencyPair: "BTC_USDT"} + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, order.ErrSideIsInvalid) + out.Side = strings.ToLower(order.Buy.String()) + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, errInvalidAmount) + out.Amount = "0.0003" + out.Type = "limit" + _, err = g.WebsocketSpotSubmitOrder(context.Background(), out) + require.ErrorIs(t, err, errInvalidPrice) + out.Price = "20000" + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketSpotSubmitOrder(context.Background(), out) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotSubmitOrders(t *testing.T) { + t.Parallel() + _, err := g.WebsocketSpotSubmitOrders(context.Background(), nil) + require.ErrorIs(t, err, errOrdersEmpty) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), make([]WebsocketOrder, 1)) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + out := WebsocketOrder{CurrencyPair: "BTC_USDT"} + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) + require.ErrorIs(t, err, order.ErrSideIsInvalid) + out.Side = strings.ToLower(order.Buy.String()) + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) + require.ErrorIs(t, err, errInvalidAmount) + out.Amount = "0.0003" + out.Type = "limit" + _, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) + require.ErrorIs(t, err, errInvalidPrice) + out.Price = "20000" + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + // test single order + got, err := g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out}) + require.NoError(t, err) + require.NotEmpty(t, got) + + // test batch orders + got, err = g.WebsocketSpotSubmitOrders(context.Background(), []WebsocketOrder{out, out}) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotCancelOrder(t *testing.T) { + t.Parallel() + _, err := g.WebsocketSpotCancelOrder(context.Background(), "", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + _, err = g.WebsocketSpotCancelOrder(context.Background(), "1337", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + btcusdt, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketSpotCancelOrder(context.Background(), "644913098758", btcusdt, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotCancelAllOrdersByIDs(t *testing.T) { + t.Parallel() + _, err := g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{}) + require.ErrorIs(t, err, errNoOrdersToCancel) + out := WebsocketOrderBatchRequest{} + _, err = g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + out.OrderID = "1337" + _, err = g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + out.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + out.OrderID = "644913101755" + got, err := g.WebsocketSpotCancelAllOrdersByIDs(context.Background(), []WebsocketOrderBatchRequest{out}) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotCancelAllOrdersByPair(t *testing.T) { + t.Parallel() + pair, err := currency.NewPairFromString("LTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketSpotCancelAllOrdersByPair(context.Background(), pair, 0, "") + require.ErrorIs(t, err, order.ErrSideIsInvalid) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketSpotCancelAllOrdersByPair(context.Background(), currency.EMPTYPAIR, order.Buy, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotAmendOrder(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketSpotAmendOrder(context.Background(), nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + amend := &WebsocketAmendOrder{} + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + amend.OrderID = "1337" + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + amend.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketSpotAmendOrder(context.Background(), amend) + require.ErrorIs(t, err, errInvalidAmount) + + amend.Amount = "0.0004" + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + amend.OrderID = "645029162673" + got, err := g.WebsocketSpotAmendOrder(context.Background(), amend) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketSpotGetOrderStatus(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketSpotGetOrderStatus(context.Background(), "", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + _, err = g.WebsocketSpotGetOrderStatus(context.Background(), "1337", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + got, err := g.WebsocketSpotGetOrderStatus(context.Background(), "644999650452", pair, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +// getWebsocketInstance returns a websocket instance copy for testing. +// This restricts the pairs to a single pair per asset type to reduce test time. +func getWebsocketInstance(t *testing.T, g *Gateio) *Gateio { + t.Helper() + + cpy := new(Gateio) + cpy.SetDefaults() + gConf, err := config.GetConfig().GetExchangeConfig("GateIO") + require.NoError(t, err) + gConf.API.AuthenticatedSupport = true + gConf.API.AuthenticatedWebsocketSupport = true + gConf.API.Credentials.Key = apiKey + gConf.API.Credentials.Secret = apiSecret + + require.NoError(t, cpy.Setup(gConf), "Test instance Setup must not error") + cpy.CurrencyPairs.Load(&g.CurrencyPairs) + + for _, a := range cpy.GetAssetTypes(true) { + if a != asset.Spot { + require.NoError(t, cpy.CurrencyPairs.SetAssetEnabled(a, false)) + continue + } + avail, err := cpy.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { + avail = avail[:1] + } + require.NoError(t, cpy.SetPairs(avail, a, true)) + } + require.NoError(t, cpy.Websocket.Connect()) + return cpy +} diff --git a/exchanges/gateio/gateio_websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go new file mode 100644 index 00000000..165eea41 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_types.go @@ -0,0 +1,143 @@ +package gateio + +import ( + "encoding/json" + + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/types" +) + +// WebsocketAPIResponse defines a general websocket response for api calls +type WebsocketAPIResponse struct { + Header Header `json:"header"` + Data json.RawMessage `json:"data"` +} + +// Header defines a websocket header +type Header struct { + ResponseTime types.Time `json:"response_time"` + Status string `json:"status"` + Channel string `json:"channel"` + Event string `json:"event"` + ClientID string `json:"client_id"` + ConnectionID string `json:"conn_id"` + TraceID string `json:"trace_id"` +} + +// WebsocketRequest defines a websocket request +type WebsocketRequest struct { + Time int64 `json:"time,omitempty"` + ID int64 `json:"id,omitempty"` + Channel string `json:"channel"` + Event string `json:"event"` + Payload WebsocketPayload `json:"payload"` +} + +// WebsocketPayload defines an individualised websocket payload +type WebsocketPayload struct { + RequestID string `json:"req_id,omitempty"` + // APIKey and signature are only required in the initial login request + // which is done when the connection is established. + APIKey string `json:"api_key,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + Signature string `json:"signature,omitempty"` + RequestParam json.RawMessage `json:"req_param,omitempty"` +} + +// WebsocketErrors defines a websocket error +type WebsocketErrors struct { + Errors struct { + Label string `json:"label"` + Message string `json:"message"` + } `json:"errs"` +} + +// WebsocketOrder defines a websocket order +type WebsocketOrder struct { + Text string `json:"text"` + CurrencyPair string `json:"currency_pair,omitempty"` + Type string `json:"type,omitempty"` + Account string `json:"account,omitempty"` + Side string `json:"side,omitempty"` + Amount string `json:"amount,omitempty"` + Price string `json:"price,omitempty"` + TimeInForce string `json:"time_in_force,omitempty"` + Iceberg string `json:"iceberg,omitempty"` + AutoBorrow bool `json:"auto_borrow,omitempty"` + AutoRepay bool `json:"auto_repay,omitempty"` + StpAct string `json:"stp_act,omitempty"` +} + +// WebsocketOrderResponse defines a websocket order response +type WebsocketOrderResponse struct { + Left types.Number `json:"left"` + UpdateTime types.Time `json:"update_time"` + Amount types.Number `json:"amount"` + CreateTime types.Time `json:"create_time"` + Price types.Number `json:"price"` + FinishAs string `json:"finish_as"` + TimeInForce string `json:"time_in_force"` + CurrencyPair currency.Pair `json:"currency_pair"` + Type string `json:"type"` + Account string `json:"account"` + Side string `json:"side"` + AmendText string `json:"amend_text"` + Text string `json:"text"` + Status string `json:"status"` + Iceberg types.Number `json:"iceberg"` + FilledTotal types.Number `json:"filled_total"` + ID string `json:"id"` + FillPrice types.Number `json:"fill_price"` + UpdateTimeMs types.Time `json:"update_time_ms"` + CreateTimeMs types.Time `json:"create_time_ms"` + Fee types.Number `json:"fee"` + FeeCurrency currency.Code `json:"fee_currency"` + PointFee types.Number `json:"point_fee"` + GTFee types.Number `json:"gt_fee"` + GTMakerFee types.Number `json:"gt_maker_fee"` + GTTakerFee types.Number `json:"gt_taker_fee"` + GTDiscount bool `json:"gt_discount"` + RebatedFee types.Number `json:"rebated_fee"` + RebatedFeeCurrency currency.Code `json:"rebated_fee_currency"` + STPID int `json:"stp_id"` + STPAct string `json:"stp_act"` +} + +// WebsocketOrderBatchRequest defines a websocket order batch request +type WebsocketOrderBatchRequest struct { + OrderID string `json:"id"` // This require id tag not order_id + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` +} + +// WebsocketOrderRequest defines a websocket order request +type WebsocketOrderRequest struct { + OrderID string `json:"order_id"` // This requires order_id tag + Pair string `json:"pair"` + Account string `json:"account,omitempty"` +} + +// WebsocketCancellAllResponse defines a websocket order cancel response +type WebsocketCancellAllResponse struct { + Pair currency.Pair `json:"currency_pair"` + Label string `json:"label"` + Message string `json:"message"` + Succeeded bool `json:"succeeded"` +} + +// WebsocketCancelParam is a struct to hold the parameters for cancelling orders +type WebsocketCancelParam struct { + Pair currency.Pair `json:"pair"` + Side string `json:"side"` + Account string `json:"account,omitempty"` +} + +// WebsocketAmendOrder defines a websocket amend order +type WebsocketAmendOrder struct { + OrderID string `json:"order_id"` + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` + AmendText string `json:"amend_text,omitempty"` + Price string `json:"price,omitempty"` + Amount string `json:"amount,omitempty"` +} diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index dad42e4d..9d9929f6 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -218,6 +218,8 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.Unsubscribe, GenerateSubscriptions: g.generateSubscriptionsSpot, Connector: g.WsConnectSpot, + Authenticate: g.authenticateSpot, + MessageFilter: asset.Spot, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { @@ -235,6 +237,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.FuturesUnsubscribe, GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, Connector: g.WsFuturesConnect, + MessageFilter: asset.USDTMarginedFutures, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { @@ -253,6 +256,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.FuturesUnsubscribe, GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, Connector: g.WsFuturesConnect, + MessageFilter: asset.CoinMarginedFutures, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { @@ -272,6 +276,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.DeliveryFuturesUnsubscribe, GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, Connector: g.WsDeliveryFuturesConnect, + MessageFilter: asset.DeliveryFutures, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { @@ -288,6 +293,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Unsubscriber: g.OptionsUnsubscribe, GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, Connector: g.WsOptionsConnect, + MessageFilter: asset.Options, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 22088590..5b276055 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -231,7 +231,7 @@ func (k *Kraken) wsHandleData(respRaw []byte) error { return nil case krakenWsCancelOrderStatus, krakenWsCancelAllOrderStatus, krakenWsAddOrderStatus, krakenWsSubscriptionStatus: // All of these should have found a listener already - return fmt.Errorf("%w: %s %v", stream.ErrNoMessageListener, event, reqID) + return fmt.Errorf("%w: %s %v", stream.ErrSignatureNotMatched, event, reqID) case krakenWsSystemStatus: return k.wsProcessSystemStatus(respRaw) default: diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 8b4aa535..e364b2b8 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -215,18 +215,14 @@ func (ku *Kucoin) wsReadData() { // wsHandleData processes a websocket incoming data. func (ku *Kucoin) wsHandleData(respData []byte) error { resp := WsPushData{} - err := json.Unmarshal(respData, &resp) - if err != nil { + if err := json.Unmarshal(respData, &resp); err != nil { return err } if resp.Type == "pong" || resp.Type == "welcome" { return nil } if resp.ID != "" { - if !ku.Websocket.Match.IncomingWithData("msgID:"+resp.ID, respData) { - return fmt.Errorf("%w: %s", stream.ErrNoMessageListener, resp.ID) - } - return nil + return ku.Websocket.Match.RequireMatchWithData("msgID:"+resp.ID, respData) } topicInfo := strings.Split(resp.Topic, ":") switch topicInfo[0] { diff --git a/exchanges/stream/stream_match.go b/exchanges/stream/stream_match.go index a7b8a10b..430688a8 100644 --- a/exchanges/stream/stream_match.go +++ b/exchanges/stream/stream_match.go @@ -2,9 +2,13 @@ package stream import ( "errors" + "fmt" "sync" ) +// ErrSignatureNotMatched is returned when a signature does not match a request +var ErrSignatureNotMatched = errors.New("websocket response to request signature not matched") + var ( errSignatureCollision = errors.New("signature collision") errInvalidBufferSize = errors.New("buffer size must be positive") @@ -47,6 +51,15 @@ func (m *Match) IncomingWithData(signature any, data []byte) bool { return true } +// RequireMatchWithData validates that incoming data matches a request's signature. +// If a match is found, the data is processed; otherwise, it returns an error. +func (m *Match) RequireMatchWithData(signature any, data []byte) error { + if m.IncomingWithData(signature, data) { + return nil + } + return fmt.Errorf("'%v' %w with data %v", signature, ErrSignatureNotMatched, string(data)) +} + // Set the signature response channel for incoming data func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) { if bufSize <= 0 { diff --git a/exchanges/stream/stream_match_test.go b/exchanges/stream/stream_match_test.go index b7a21b23..8053cb37 100644 --- a/exchanges/stream/stream_match_test.go +++ b/exchanges/stream/stream_match_test.go @@ -51,3 +51,18 @@ func TestRemoveSignature(t *testing.T) { t.Fatal("Should be able to read from a closed channel") } } + +func TestRequireMatchWithData(t *testing.T) { + t.Parallel() + match := NewMatch() + err := match.RequireMatchWithData("hello", []byte("world")) + require.ErrorIs(t, err, ErrSignatureNotMatched, "Must error on unmatched signature") + assert.Contains(t, err.Error(), "world", "Should contain the data in the error message") + assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message") + + ch, err := match.Set("hello", 1) + require.NoError(t, err, "Set must not error") + err = match.RequireMatchWithData("hello", []byte("world")) + require.NoError(t, err, "Must not error on matched signature") + assert.Equal(t, "world", string(<-ch)) +} diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 2cbf0a2f..832e74c5 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -27,6 +27,8 @@ type Connection interface { SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error) // SendMessageReturnResponses will send a WS message to the connection and wait for N responses SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int) ([][]byte, error) + // SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection + SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, messageInspector Inspector) ([][]byte, error) // SendRawMessage sends a message over the connection without JSON encoding it SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error // SendJSONMessage sends a JSON encoded message over the connection @@ -37,6 +39,12 @@ type Connection interface { Shutdown() error } +// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection +// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages +type Inspector interface { + IsFinal([]byte) bool +} + // Response defines generalised data from the stream connection type Response struct { Type int @@ -76,6 +84,11 @@ type ConnectionSetup struct { // This is useful for when an exchange connection requires a unique or // structured message ID for each message sent. BespokeGenerateMessageID func(highPrecision bool) int64 + // Authenticate will be called to authenticate the connection + Authenticate func(ctx context.Context, conn Connection) error + // MessageFilter defines the criteria used to match messages to a specific connection. + // The filter enables precise routing and handling of messages for distinct connection contexts. + MessageFilter any } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 309db9a7..737c0ead 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "reflect" "slices" "sync" "time" @@ -27,8 +28,10 @@ var ( ErrUnsubscribeFailure = errors.New("unsubscribe failure") ErrAlreadyDisabled = errors.New("websocket already disabled") ErrNotConnected = errors.New("websocket is not connected") - ErrNoMessageListener = errors.New("websocket listener not found for message") ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") + ErrRequestRouteNotFound = errors.New("request route not found") + ErrSignatureNotSet = errors.New("signature not set") + ErrRequestPayloadNotSet = errors.New("request payload not set") ) // Private websocket errors @@ -64,6 +67,9 @@ var ( errConnectionWrapperDuplication = errors.New("connection wrapper duplication") errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") errExchangeConfigEmpty = errors.New("exchange config is empty") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errMessageFilterNotSet = errors.New("message filter not set") + errMessageFilterNotComparable = errors.New("message filter is not comparable") ) var globalReporter Reporter @@ -259,13 +265,19 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } + if c.MessageFilter != nil && !reflect.TypeOf(c.MessageFilter).Comparable() { + return errMessageFilterNotComparable + } + for x := range w.connectionManager { - if w.connectionManager[x].Setup.URL == c.URL { + // Below allows for multiple connections to the same URL with different outbound request signatures. This + // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on + // a spot connection. + if w.connectionManager[x].Setup.URL == c.URL && c.MessageFilter == w.connectionManager[x].Setup.MessageFilter { return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } - - w.connectionManager = append(w.connectionManager, ConnectionWrapper{ + w.connectionManager = append(w.connectionManager, &ConnectionWrapper{ Setup: c, Subscriptions: subscription.NewStore(), }) @@ -422,12 +434,21 @@ func (w *Websocket) connect() error { break } - w.connections[conn] = &w.connectionManager[i] + w.connections[conn] = w.connectionManager[i] w.connectionManager[i].Connection = conn w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) + if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() { + err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn) + if err != nil { + // Opted to not fail entirely here for POC. This should be + // revisited and handled more gracefully. + log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err) + } + } + err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) @@ -633,7 +654,7 @@ func (w *Websocket) FlushChannels() error { } w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) - w.connections[conn] = &w.connectionManager[x] + w.connections[conn] = w.connectionManager[x] w.connectionManager[x].Connection = conn } @@ -1064,7 +1085,7 @@ func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) if s.State() == subscription.ResubscribingState { continue } - if found := w.subscriptions.Get(s); found != nil { + if found := subscriptionStore.Get(s); found != nil { return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) } } @@ -1241,3 +1262,37 @@ func signalReceived(ch chan struct{}) bool { return false } } + +// GetConnection returns a connection by message filter (defined in exchange package _wrapper.go websocket connection) +// for request and response handling in a multi connection context. +func (w *Websocket) GetConnection(messageFilter any) (Connection, error) { + if w == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) + } + + if messageFilter == nil { + return nil, errMessageFilterNotSet + } + + w.m.Lock() + defer w.m.Unlock() + + if !w.useMultiConnectionManagement { + return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) + } + + if !w.IsConnected() { + return nil, ErrNotConnected + } + + for _, wrapper := range w.connectionManager { + if wrapper.Setup.MessageFilter == messageFilter { + if wrapper.Connection == nil { + return nil, fmt.Errorf("%s: %s %w associated with message filter: '%v'", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, messageFilter) + } + return wrapper.Connection, nil + } + } + + return nil, fmt.Errorf("%s: %w associated with message filter: '%v'", w.exchangeName, ErrRequestRouteNotFound, messageFilter) +} diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 1f6f5e60..55fd7168 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -304,6 +304,12 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl // SendMessageReturnResponses will send a WS message to the connection and wait for N responses // An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) { + return w.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil) +} + +// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses +// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked +func (w *WebsocketConnection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) { outbound, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -320,28 +326,43 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep return nil, err } + resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector) + if err != nil { + return nil, err + } + + if w.Reporter != nil { + w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) + } + + return resps, err +} + +// waitForResponses waits for N responses from a channel +func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) { timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) + defer timeout.Stop() resps := make([][]byte, 0, expected) - for err == nil && len(resps) < expected { +inspection: + for range expected { select { case resp := <-ch: resps = append(resps, resp) + // Checks recently received message to determine if this is in fact the final message in a sequence of messages. + if messageInspector != nil && messageInspector.IsFinal(resp) { + w.Match.RemoveSignature(signature) + break inspection + } case <-timeout.C: w.Match.RemoveSignature(signature) - err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) + return nil, fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) case <-ctx.Done(): w.Match.RemoveSignature(signature) - err = ctx.Err() + return nil, ctx.Err() } } - timeout.Stop() - - if err == nil && w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) - } - // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. if request.IsVerbose(ctx, false) { for i := range resps { @@ -349,7 +370,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } } - return resps, err + return resps, nil } func removeURLQueryString(url string) string { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2904bcca..b6f3a762 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -223,13 +223,16 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") ws.useMultiConnectionManagement = true + ws.SetCanUseAuthenticatedEndpoints(true) mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} + ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + ws.connectionManager[0].Setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } @@ -371,7 +374,7 @@ func TestWebsocket(t *testing.T) { ws.useMultiConnectionManagement = true - ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} + ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} err = ws.SetProxyAddress("https://192.168.0.1:1337") require.NoError(t, err) } @@ -464,7 +467,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { amazingConn := multi.getConnectionFromSetup(amazingCandidate) multi.connections = map[Connection]*ConnectionWrapper{ - amazingConn: &multi.connectionManager[0], + amazingConn: multi.connectionManager[0], } subs, err = amazingCandidate.GenerateSubscriptions() @@ -761,8 +764,43 @@ func TestSendMessageReturnResponse(t *testing.T) { wc.ResponseMaxLimit = 1 _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, "123", req) assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") + + _, err = wc.SendMessageReturnResponsesWithInspector(context.Background(), request.Unset, "123", req, 1, inspection{}) + assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found") } +func TestWaitForResponses(t *testing.T) { + t.Parallel() + dummy := &WebsocketConnection{ + ResponseMaxLimit: time.Nanosecond, + Match: NewMatch(), + } + _, err := dummy.waitForResponses(context.Background(), "silly", nil, 1, inspection{}) + require.ErrorIs(t, err, ErrSignatureTimeout) + + dummy.ResponseMaxLimit = time.Second + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = dummy.waitForResponses(ctx, "silly", nil, 1, inspection{}) + require.ErrorIs(t, err, context.Canceled) + + // test break early and hit verbose path + ch := make(chan []byte, 1) + ch <- []byte("hello") + ctx = request.WithVerbose(context.Background()) + + got, err := dummy.waitForResponses(ctx, "silly", ch, 2, inspection{breakEarly: true}) + require.NoError(t, err) + require.Len(t, got, 1) + assert.Equal(t, "hello", string(got[0])) +} + +type inspection struct { + breakEarly bool +} + +func (i inspection) IsFinal([]byte) bool { return i.breakEarly } + type reporter struct { name string msg []byte @@ -1229,6 +1267,11 @@ func TestSetupNewConnection(t *testing.T) { require.ErrorIs(t, err, errWebsocketDataHandlerUnset) connSetup.Handler = func(context.Context, []byte) error { return nil } + connSetup.MessageFilter = []string{"slices are super naughty and not comparable"} + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errMessageFilterNotComparable) + + connSetup.MessageFilter = "comparable string signature" err = multi.SetupNewConnection(connSetup) require.NoError(t, err) @@ -1484,3 +1527,42 @@ func TestMonitorTraffic(t *testing.T) { ws.TrafficAlert <- struct{}{} require.False(t, innerShell()) } + +func TestGetConnection(t *testing.T) { + t.Parallel() + var ws *Websocket + _, err := ws.GetConnection(nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + ws = &Websocket{} + + _, err = ws.GetConnection(nil) + require.ErrorIs(t, err, errMessageFilterNotSet) + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, errCannotObtainOutboundConnection) + + ws.useMultiConnectionManagement = true + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + ws.setState(connectedState) + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrRequestRouteNotFound) + + ws.connectionManager = []*ConnectionWrapper{{ + Setup: &ConnectionSetup{MessageFilter: "testURL", URL: "testURL"}, + }} + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + expected := &WebsocketConnection{} + ws.connectionManager[0].Connection = expected + + conn, err := ws.GetConnection("testURL") + require.NoError(t, err) + assert.Same(t, expected, conn) +} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 27a5c819..26b20f1e 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -54,7 +54,7 @@ type Websocket struct { // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. - connectionManager []ConnectionWrapper + connectionManager []*ConnectionWrapper // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder connections map[Connection]*ConnectionWrapper