From 90187a3a5a9007aa836bb400630503b2ac523137 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Wed, 6 Aug 2025 10:42:35 +1000 Subject: [PATCH] stream/match: allow a single connection to maintain its own match lookup for multi-connection (#1613) * gateio: Add multi asset websocket support WIP. * meow * Add tests and shenanigans * integrate flushing and for enabling/disabling pairs from rpc shenanigans * some changes * linter: fixes strikes again. * Change name ConnectionAssociation -> ConnectionCandidate for better clarity on purpose. Change connections map to point to candidate to track subscriptions for future dynamic connections holder and drop struct ConnectionDetails. * Add subscription tests (state functional) * glorious:nits + proxy handling * Spelling * linter: fixerino * instead of nil, dont do nil. * clean up nils * cya nils * don't need to set URL or check if its running * stream match update * update tests * linter: fix * glorious: nits + handle context cancellations * stop ping handler routine leak * * Fix bug where reader routine on error that is not a disconnection error but websocket frame error or anything really makes the reader routine return and then connection never cycles and the buffer gets filled. * Handle reconnection via an errors.Is check which is simpler and in that scope allow for quick disconnect reconnect without waiting for connection cycle. * Dial now uses code from DialContext but just calls context.Background() * Don't allow reader to return on parse binary response error. Just output error and return a non nil response * Allow rollback on connect on any error across all connections * fix shadow jutsu * glorious/gk: nitters - adds in ws mock server * linter: fix * fix deadlock on connection as the previous channel had no reader and would hang connection reader for eternity. * glorious: whooops * gk: nits * Leak issue and edge case * Websocket: Add SendMessageReturnResponses * whooooooopsie * gk: nitssssss * Update exchanges/stream/stream_match.go Co-authored-by: Gareth Kirwan * Update exchanges/stream/stream_match_test.go Co-authored-by: Gareth Kirwan * linter: appease the linter gods * gk: nits * gk: drain brain * glorious: nits * glorious: nits * glorious: nits * start to decouple match from a global reference to a connection * Update exchanges/stream/websocket.go Co-authored-by: Scott * glorious: nits * add tests * linter: fix * After merge * Add error connection info * Fix edge case where it does not reconnect made by an already closed connection * stream coverage * glorious: nits * glorious: nits removed asset error handling in stream package * linter: fix * rm block * Add basic readme * fix asset enabled flush cycle for multi connection * spella: fix * linter: fix * Add glorious suggestions, fix some race thing * reinstate name before any routine gets spawned * stop on error in mock tests * glorious: nits * glorious: nits found in CI build * Add test for drain, bumped wait times as there seems to be something happening on macos CI builds, used context.WithTimeout because its instant. * mutex across shutdown and connect for protection * lint: fix * test time withoffset, reinstate stop * fix whoops * const trafficCheckInterval; rm testmain * y * fix lint * bump time check window * stream: fix intermittant test failures while testing routines and remove code that is not needed. * spells * cant do what I did * protect race due to routine. * update testURL * use mock websocket connection instead of test URL's * linter: fix * remove url because its throwing errors on CI builds * connections drop all the time, don't need to worry about not being able to echo back ws data as it can be easily reviewed _test file side. * remove another superfluous url thats not really set up for this * spawn overwatch routine when there is no errors, inline checker instead of waiting for a time period, add sleep inline with echo handler as this is really quick and wanted to ensure that latency is handing correctly * linter: fixerino uperino * glorious: panix * linter: things * whoops * match naming with master changes * stream: Add tests * gk: nits on potential blockage in test * gk; nits assert value --------- Co-authored-by: shazbert Co-authored-by: Gareth Kirwan Co-authored-by: Scott --- exchange/websocket/connection.go | 9 ++- exchange/websocket/manager.go | 12 +++- exchange/websocket/manager_test.go | 23 +++++-- exchange/websocket/subscriptions_test.go | 2 +- exchanges/gateio/gateio_test.go | 68 ++++++++++---------- exchanges/gateio/gateio_websocket.go | 6 +- exchanges/gateio/gateio_websocket_futures.go | 6 +- exchanges/gateio/gateio_websocket_option.go | 4 +- exchanges/gateio/gateio_wrapper.go | 12 ++-- 9 files changed, 85 insertions(+), 57 deletions(-) diff --git a/exchange/websocket/connection.go b/exchange/websocket/connection.go index 529e20e1..20b28a3f 100644 --- a/exchange/websocket/connection.go +++ b/exchange/websocket/connection.go @@ -55,6 +55,8 @@ type Connection interface { SetProxy(string) GetURL() string Shutdown() error + // RequireMatchWithData routes incoming data using the connection specific match system to the correct handler + RequireMatchWithData(signature any, incoming []byte) error } // ConnectionSetup defines variables for an individual stream connection @@ -85,7 +87,7 @@ type ConnectionSetup struct { // Handler defines the function that will be called when a message is // received from the exchange's websocket server. This function should // handle the incoming message and pass it to the appropriate data handler. - Handler func(ctx context.Context, incoming []byte) error + Handler func(ctx context.Context, conn Connection, incoming []byte) error // BespokeGenerateMessageID is a function that returns a unique message ID. // This is useful for when an exchange connection requires a unique or // structured message ID for each message sent. @@ -473,3 +475,8 @@ func removeURLQueryString(url string) string { } return url } + +// RequireMatchWithData routes incoming data using the connection specific match system to the correct handler +func (c *connection) RequireMatchWithData(signature any, incoming []byte) error { + return c.Match.RequireMatchWithData(signature, incoming) +} diff --git a/exchange/websocket/manager.go b/exchange/websocket/manager.go index 0fafdfc9..70440f93 100644 --- a/exchange/websocket/manager.go +++ b/exchange/websocket/manager.go @@ -387,6 +387,12 @@ func (m *Manager) getConnectionFromSetup(c *ConnectionSetup) *connection { if c.URL != "" { connectionURL = c.URL } + match := m.Match + if m.useMultiConnectionManagement { + // If we are using multi connection management, we can decouple + // the match from the global match and have a match per connection. + match = NewMatch() + } return &connection{ ExchangeName: m.exchangeName, URL: connectionURL, @@ -397,7 +403,7 @@ func (m *Manager) getConnectionFromSetup(c *ConnectionSetup) *connection { readMessageErrors: m.ReadMessageErrors, shutdown: m.ShutdownC, Wg: &m.Wg, - Match: m.Match, + Match: match, RateLimit: c.RateLimit, Reporter: c.ConnectionLevelReporter, bespokeGenerateMessageID: c.BespokeGenerateMessageID, @@ -867,14 +873,14 @@ func checkWebsocketURL(s string) error { } // Reader reads and handles data from a specific connection -func (m *Manager) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, message []byte) error) { +func (m *Manager) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, conn Connection, message []byte) error) { defer m.Wg.Done() for { resp := conn.ReadMessage() if resp.Raw == nil { return // Connection has been closed } - if err := handler(ctx, resp.Raw); err != nil { + if err := handler(ctx, conn, resp.Raw); err != nil { m.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) } } diff --git a/exchange/websocket/manager_test.go b/exchange/websocket/manager_test.go index 68dc5e8f..83d28aa8 100644 --- a/exchange/websocket/manager_test.go +++ b/exchange/websocket/manager_test.go @@ -262,7 +262,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.ErrorIs(t, err, errWebsocketDataHandlerUnset) - ws.connectionManager[0].setup.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].setup.Handler = func(context.Context, Connection, []byte) error { return errDastardlyReason } err = ws.Connect() @@ -280,7 +280,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].setup.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].setup.Handler = func(context.Context, Connection, []byte) error { return errDastardlyReason } err = ws.Connect() @@ -958,7 +958,7 @@ func TestFlushChannels(t *testing.T) { GenerateSubscriptions: newgen.generateSubs, Subscriber: func(context.Context, Connection, subscription.List) error { return nil }, Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil }, - Handler: func(context.Context, []byte) error { return nil }, + Handler: func(context.Context, Connection, []byte) error { return nil }, } require.NoError(t, w.SetupNewConnection(amazingCandidate)) require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store") @@ -1066,7 +1066,7 @@ func TestSetupNewConnection(t *testing.T) { err = multi.SetupNewConnection(connSetup) require.ErrorIs(t, err, errWebsocketDataHandlerUnset) - connSetup.Handler = func(context.Context, []byte) error { return nil } + connSetup.Handler = func(context.Context, Connection, []byte) error { return nil } connSetup.MessageFilter = []string{"slices are super naughty and not comparable"} err = multi.SetupNewConnection(connSetup) require.ErrorIs(t, err, errMessageFilterNotComparable) @@ -1336,3 +1336,18 @@ func TestGetConnection(t *testing.T) { require.NoError(t, err) assert.Same(t, expected, conn) } + +func TestWebsocketConnectionRequireMatchWithData(t *testing.T) { + t.Parallel() + ws := connection{Match: NewMatch()} + err := ws.RequireMatchWithData(0, nil) + require.ErrorIs(t, err, ErrSignatureNotMatched) + + ch, err := ws.Match.Set(0, 1) + require.NoError(t, err) + + err = ws.RequireMatchWithData(0, []byte("test")) + require.NoError(t, err) + require.Len(t, ch, 1, "must have one item in channel") + assert.Equal(t, []byte("test"), <-ch) +} diff --git a/exchange/websocket/subscriptions_test.go b/exchange/websocket/subscriptions_test.go index 00125573..7fc160e3 100644 --- a/exchange/websocket/subscriptions_test.go +++ b/exchange/websocket/subscriptions_test.go @@ -68,7 +68,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleUnsubConn(multi)(ctx, c, s) }, - Handler: func(context.Context, []byte) error { return nil }, + Handler: func(context.Context, Connection, []byte) error { return nil }, } require.NoError(t, multi.SetupNewConnection(amazingCandidate)) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index f4743f57..e9458f7c 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -1981,7 +1981,7 @@ const wsTickerPushDataJSON = `{"time": 1606291803, "channel": "spot.tickers", "e func TestWsTickerPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsTickerPushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsTickerPushDataJSON)); err != nil { t.Errorf("%s websocket ticker push data error: %v", e.Name, err) } } @@ -1990,7 +1990,7 @@ const wsTradePushDataJSON = `{ "time": 1606292218, "channel": "spot.trades", "ev func TestWsTradePushData(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsTradePushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsTradePushDataJSON)); err != nil { t.Errorf("%s websocket trade push data error: %v", e.Name, err) } } @@ -1999,7 +1999,7 @@ const wsCandlestickPushDataJSON = `{"time": 1606292600, "channel": "spot.candles func TestWsCandlestickPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsCandlestickPushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsCandlestickPushDataJSON)); err != nil { t.Errorf("%s websocket candlestick push data error: %v", e.Name, err) } } @@ -2008,7 +2008,7 @@ const wsOrderbookTickerJSON = `{"time": 1606293275, "channel": "spot.book_ticker func TestWsOrderbookTickerPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsOrderbookTickerJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsOrderbookTickerJSON)); err != nil { t.Errorf("%s websocket orderbook push data error: %v", e.Name, err) } } @@ -2020,11 +2020,11 @@ const ( func TestWsOrderbookSnapshotPushData(t *testing.T) { t.Parallel() - err := e.WsHandleSpotData(t.Context(), []byte(wsOrderbookSnapshotPushDataJSON)) + err := e.WsHandleSpotData(t.Context(), nil, []byte(wsOrderbookSnapshotPushDataJSON)) if err != nil { t.Errorf("%s websocket orderbook snapshot push data error: %v", e.Name, err) } - if err = e.WsHandleSpotData(t.Context(), []byte(wsOrderbookUpdatePushDataJSON)); err != nil { + if err = e.WsHandleSpotData(t.Context(), nil, []byte(wsOrderbookUpdatePushDataJSON)); err != nil { t.Errorf("%s websocket orderbook update push data error: %v", e.Name, err) } } @@ -2033,7 +2033,7 @@ const wsSpotOrderPushDataJSON = `{"time": 1605175506, "channel": "spot.orders", func TestWsPushOrders(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsSpotOrderPushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsSpotOrderPushDataJSON)); err != nil { t.Errorf("%s websocket orders push data error: %v", e.Name, err) } } @@ -2042,7 +2042,7 @@ const wsUserTradePushDataJSON = `{"time": 1605176741, "channel": "spot.usertrade func TestWsUserTradesPushDataJSON(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsUserTradePushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsUserTradePushDataJSON)); err != nil { t.Errorf("%s websocket users trade push data error: %v", e.Name, err) } } @@ -2052,7 +2052,7 @@ const wsBalancesPushDataJSON = `{"time": 1605248616, "channel": "spot.balances", func TestBalancesPushData(t *testing.T) { t.Parallel() ctx := account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "test", Secret: "test"}) - if err := e.WsHandleSpotData(ctx, []byte(wsBalancesPushDataJSON)); err != nil { + if err := e.WsHandleSpotData(ctx, nil, []byte(wsBalancesPushDataJSON)); err != nil { t.Errorf("%s websocket balances push data error: %v", e.Name, err) } } @@ -2061,7 +2061,7 @@ const wsMarginBalancePushDataJSON = `{"time": 1605248616, "channel": "spot.fundi func TestMarginBalancePushData(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsMarginBalancePushDataJSON)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket margin balance push data error: %v", e.Name, err) } } @@ -2071,7 +2071,7 @@ const wsCrossMarginBalancePushDataJSON = `{"time": 1605248616,"channel": "spot.c func TestCrossMarginBalancePushData(t *testing.T) { t.Parallel() ctx := account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "test", Secret: "test"}) - if err := e.WsHandleSpotData(ctx, []byte(wsCrossMarginBalancePushDataJSON)); err != nil { + if err := e.WsHandleSpotData(ctx, nil, []byte(wsCrossMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket cross margin balance push data error: %v", e.Name, err) } } @@ -2080,7 +2080,7 @@ const wsCrossMarginBalanceLoan = `{ "time":1658289372, "channel":"spot.cross_loa func TestCrossMarginBalanceLoan(t *testing.T) { t.Parallel() - if err := e.WsHandleSpotData(t.Context(), []byte(wsCrossMarginBalanceLoan)); err != nil { + if err := e.WsHandleSpotData(t.Context(), nil, []byte(wsCrossMarginBalanceLoan)); err != nil { t.Errorf("%s websocket cross margin loan push data error: %v", e.Name, err) } } @@ -2094,7 +2094,7 @@ func TestFuturesDataHandler(t *testing.T) { if strings.Contains(string(m), "futures.balances") { ctx = account.DeployCredentialsToContext(ctx, &account.Credentials{Key: "test", Secret: "test"}) } - return e.WsHandleFuturesData(ctx, m, asset.CoinMarginedFutures) + return e.WsHandleFuturesData(ctx, nil, m, asset.CoinMarginedFutures) }) close(e.Websocket.DataHandler) assert.Len(t, e.Websocket.DataHandler, 14, "Should see the correct number of messages") @@ -2111,7 +2111,7 @@ const optionsContractTickerPushDataJSON = `{"time": 1630576352, "channel": "opti func TestOptionsContractTickerPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsContractTickerPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsContractTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options contract ticker push data failed with error %v", e.Name, err) } } @@ -2120,7 +2120,7 @@ const optionsUnderlyingTickerPushDataJSON = `{"time": 1630576352, "channel": "op func TestOptionsUnderlyingTickerPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsUnderlyingTickerPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsUnderlyingTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying ticker push data error: %v", e.Name, err) } } @@ -2129,7 +2129,7 @@ const optionsContractTradesPushDataJSON = `{"time": 1630576356, "channel": "opti func TestOptionsContractTradesPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsContractTradesPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsContractTradesPushDataJSON)); err != nil { t.Errorf("%s websocket contract trades push data error: %v", e.Name, err) } } @@ -2138,7 +2138,7 @@ const optionsUnderlyingTradesPushDataJSON = `{"time": 1630576356, "channel": "op func TestOptionsUnderlyingTradesPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsUnderlyingTradesPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsUnderlyingTradesPushDataJSON)); err != nil { t.Errorf("%s websocket underlying trades push data error: %v", e.Name, err) } } @@ -2147,7 +2147,7 @@ const optionsUnderlyingPricePushDataJSON = `{ "time": 1630576356, "channel": "op func TestOptionsUnderlyingPricePushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsUnderlyingPricePushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsUnderlyingPricePushDataJSON)); err != nil { t.Errorf("%s websocket underlying price push data error: %v", e.Name, err) } } @@ -2156,7 +2156,7 @@ const optionsMarkPricePushDataJSON = `{ "time": 1630576356, "channel": "options. func TestOptionsMarkPricePushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsMarkPricePushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsMarkPricePushDataJSON)); err != nil { t.Errorf("%s websocket mark price push data error: %v", e.Name, err) } } @@ -2165,7 +2165,7 @@ const optionsSettlementsPushDataJSON = `{ "time": 1630576356, "channel": "option func TestSettlementsPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsSettlementsPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsSettlementsPushDataJSON)); err != nil { t.Errorf("%s websocket options settlements push data error: %v", e.Name, err) } } @@ -2174,7 +2174,7 @@ const optionsContractPushDataJSON = `{"time": 1630576356, "channel": "options.co func TestOptionsContractPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsContractPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsContractPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts push data error: %v", e.Name, err) } } @@ -2186,10 +2186,10 @@ const ( func TestOptionsCandlesticksPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsContractCandlesticksPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsContractCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts candlestick push data error: %v", e.Name, err) } - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying candlestick push data error: %v", e.Name, err) } } @@ -2204,17 +2204,17 @@ const ( func TestOptionsOrderbookPushData(t *testing.T) { t.Parallel() p := getPair(t, asset.Options) - assert.NoError(t, e.WsHandleOptionsData(t.Context(), []byte(optionsOrderbookTickerPushDataJSON))) - assert.NoError(t, e.WsHandleOptionsData(t.Context(), fmt.Appendf(nil, optionsOrderbookUpdatePushDataJSON, p.Upper().String()))) - assert.NoError(t, e.WsHandleOptionsData(t.Context(), []byte(optionsOrderbookSnapshotPushDataJSON))) - assert.NoError(t, e.WsHandleOptionsData(t.Context(), []byte(optionsOrderbookSnapshotUpdateEventPushDataJSON))) + assert.NoError(t, e.WsHandleOptionsData(t.Context(), nil, []byte(optionsOrderbookTickerPushDataJSON))) + assert.NoError(t, e.WsHandleOptionsData(t.Context(), nil, fmt.Appendf(nil, optionsOrderbookUpdatePushDataJSON, p.Upper().String()))) + assert.NoError(t, e.WsHandleOptionsData(t.Context(), nil, []byte(optionsOrderbookSnapshotPushDataJSON))) + assert.NoError(t, e.WsHandleOptionsData(t.Context(), nil, []byte(optionsOrderbookSnapshotUpdateEventPushDataJSON))) } const optionsOrderPushDataJSON = `{"time": 1630654851,"channel": "options.orders", "event": "update", "result": [ { "contract": "BTC_USDT-20211130-65000-C", "create_time": 1637897000, "fill_price": 0, "finish_as": "cancelled", "iceberg": 0, "id": 106, "is_close": false, "is_liq": false, "is_reduce_only": false, "left": -10, "mkfr": 0.0004, "price": 15000, "refr": 0, "refu": 0, "size": -10, "status": "finished", "text": "web", "tif": "gtc", "tkfr": 0.0004, "underlying": "BTC_USDT", "user": "9xxx", "time": 1639051907,"time_ms": 1639051907000}]}` func TestOptionsOrderPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsOrderPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsOrderPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", e.Name, err) } } @@ -2223,7 +2223,7 @@ const optionsUsersTradesPushDataJSON = `{ "time": 1639144214, "channel": "option func TestOptionUserTradesPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsUsersTradesPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsUsersTradesPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", e.Name, err) } } @@ -2232,7 +2232,7 @@ const optionsLiquidatesPushDataJSON = `{ "channel": "options.liquidates", "event func TestOptionsLiquidatesPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsLiquidatesPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsLiquidatesPushDataJSON)); err != nil { t.Errorf("%s websocket options liquidates push data error: %v", e.Name, err) } } @@ -2241,7 +2241,7 @@ const optionsSettlementPushDataJSON = `{ "channel": "options.user_settlements", func TestOptionsSettlementPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsSettlementPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsSettlementPushDataJSON)); err != nil { t.Errorf("%s websocket options settlement push data error: %v", e.Name, err) } } @@ -2250,7 +2250,7 @@ const optionsPositionClosePushDataJSON = `{"channel": "options.position_closes", func TestOptionsPositionClosePushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsPositionClosePushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsPositionClosePushDataJSON)); err != nil { t.Errorf("%s websocket options position close push data error: %v", e.Name, err) } } @@ -2260,7 +2260,7 @@ const optionsBalancePushDataJSON = `{ "channel": "options.balances", "event": "u func TestOptionsBalancePushData(t *testing.T) { t.Parallel() ctx := account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "test", Secret: "test"}) - if err := e.WsHandleOptionsData(ctx, []byte(optionsBalancePushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(ctx, nil, []byte(optionsBalancePushDataJSON)); err != nil { t.Errorf("%s websocket options balance push data error: %v", e.Name, err) } } @@ -2269,7 +2269,7 @@ const optionsPositionPushDataJSON = `{"time": 1630654851, "channel": "options.po func TestOptionsPositionPushData(t *testing.T) { t.Parallel() - if err := e.WsHandleOptionsData(t.Context(), []byte(optionsPositionPushDataJSON)); err != nil { + if err := e.WsHandleOptionsData(t.Context(), nil, []byte(optionsPositionPushDataJSON)); err != nil { t.Errorf("%s websocket options position push data error: %v", e.Name, err) } } diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index e32fb76d..06818953 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -166,18 +166,18 @@ func (e *Exchange) generateWsSignature(secret, event, channel string, t int64) ( } // WsHandleSpotData handles spot data -func (e *Exchange) WsHandleSpotData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) WsHandleSpotData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { push, err := parseWSHeader(respRaw) if err != nil { return err } if push.RequestID != "" { - return e.Websocket.Match.RequireMatchWithData(push.RequestID, respRaw) + return conn.RequireMatchWithData(push.RequestID, respRaw) } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - return e.Websocket.Match.RequireMatchWithData(push.ID, respRaw) + return conn.RequireMatchWithData(push.ID, respRaw) } switch push.Channel { // TODO: Convert function params below to only use push.Result diff --git a/exchanges/gateio/gateio_websocket_futures.go b/exchanges/gateio/gateio_websocket_futures.go index f01a7cfb..6cdba717 100644 --- a/exchanges/gateio/gateio_websocket_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -143,18 +143,18 @@ func (e *Exchange) FuturesUnsubscribe(ctx context.Context, conn websocket.Connec } // WsHandleFuturesData handles futures websocket data -func (e *Exchange) WsHandleFuturesData(ctx context.Context, respRaw []byte, a asset.Item) error { +func (e *Exchange) WsHandleFuturesData(ctx context.Context, conn websocket.Connection, respRaw []byte, a asset.Item) error { push, err := parseWSHeader(respRaw) if err != nil { return err } if push.RequestID != "" { - return e.Websocket.Match.RequireMatchWithData(push.RequestID, respRaw) + return conn.RequireMatchWithData(push.RequestID, respRaw) } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - return e.Websocket.Match.RequireMatchWithData(push.ID, respRaw) + return conn.RequireMatchWithData(push.ID, respRaw) } switch push.Channel { diff --git a/exchanges/gateio/gateio_websocket_option.go b/exchanges/gateio/gateio_websocket_option.go index 5cdc45cb..4c4799db 100644 --- a/exchanges/gateio/gateio_websocket_option.go +++ b/exchanges/gateio/gateio_websocket_option.go @@ -295,14 +295,14 @@ func (e *Exchange) OptionsUnsubscribe(ctx context.Context, conn websocket.Connec } // WsHandleOptionsData handles options websocket data -func (e *Exchange) WsHandleOptionsData(ctx context.Context, respRaw []byte) error { +func (e *Exchange) WsHandleOptionsData(ctx context.Context, conn websocket.Connection, respRaw []byte) error { push, err := parseWSHeader(respRaw) if err != nil { return err } if push.Event == subscribeEvent || push.Event == unsubscribeEvent { - return e.Websocket.Match.RequireMatchWithData(push.ID, respRaw) + return conn.RequireMatchWithData(push.ID, respRaw) } switch push.Channel { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index ecf4df04..18044949 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -228,8 +228,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { URL: usdtFuturesWebsocketURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return e.WsHandleFuturesData(ctx, incoming, asset.USDTMarginedFutures) + Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { + return e.WsHandleFuturesData(ctx, conn, incoming, asset.USDTMarginedFutures) }, Subscriber: e.FuturesSubscribe, Unsubscriber: e.FuturesUnsubscribe, @@ -250,8 +250,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { URL: btcFuturesWebsocketURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return e.WsHandleFuturesData(ctx, incoming, asset.CoinMarginedFutures) + Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { + return e.WsHandleFuturesData(ctx, conn, incoming, asset.CoinMarginedFutures) }, Subscriber: e.FuturesSubscribe, Unsubscriber: e.FuturesUnsubscribe, @@ -272,8 +272,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { URL: deliveryRealUSDTTradingURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: func(ctx context.Context, incoming []byte) error { - return e.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) + Handler: func(ctx context.Context, conn websocket.Connection, incoming []byte) error { + return e.WsHandleFuturesData(ctx, conn, incoming, asset.DeliveryFutures) }, Subscriber: e.DeliveryFuturesSubscribe, Unsubscriber: e.DeliveryFuturesUnsubscribe,