From ac731ce283f126e0a4184d377599da30dfd27d9d Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 10 Oct 2024 15:09:52 +1100 Subject: [PATCH] websocket/gateio: Support multi connection management and integrate with GateIO (#1580) * gateio: Add multi asset websocket support WIP. * meow * Add tests and shenanigans * integrate flushing and for enabling/disabling pairs from rpc shenanigans * some changes * linter: fixes strikes again. * Change name ConnectionAssociation -> ConnectionCandidate for better clarity on purpose. Change connections map to point to candidate to track subscriptions for future dynamic connections holder and drop struct ConnectionDetails. * Add subscription tests (state functional) * glorious:nits + proxy handling * Spelling * linter: fixerino * instead of nil, dont do nil. * clean up nils * cya nils * don't need to set URL or check if its running * stop ping handler routine leak * * Fix bug where reader routine on error that is not a disconnection error but websocket frame error or anything really makes the reader routine return and then connection never cycles and the buffer gets filled. * Handle reconnection via an errors.Is check which is simpler and in that scope allow for quick disconnect reconnect without waiting for connection cycle. * Dial now uses code from DialContext but just calls context.Background() * Don't allow reader to return on parse binary response error. Just output error and return a non nil response * Allow rollback on connect on any error across all connections * fix shadow jutsu * glorious/gk: nitters - adds in ws mock server * linter: fix * fix deadlock on connection as the previous channel had no reader and would hang connection reader for eternity. * gk: nits * Leak issue and edge case * gk: nits * gk: drain brain * glorious: nits * Update exchanges/stream/websocket.go Co-authored-by: Scott * glorious: nits * add tests * linter: fix * After merge * Add error connection info * Fix edge case where it does not reconnect made by an already closed connection * stream coverage * glorious: nits * glorious: nits removed asset error handling in stream package * linter: fix * rm block * Add basic readme * fix asset enabled flush cycle for multi connection * spella: fix * linter: fix * Add glorious suggestions, fix some race thing * reinstate name before any routine gets spawned * stop on error in mock tests * glorious: nits * glorious: nits found in CI build * Add test for drain, bumped wait times as there seems to be something happening on macos CI builds, used context.WithTimeout because its instant. * mutex across shutdown and connect for protection * lint: fix * test time withoffset, reinstate stop * fix whoops * const trafficCheckInterval; rm testmain * y * fix lint * bump time check window * stream: fix intermittant test failures while testing routines and remove code that is not needed. * spells * cant do what I did * protect race due to routine. * update testURL * use mock websocket connection instead of test URL's * linter: fix * remove url because its throwing errors on CI builds * connections drop all the time, don't need to worry about not being able to echo back ws data as it can be easily reviewed _test file side. * remove another superfluous url thats not really set up for this * spawn overwatch routine when there is no errors, inline checker instead of waiting for a time period, add sleep inline with echo handler as this is really quick and wanted to ensure that latency is handing correctly * linter: fixerino uperino * glorious: panix * linter: things * whoops * defer lock and use functions that don't require locking in SetProxyAddress * lint: fix * thrasher: nits --------- Co-authored-by: shazbert Co-authored-by: Scott --- docs/ADD_NEW_EXCHANGE.md | 6 +- engine/rpcserver.go | 7 + exchanges/binance/binance_test.go | 5 +- exchanges/binance/binance_websocket.go | 6 +- exchanges/binance/binance_wrapper.go | 2 +- exchanges/binanceus/binanceus_websocket.go | 4 +- exchanges/binanceus/binanceus_wrapper.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 16 +- exchanges/bitfinex/bitfinex_websocket.go | 10 +- exchanges/bitfinex/bitfinex_wrapper.go | 4 +- exchanges/bithumb/bithumb_websocket.go | 2 +- exchanges/bithumb/bithumb_wrapper.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 4 +- exchanges/bitmex/bitmex_wrapper.go | 2 +- exchanges/bitstamp/bitstamp_websocket.go | 4 +- exchanges/bitstamp/bitstamp_wrapper.go | 2 +- exchanges/btcmarkets/btcmarkets_websocket.go | 4 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 2 +- exchanges/btse/btse_websocket.go | 4 +- exchanges/btse/btse_wrapper.go | 2 +- exchanges/bybit/bybit_wrapper.go | 4 +- .../coinbasepro/coinbasepro_websocket.go | 4 +- exchanges/coinbasepro/coinbasepro_wrapper.go | 2 +- exchanges/coinut/coinut_websocket.go | 4 +- exchanges/coinut/coinut_wrapper.go | 2 +- exchanges/deribit/deribit_wrapper.go | 2 +- exchanges/exchange.go | 4 +- exchanges/gateio/gateio_test.go | 137 ++- exchanges/gateio/gateio_websocket.go | 111 +- exchanges/gateio/gateio_wrapper.go | 119 +- .../gateio/gateio_ws_delivery_futures.go | 198 +--- exchanges/gateio/gateio_ws_futures.go | 245 ++-- exchanges/gateio/gateio_ws_option.go | 103 +- exchanges/gemini/gemini_websocket.go | 4 +- exchanges/gemini/gemini_wrapper.go | 4 +- exchanges/hitbtc/hitbtc_websocket.go | 4 +- exchanges/hitbtc/hitbtc_wrapper.go | 2 +- exchanges/huobi/huobi_websocket.go | 4 +- exchanges/huobi/huobi_wrapper.go | 4 +- exchanges/kraken/kraken_test.go | 11 +- exchanges/kraken/kraken_websocket.go | 8 +- exchanges/kraken/kraken_wrapper.go | 4 +- exchanges/kucoin/kucoin_websocket.go | 4 +- exchanges/kucoin/kucoin_wrapper.go | 2 +- exchanges/okx/okx_websocket.go | 12 +- exchanges/okx/okx_wrapper.go | 4 +- exchanges/poloniex/poloniex_websocket.go | 4 +- exchanges/poloniex/poloniex_wrapper.go | 2 +- exchanges/protocol/features.go | 65 +- exchanges/stream/README.md | 137 +++ exchanges/stream/stream_types.go | 39 +- exchanges/stream/websocket.go | 1029 +++++++++++------ exchanges/stream/websocket_connection.go | 60 +- exchanges/stream/websocket_test.go | 936 +++++++++------ exchanges/stream/websocket_types.go | 22 +- internal/testing/exchange/exchange.go | 33 - internal/testing/exchange/exchange_test.go | 3 +- internal/testing/websocket/mock.go | 49 + 58 files changed, 1996 insertions(+), 1475 deletions(-) create mode 100644 exchanges/stream/README.md create mode 100644 internal/testing/websocket/mock.go diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 1d534f47..365f9ed7 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -798,7 +798,7 @@ channels: continue } // When we have a successful subscription, we can alert our internal management system of the success. - f.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i]) + f.Websocket.AddSuccessfulSubscriptions(f.Websocket.Conn, channelsToSubscribe[i]) } return errs } @@ -1038,7 +1038,7 @@ channels: continue } // When we have a successful unsubscription, we can alert our internal management system of the success. - f.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i]) + f.Websocket.RemoveSubscriptions(f.Websocket.Conn, channelsToUnsubscribe[i]) } if errs != nil { return errs @@ -1098,7 +1098,7 @@ func (f *FTX) Setup(exch *config.Exchange) error { return err } // Sets up a new connection for the websocket, there are two separate connections denoted by the ConnectionSetup struct auth bool. - return f.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return f.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, // RateLimit int64 rudimentary rate limit that sleeps connection in milliseconds before sending designated payload diff --git a/engine/rpcserver.go b/engine/rpcserver.go index 43b5e381..72e7e76e 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -2935,6 +2935,13 @@ func (s *RPCServer) SetExchangeAsset(_ context.Context, r *gctrpc.SetExchangeAss return nil, err } + if base.IsWebsocketEnabled() && base.Websocket.IsConnected() { + err = exch.FlushWebsocketChannels() + if err != nil { + return nil, err + } + } + return &gctrpc.GenericResponse{Status: MsgStatusSuccess}, nil } diff --git a/exchanges/binance/binance_test.go b/exchanges/binance/binance_test.go index f7dbbc17..5058b267 100644 --- a/exchanges/binance/binance_test.go +++ b/exchanges/binance/binance_test.go @@ -30,6 +30,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -1989,7 +1990,7 @@ func TestSubscribe(t *testing.T) { require.ElementsMatch(tb, req.Params, exp, "Params should have correct channels") return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":null,"id":%d}`, req.ID))) } - b = testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock)) + b = testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) } else { testexch.SetupWs(t, b) } @@ -2011,7 +2012,7 @@ func TestSubscribeBadResp(t *testing.T) { require.NoError(tb, err, "Unmarshal should not error") return w.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(`{"result":{"error":"carrots"},"id":%d}`, req.ID))) } - b := testexch.MockWsInstance[Binance](t, testexch.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + b := testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes err := b.Subscribe(channels) assert.ErrorIs(t, err, common.ErrUnknownError, "Subscribe should error correctly") assert.ErrorContains(t, err, "carrots", "Subscribe should error containing the carrots") diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 25df2c58..91b68a11 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -563,7 +563,7 @@ func (b *Binance) Unsubscribe(channels subscription.List) error { // manageSubs subscribes or unsubscribes from a list of subscriptions func (b *Binance) manageSubs(op string, subs subscription.List) error { if op == wsSubscribeMethod { - if err := b.Websocket.AddSubscriptions(subs...); err != nil { // Note: AddSubscription will set state to subscribing + if err := b.Websocket.AddSubscriptions(b.Websocket.Conn, subs...); err != nil { // Note: AddSubscription will set state to subscribing return err } } else { @@ -592,7 +592,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { b.Websocket.DataHandler <- err if op == wsSubscribeMethod { - if err2 := b.Websocket.RemoveSubscriptions(subs...); err2 != nil { + if err2 := b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...); err2 != nil { err = common.AppendError(err, err2) } } @@ -600,7 +600,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error { if op == wsSubscribeMethod { err = common.AppendError(err, subs.SetStates(subscription.SubscribedState)) } else { - err = b.Websocket.RemoveSubscriptions(subs...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...) } } diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index 7e6fe658..261ed151 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -255,7 +255,7 @@ func (b *Binance) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(250 * time.Millisecond), diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 859736cb..15101a0a 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -590,7 +590,7 @@ func (bi *Binanceus) Subscribe(channelsToSubscribe subscription.List) error { return err } } - return bi.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...) + return bi.Websocket.AddSuccessfulSubscriptions(bi.Websocket.Conn, channelsToSubscribe...) } // Unsubscribe unsubscribes from a set of channels @@ -614,7 +614,7 @@ func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe subscription.List) error return err } } - return bi.Websocket.RemoveSubscriptions(channelsToUnsubscribe...) + return bi.Websocket.RemoveSubscriptions(bi.Websocket.Conn, channelsToUnsubscribe...) } func (bi *Binanceus) setupOrderbookManager() { diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index bbe7d6fc..d3992d4e 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -185,7 +185,7 @@ func (bi *Binanceus) Setup(exch *config.Exchange) error { return err } - return bi.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return bi.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(300 * time.Millisecond), diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index dd6ccef0..1445a3ac 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1324,7 +1324,7 @@ func TestWsSubscribedResponse(t *testing.T) { assert.ErrorContains(t, err, "waiter1", "Should error containing subID if") } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Key: "waiter1"}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Key: "waiter1"}) require.NoError(t, err, "AddSubscriptions must not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) assert.NoError(t, err, "wsHandleData should not error") @@ -1337,7 +1337,7 @@ func TestWsSubscribedResponse(t *testing.T) { } func TestWsOrderBook(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[23405,[[38334303613,9348.8,0.53],[38334308111,9348.8,5.98979404],[38331335157,9344.1,1.28965787],[38334302803,9343.8,0.08230094],[38334279092,9343,0.8],[38334307036,9342.938663676,0.8],[38332749107,9342.9,0.2],[38332277330,9342.8,0.85],[38329406786,9342,0.1432012],[38332841570,9341.947288638,0.3],[38332163238,9341.7,0.3],[38334303384,9341.6,0.324],[38332464840,9341.4,0.5],[38331935870,9341.2,0.5],[38334312082,9340.9,0.02126899],[38334261292,9340.8,0.26763],[38334138680,9340.625455254,0.12],[38333896802,9339.8,0.85],[38331627527,9338.9,1.57863959],[38334186713,9338.9,0.26769],[38334305819,9338.8,2.999],[38334211180,9338.75285796,3.999],[38334310699,9337.8,0.10679883],[38334307414,9337.5,1],[38334179822,9337.1,0.26773],[38334306600,9336.659955102,1.79],[38334299667,9336.6,1.1],[38334306452,9336.6,0.13979771],[38325672859,9336.3,1.25],[38334311646,9336.2,1],[38334258509,9336.1,0.37],[38334310592,9336,1.79],[38334310378,9335.6,1.43],[38334132444,9335.2,0.26777],[38331367325,9335,0.07],[38334310703,9335,0.10680562],[38334298209,9334.7,0.08757301],[38334304857,9334.456899462,0.291],[38334309940,9334.088390727,0.0725],[38334310377,9333.7,1.2868],[38334297615,9333.607784,0.1108],[38334095188,9333.3,0.26785],[38334228913,9332.7,0.40861186],[38334300526,9332.363996604,0.3884],[38334310701,9332.2,0.10680562],[38334303548,9332.005382871,0.07],[38334311798,9331.8,0.41285228],[38334301012,9331.7,1.7952],[38334089877,9331.4,0.2679],[38321942150,9331.2,0.2],[38334310670,9330,1.069],[38334063096,9329.6,0.26796],[38334310700,9329.4,0.10680562],[38334310404,9329.3,1],[38334281630,9329.1,6.57150597],[38334036864,9327.7,0.26801],[38334310702,9326.6,0.10680562],[38334311799,9326.1,0.50220625],[38334164163,9326,0.219638],[38334309722,9326,1.5],[38333051682,9325.8,0.26807],[38334302027,9325.7,0.75],[38334203435,9325.366592,0.32397696],[38321967613,9325,0.05],[38334298787,9324.9,0.3],[38334301719,9324.8,3.6227592],[38331316716,9324.763454646,0.71442],[38334310698,9323.8,0.10680562],[38334035499,9323.7,0.23431017],[38334223472,9322.670551788,0.42150603],[38334163459,9322.560399006,0.143967],[38321825171,9320.8,2],[38334075805,9320.467496148,0.30772633],[38334075800,9319.916732238,0.61457592],[38333682302,9319.7,0.0011],[38331323088,9319.116771762,0.12913],[38333677480,9319,0.0199],[38334277797,9318.6,0.89],[38325235155,9318.041088,1.20249],[38334310910,9317.82382938,1.79],[38334311811,9317.2,0.61079138],[38334311812,9317.2,0.71937652],[38333298214,9317.1,50],[38334306359,9317,1.79],[38325531545,9316.382823951,0.21263],[38333727253,9316.3,0.02316372],[38333298213,9316.1,45],[38333836479,9316,2.135],[38324520465,9315.9,2.7681],[38334307411,9315.5,1],[38330313617,9315.3,0.84455],[38334077770,9315.294024,0.01248397],[38334286663,9315.294024,1],[38325533762,9315.290315394,2.40498],[38334310018,9315.2,3],[38333682617,9314.6,0.0011],[38334304794,9314.6,0.76364676],[38334304798,9314.3,0.69242113],[38332915733,9313.8,0.0199],[38334084411,9312.8,1],[38334311893,9350.1,-1.015],[38334302734,9350.3,-0.26737],[38334300732,9350.8,-5.2],[38333957619,9351,-0.90677089],[38334300521,9351,-1.6457],[38334301600,9351.012829557,-0.0523],[38334308878,9351.7,-2.5],[38334299570,9351.921544,-0.1015],[38334279367,9352.1,-0.26732],[38334299569,9352.411802928,-0.4036],[38334202773,9353.4,-0.02139404],[38333918472,9353.7,-1.96412776],[38334278782,9354,-0.26731],[38334278606,9355,-1.2785],[38334302105,9355.439221251,-0.79191542],[38313897370,9355.569409242,-0.43363],[38334292995,9355.584296,-0.0979],[38334216989,9355.8,-0.03686414],[38333894025,9355.9,-0.26721],[38334293798,9355.936691952,-0.4311],[38331159479,9356,-0.4204022],[38333918888,9356.1,-1.10885563],[38334298205,9356.4,-0.20124428],[38328427481,9356.5,-0.1],[38333343289,9356.6,-0.41034213],[38334297205,9356.6,-0.08835018],[38334277927,9356.741101161,-0.0737],[38334311645,9356.8,-0.5],[38334309002,9356.9,-5],[38334309736,9357,-0.10680107],[38334306448,9357.4,-0.18645275],[38333693302,9357.7,-0.2672],[38332815159,9357.8,-0.0011],[38331239824,9358.2,-0.02],[38334271608,9358.3,-2.999],[38334311971,9358.4,-0.55],[38333919260,9358.5,-1.9972841],[38334265365,9358.5,-1.7841],[38334277960,9359,-3],[38334274601,9359.020969848,-3],[38326848839,9359.1,-0.84],[38334291080,9359.247048,-0.16199869],[38326848844,9359.4,-1.84],[38333680200,9359.6,-0.26713],[38331326606,9359.8,-0.84454],[38334309738,9359.8,-0.10680107],[38331314707,9359.9,-0.2],[38333919803,9360.9,-1.41177599],[38323651149,9361.33417827,-0.71442],[38333656906,9361.5,-0.26705],[38334035500,9361.5,-0.40861586],[38334091886,9362.4,-6.85940815],[38334269617,9362.5,-4],[38323629409,9362.545858872,-2.40497],[38334309737,9362.7,-0.10680107],[38334312380,9362.7,-3],[38325280830,9362.8,-1.75123],[38326622800,9362.8,-1.05145],[38333175230,9363,-0.0011],[38326848745,9363.2,-0.79],[38334308960,9363.206775564,-0.12],[38333920234,9363.3,-1.25318113],[38326848843,9363.4,-1.29],[38331239823,9363.4,-0.02],[38333209613,9363.4,-0.26719],[38334299964,9364,-0.05583123],[38323470224,9364.161816648,-0.12912],[38334284711,9365,-0.21346019],[38334299594,9365,-2.6757062],[38323211816,9365.073132585,-0.21262],[38334312456,9365.1,-0.11167861],[38333209612,9365.2,-0.26719],[38327770474,9365.3,-0.0073],[38334298788,9365.3,-0.3],[38334075803,9365.409831204,-0.30772637],[38334309740,9365.5,-0.10680107],[38326608767,9365.7,-2.76809],[38333920657,9365.7,-1.25848083],[38329594226,9366.6,-0.02587],[38334311813,9366.7,-4.72290945],[38316386301,9367.39258128,-2.37581],[38334302026,9367.4,-4.5],[38334228915,9367.9,-0.81725458],[38333921381,9368.1,-1.72213641],[38333175678,9368.2,-0.0011],[38334301150,9368.2,-2.654604],[38334297208,9368.3,-0.78036466],[38334309739,9368.3,-0.10680107],[38331227515,9368.7,-0.02],[38331184470,9369,-0.003975],[38334203436,9369.319616,-0.32397695],[38334269964,9369.7,-0.5],[38328386732,9370,-4.11759935],[38332719555,9370,-0.025],[38333921935,9370.5,-1.2224398],[38334258511,9370.5,-0.35],[38326848842,9370.8,-0.34],[38333985038,9370.9,-0.8551502],[38334283018,9370.9,-1],[38326848744,9371,-1.34]],5]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1355,7 +1355,7 @@ func TestWsOrderBook(t *testing.T) { } func TestWsTradeResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[18788,[[412685577,1580268444802,11.1998,176.3],[412685575,1580268444802,5,176.29952759],[412685574,1580268374717,1.99069999,176.41],[412685573,1580268374717,1.00930001,176.41],[412685572,1580268358760,0.9907,176.47],[412685571,1580268324362,0.5505,176.44],[412685570,1580268297270,-0.39040819,176.39],[412685568,1580268297270,-0.39780162,176.46475676],[412685567,1580268283470,-0.09,176.41],[412685566,1580268256536,-2.31310783,176.48],[412685565,1580268256536,-0.59669217,176.49],[412685564,1580268256536,-0.9902,176.49],[412685562,1580268194474,0.9902,176.55],[412685561,1580268186215,0.1,176.6],[412685560,1580268185964,-2.17096773,176.5],[412685559,1580268185964,-1.82903227,176.51],[412685558,1580268181215,2.098914,176.53],[412685557,1580268169844,16.7302,176.55],[412685556,1580268169844,3.25,176.54],[412685555,1580268155725,0.23576115,176.45],[412685553,1580268155725,3,176.44596249],[412685552,1580268155725,3.25,176.44],[412685551,1580268155725,5,176.44],[412685550,1580268155725,0.65830078,176.41],[412685549,1580268155725,0.45063807,176.41],[412685548,1580268153825,-0.67604704,176.39],[412685547,1580268145713,2.5883,176.41],[412685543,1580268087513,12.92927,176.33],[412685542,1580268087513,0.40083,176.33],[412685533,1580268005756,-0.17096773,176.32]]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1365,7 +1365,7 @@ func TestWsTradeResponse(t *testing.T) { } func TestWsTickerResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[11534,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1376,7 +1376,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123412}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123412}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123412,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1387,7 +1387,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123413}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123413}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123413,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1398,7 +1398,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - err = b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123414}) + err = b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{pair}, Channel: wsTicker, Key: 123414}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON = `[123414,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) @@ -1408,7 +1408,7 @@ func TestWsTickerResponse(t *testing.T) { } func TestWsCandleResponse(t *testing.T) { - err := b.Websocket.AddSubscriptions(&subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) + err := b.Websocket.AddSubscriptions(b.Websocket.Conn, &subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) require.NoError(t, err, "AddSubscriptions must not error") pressXToJSON := `[343351,[[1574698260000,7379.785503,7383.8,7388.3,7379.785503,1.68829482]]]` err = b.wsHandleData([]byte(pressXToJSON)) diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index f60bd896..b4cb9904 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -512,7 +512,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c.Key = int(chanID) // subscribeToChan removes the old subID keyed Subscription - if err := b.Websocket.AddSuccessfulSubscriptions(c); err != nil { + if err := b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c); err != nil { return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) } @@ -1661,7 +1661,7 @@ func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error { // Resub will block so we have to do this in a goro go func() { - if err := b.Websocket.ResubscribeToChannel(c); err != nil { + if err := b.Websocket.ResubscribeToChannel(b.Websocket.Conn, c); err != nil { log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", b.Name, err) } }() @@ -1748,13 +1748,13 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error { // Add a temporary Key so we can find this Sub when we get the resp without delay or context switch // Otherwise we might drop the first messages after the subscribed resp c.Key = subID // Note subID string type avoids conflicts with later chanID key - if err = b.Websocket.AddSubscriptions(c); err != nil { + if err = b.Websocket.AddSubscriptions(b.Websocket.Conn, c); err != nil { return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) } // Always remove the temporary subscription keyed by subID defer func() { - _ = b.Websocket.RemoveSubscriptions(c) + _ = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, c) }() respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, "subscribe:"+subID, req) @@ -1861,7 +1861,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error { return wErr } - return b.Websocket.RemoveSubscriptions(c) + return b.Websocket.RemoveSubscriptions(b.Websocket.Conn, c) } // getErrResp takes a json response string and looks for an error event type diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 82e50075..cf52cb3a 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -218,7 +218,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: publicBitfinexWebsocketEndpoint, @@ -227,7 +227,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: authenticatedBitfinexWebsocketEndpoint, diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 461f6dcb..ad66c19e 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -207,7 +207,7 @@ func (b *Bithumb) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 7c6fdb5c..ff6f0ec1 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -167,7 +167,7 @@ func (b *Bithumb) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(time.Second), diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 8c4dbf16..c2c5c71f 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -602,7 +602,7 @@ func (b *Bitmex) Subscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(subs...) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, subs...) } return err } @@ -621,7 +621,7 @@ func (b *Bitmex) Unsubscribe(subs subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.RemoveSubscriptions(subs...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, subs...) } return err } diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 0d20a361..78a13945 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -208,7 +208,7 @@ func (b *Bitmex) Setup(exch *config.Exchange) error { if err != nil { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: bitmexWSURL, diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index 5effb2a6..08dbbf25 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -295,7 +295,7 @@ func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -317,7 +317,7 @@ func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.RemoveSubscriptions(s) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 5ce0156d..f8ff0632 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -163,7 +163,7 @@ func (b *Bitstamp) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: b.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 26b66e64..333f77f0 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -377,7 +377,7 @@ func (b *BTCMarkets) Subscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(s) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -417,7 +417,7 @@ func (b *BTCMarkets) Unsubscribe(subs subscription.List) error { err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, req) if err == nil { - err = b.Websocket.RemoveSubscriptions(s) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index e189ddab..0e97b9e7 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -172,7 +172,7 @@ func (b *BTCMarkets) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 9399c9ce..bbe52cf6 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -395,7 +395,7 @@ func (b *BTSE) Subscribe(channelsToSubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, sub) if err == nil { - err = b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...) + err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, channelsToSubscribe...) } return err } @@ -410,7 +410,7 @@ func (b *BTSE) Unsubscribe(channelsToUnsubscribe subscription.List) error { } err := b.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, unSub) if err == nil { - err = b.Websocket.RemoveSubscriptions(channelsToUnsubscribe...) + err = b.Websocket.RemoveSubscriptions(b.Websocket.Conn, channelsToUnsubscribe...) } return err } diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index 27df6de1..3b897a4a 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -197,7 +197,7 @@ func (b *BTSE) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index e0b1ea0b..a0420063 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -252,7 +252,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { if err != nil { return err } - err = by.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: by.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: bybitWebsocketTimer, @@ -261,7 +261,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { return err } - return by.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: websocketPrivate, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 22a16cd6..e18f5255 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -426,7 +426,7 @@ func (c *CoinbasePro) Subscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(subs...) + err = c.Websocket.AddSuccessfulSubscriptions(c.Websocket.Conn, subs...) } return err } @@ -462,7 +462,7 @@ func (c *CoinbasePro) Unsubscribe(subs subscription.List) error { } err := c.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = c.Websocket.RemoveSubscriptions(subs...) + err = c.Websocket.RemoveSubscriptions(c.Websocket.Conn, subs...) } return err } diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 81a64ccf..d5d9173d 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -174,7 +174,7 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index e147cb9e..7e1dece6 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -621,7 +621,7 @@ func (c *COINUT) Subscribe(subs subscription.List) error { } err = c.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, subscribe) if err == nil { - err = c.Websocket.AddSuccessfulSubscriptions(s) + err = c.Websocket.AddSuccessfulSubscriptions(c.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -664,7 +664,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error { case len(val) == 0, val[0] != "OK": err = common.AppendError(errs, fmt.Errorf("%v unsubscribe failed for channel %v", c.Name, s.Channel)) default: - err = c.Websocket.RemoveSubscriptions(s) + err = c.Websocket.RemoveSubscriptions(c.Websocket.Conn, s) } } if err != nil { diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 1dd0be29..20ae1cf7 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -148,7 +148,7 @@ func (c *COINUT) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(33 * time.Millisecond), diff --git a/exchanges/deribit/deribit_wrapper.go b/exchanges/deribit/deribit_wrapper.go index abe7c67f..38023875 100644 --- a/exchanges/deribit/deribit_wrapper.go +++ b/exchanges/deribit/deribit_wrapper.go @@ -211,7 +211,7 @@ func (d *Deribit) Setup(exch *config.Exchange) error { // setup option decimal regex at startup to make constant checks more efficient optionRegex = regexp.MustCompile(optionDecimalRegex) - return d.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return d.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: d.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 3599b5e9..9676364c 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -1123,7 +1123,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels subscription.List) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.SubscribeToChannels(channels) + return b.Websocket.SubscribeToChannels(b.Websocket.Conn, channels) } // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe @@ -1132,7 +1132,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels subscription.List) error if b.Websocket == nil { return common.ErrFunctionNotSupported } - return b.Websocket.UnsubscribeChannels(channels) + return b.Websocket.UnsubscribeChannels(b.Websocket.Conn, channels) } // GetSubscriptions returns a copied list of subscriptions diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index c179d97e..5e8ee348 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -22,7 +22,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/futures" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" @@ -2540,7 +2542,7 @@ const wsTickerPushDataJSON = `{"time": 1606291803, "channel": "spot.tickers", "e func TestWsTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsTickerPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsTickerPushDataJSON)); err != nil { t.Errorf("%s websocket ticker push data error: %v", g.Name, err) } } @@ -2549,7 +2551,7 @@ const wsTradePushDataJSON = `{ "time": 1606292218, "channel": "spot.trades", "ev func TestWsTradePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsTradePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsTradePushDataJSON)); err != nil { t.Errorf("%s websocket trade push data error: %v", g.Name, err) } } @@ -2558,7 +2560,7 @@ const wsCandlestickPushDataJSON = `{"time": 1606292600, "channel": "spot.candles func TestWsCandlestickPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCandlestickPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCandlestickPushDataJSON)); err != nil { t.Errorf("%s websocket candlestick push data error: %v", g.Name, err) } } @@ -2567,7 +2569,7 @@ const wsOrderbookTickerJSON = `{"time": 1606293275, "channel": "spot.book_ticker func TestWsOrderbookTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsOrderbookTickerJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsOrderbookTickerJSON)); err != nil { t.Errorf("%s websocket orderbook push data error: %v", g.Name, err) } } @@ -2579,11 +2581,11 @@ const ( func TestWsOrderbookSnapshotPushData(t *testing.T) { t.Parallel() - err := g.wsHandleData([]byte(wsOrderbookSnapshotPushDataJSON)) + err := g.WsHandleSpotData(context.Background(), []byte(wsOrderbookSnapshotPushDataJSON)) if err != nil { t.Errorf("%s websocket orderbook snapshot push data error: %v", g.Name, err) } - if err = g.wsHandleData([]byte(wsOrderbookUpdatePushDataJSON)); err != nil { + if err = g.WsHandleSpotData(context.Background(), []byte(wsOrderbookUpdatePushDataJSON)); err != nil { t.Errorf("%s websocket orderbook update push data error: %v", g.Name, err) } } @@ -2592,7 +2594,7 @@ const wsSpotOrderPushDataJSON = `{"time": 1605175506, "channel": "spot.orders", func TestWsPushOrders(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsSpotOrderPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsSpotOrderPushDataJSON)); err != nil { t.Errorf("%s websocket orders push data error: %v", g.Name, err) } } @@ -2601,7 +2603,7 @@ const wsUserTradePushDataJSON = `{"time": 1605176741, "channel": "spot.usertrade func TestWsUserTradesPushDataJSON(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsUserTradePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsUserTradePushDataJSON)); err != nil { t.Errorf("%s websocket users trade push data error: %v", g.Name, err) } } @@ -2610,7 +2612,7 @@ const wsBalancesPushDataJSON = `{"time": 1605248616, "channel": "spot.balances", func TestBalancesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsBalancesPushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsBalancesPushDataJSON)); err != nil { t.Errorf("%s websocket balances push data error: %v", g.Name, err) } } @@ -2619,7 +2621,7 @@ const wsMarginBalancePushDataJSON = `{"time": 1605248616, "channel": "spot.fundi func TestMarginBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsMarginBalancePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket margin balance push data error: %v", g.Name, err) } } @@ -2628,7 +2630,7 @@ const wsCrossMarginBalancePushDataJSON = `{"time": 1605248616,"channel": "spot.c func TestCrossMarginBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCrossMarginBalancePushDataJSON)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCrossMarginBalancePushDataJSON)); err != nil { t.Errorf("%s websocket cross margin balance push data error: %v", g.Name, err) } } @@ -2637,7 +2639,7 @@ const wsCrossMarginBalanceLoan = `{ "time":1658289372, "channel":"spot.cross_loa func TestCrossMarginBalanceLoan(t *testing.T) { t.Parallel() - if err := g.wsHandleData([]byte(wsCrossMarginBalanceLoan)); err != nil { + if err := g.WsHandleSpotData(context.Background(), []byte(wsCrossMarginBalanceLoan)); err != nil { t.Errorf("%s websocket cross margin loan push data error: %v", g.Name, err) } } @@ -2646,7 +2648,7 @@ const wsFuturesTickerPushDataJSON = `{"time": 1541659086, "channel": "futures.ti func TestFuturesTicker(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesTickerPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesTickerPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket push data error: %v", g.Name, err) } } @@ -2655,7 +2657,7 @@ const wsFuturesTradesPushDataJSON = `{"channel": "futures.trades","event": "upda func TestFuturesTrades(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesTradesPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesTradesPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket push data error: %v", g.Name, err) } } @@ -2666,7 +2668,7 @@ const ( func TestOrderbookData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesOrderbookTickerJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesOrderbookTickerJSON), asset.Futures); err != nil { t.Errorf("%s websocket orderbook ticker push data error: %v", g.Name, err) } } @@ -2675,7 +2677,7 @@ const wsFuturesOrderPushDataJSON = `{ "channel": "futures.orders", "event": "upd func TestFuturesOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesOrderPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesOrderPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures order push data error: %v", g.Name, err) } } @@ -2684,7 +2686,7 @@ const wsFuturesUsertradesPushDataJSON = `{"time": 1543205083, "channel": "future func TestFuturesUserTrades(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesUsertradesPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesUsertradesPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures user trades push data error: %v", g.Name, err) } } @@ -2693,7 +2695,7 @@ const wsFuturesLiquidationPushDataJSON = `{"channel": "futures.liquidates", "eve func TestFuturesLiquidationPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesLiquidationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesLiquidationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures liquidation push data error: %v", g.Name, err) } } @@ -2702,7 +2704,7 @@ const wsFuturesAutoDelevergesNotification = `{"channel": "futures.auto_deleverag func TestFuturesAutoDeleverges(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesAutoDelevergesNotification), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesAutoDelevergesNotification), asset.Futures); err != nil { t.Errorf("%s websocket futures auto deleverge push data error: %v", g.Name, err) } } @@ -2711,7 +2713,7 @@ const wsFuturesPositionClosePushDataJSON = ` {"channel": "futures.position_close func TestPositionClosePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesPositionClosePushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesPositionClosePushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures position close push data error: %v", g.Name, err) } } @@ -2720,7 +2722,7 @@ const wsFuturesBalanceNotificationPushDataJSON = `{"channel": "futures.balances" func TestFuturesBalanceNotification(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesBalanceNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesBalanceNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures balance notification push data error: %v", g.Name, err) } } @@ -2729,7 +2731,7 @@ const wsFuturesReduceRiskLimitNotificationPushDataJSON = `{"time": 1551858330, " func TestFuturesReduceRiskLimitPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesReduceRiskLimitNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesReduceRiskLimitNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures reduce risk limit notification push data error: %v", g.Name, err) } } @@ -2738,7 +2740,7 @@ const wsFuturesPositionsNotificationPushDataJSON = `{"time": 1588212926,"channel func TestFuturesPositionsNotification(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesPositionsNotificationPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesPositionsNotificationPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures positions change notification push data error: %v", g.Name, err) } } @@ -2747,7 +2749,7 @@ const wsFuturesAutoOrdersPushDataJSON = `{"time": 1596798126,"channel": "futures func TestFuturesAutoOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleFuturesData([]byte(wsFuturesAutoOrdersPushDataJSON), asset.Futures); err != nil { + if err := g.WsHandleFuturesData(context.Background(), []byte(wsFuturesAutoOrdersPushDataJSON), asset.Futures); err != nil { t.Errorf("%s websocket futures auto orders push data error: %v", g.Name, err) } } @@ -2758,7 +2760,7 @@ const optionsContractTickerPushDataJSON = `{"time": 1630576352, "channel": "opti func TestOptionsContractTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractTickerPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options contract ticker push data failed with error %v", g.Name, err) } } @@ -2767,7 +2769,7 @@ const optionsUnderlyingTickerPushDataJSON = `{"time": 1630576352, "channel": "op func TestOptionsUnderlyingTickerPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingTickerPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingTickerPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying ticker push data error: %v", g.Name, err) } } @@ -2776,7 +2778,7 @@ const optionsContractTradesPushDataJSON = `{"time": 1630576356, "channel": "opti func TestOptionsContractTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractTradesPushDataJSON)); err != nil { t.Errorf("%s websocket contract trades push data error: %v", g.Name, err) } } @@ -2785,7 +2787,7 @@ const optionsUnderlyingTradesPushDataJSON = `{"time": 1630576356, "channel": "op func TestOptionsUnderlyingTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingTradesPushDataJSON)); err != nil { t.Errorf("%s websocket underlying trades push data error: %v", g.Name, err) } } @@ -2794,7 +2796,7 @@ const optionsUnderlyingPricePushDataJSON = `{ "time": 1630576356, "channel": "op func TestOptionsUnderlyingPricePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingPricePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingPricePushDataJSON)); err != nil { t.Errorf("%s websocket underlying price push data error: %v", g.Name, err) } } @@ -2803,7 +2805,7 @@ const optionsMarkPricePushDataJSON = `{ "time": 1630576356, "channel": "options. func TestOptionsMarkPricePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsMarkPricePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsMarkPricePushDataJSON)); err != nil { t.Errorf("%s websocket mark price push data error: %v", g.Name, err) } } @@ -2812,7 +2814,7 @@ const optionsSettlementsPushDataJSON = `{ "time": 1630576356, "channel": "option func TestSettlementsPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsSettlementsPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsSettlementsPushDataJSON)); err != nil { t.Errorf("%s websocket options settlements push data error: %v", g.Name, err) } } @@ -2821,7 +2823,7 @@ const optionsContractPushDataJSON = `{"time": 1630576356, "channel": "options.co func TestOptionsContractPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts push data error: %v", g.Name, err) } } @@ -2833,10 +2835,10 @@ const ( func TestOptionsCandlesticksPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsContractCandlesticksPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsContractCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options contracts candlestick push data error: %v", g.Name, err) } - if err := g.wsHandleOptionsData([]byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUnderlyingCandlesticksPushDataJSON)); err != nil { t.Errorf("%s websocket options underlying candlestick push data error: %v", g.Name, err) } } @@ -2850,17 +2852,17 @@ const ( func TestOptionsOrderbookPushData(t *testing.T) { t.Parallel() - err := g.wsHandleOptionsData([]byte(optionsOrderbookTickerPushDataJSON)) + err := g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookTickerPushDataJSON)) if err != nil { t.Errorf("%s websocket options orderbook ticker push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookSnapshotPushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookSnapshotPushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook snapshot push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookUpdatePushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookUpdatePushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook update push data error: %v", g.Name, err) } - if err = g.wsHandleOptionsData([]byte(optionsOrderbookSnapshotUpdateEventPushDataJSON)); err != nil { + if err = g.WsHandleOptionsData(context.Background(), []byte(optionsOrderbookSnapshotUpdateEventPushDataJSON)); err != nil { t.Errorf("%s websocket options orderbook snapshot update event push data error: %v", g.Name, err) } } @@ -2869,7 +2871,7 @@ const optionsOrderPushDataJSON = `{"time": 1630654851,"channel": "options.orders func TestOptionsOrderPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsOrderPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsOrderPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", g.Name, err) } } @@ -2878,7 +2880,7 @@ const optionsUsersTradesPushDataJSON = `{ "time": 1639144214, "channel": "option func TestOptionUserTradesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsUsersTradesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsUsersTradesPushDataJSON)); err != nil { t.Errorf("%s websocket options orders push data error: %v", g.Name, err) } } @@ -2887,7 +2889,7 @@ const optionsLiquidatesPushDataJSON = `{ "channel": "options.liquidates", "event func TestOptionsLiquidatesPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsLiquidatesPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsLiquidatesPushDataJSON)); err != nil { t.Errorf("%s websocket options liquidates push data error: %v", g.Name, err) } } @@ -2896,7 +2898,7 @@ const optionsSettlementPushDataJSON = `{ "channel": "options.user_settlements", func TestOptionsSettlementPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsSettlementPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsSettlementPushDataJSON)); err != nil { t.Errorf("%s websocket options settlement push data error: %v", g.Name, err) } } @@ -2905,7 +2907,7 @@ const optionsPositionClosePushDataJSON = `{"channel": "options.position_closes", func TestOptionsPositionClosePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsPositionClosePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsPositionClosePushDataJSON)); err != nil { t.Errorf("%s websocket options position close push data error: %v", g.Name, err) } } @@ -2914,7 +2916,7 @@ const optionsBalancePushDataJSON = `{ "channel": "options.balances", "event": "u func TestOptionsBalancePushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsBalancePushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsBalancePushDataJSON)); err != nil { t.Errorf("%s websocket options balance push data error: %v", g.Name, err) } } @@ -2923,7 +2925,7 @@ const optionsPositionPushDataJSON = `{"time": 1630654851, "channel": "options.po func TestOptionsPositionPushData(t *testing.T) { t.Parallel() - if err := g.wsHandleOptionsData([]byte(optionsPositionPushDataJSON)); err != nil { + if err := g.WsHandleOptionsData(context.Background(), []byte(optionsPositionPushDataJSON)); err != nil { t.Errorf("%s websocket options position push data error: %v", g.Name, err) } } @@ -2935,11 +2937,11 @@ const ( func TestFuturesOrderbookPushData(t *testing.T) { t.Parallel() - err := g.wsHandleFuturesData([]byte(futuresOrderbookPushData), asset.Futures) + err := g.WsHandleFuturesData(context.Background(), []byte(futuresOrderbookPushData), asset.Futures) if err != nil { t.Error(err) } - err = g.wsHandleFuturesData([]byte(futuresOrderbookUpdatePushData), asset.Futures) + err = g.WsHandleFuturesData(context.Background(), []byte(futuresOrderbookUpdatePushData), asset.Futures) if err != nil { t.Error(err) } @@ -2949,14 +2951,13 @@ const futuresCandlesticksPushData = `{"time": 1678469467, "time_ms": 16784694679 func TestFuturesCandlestickPushData(t *testing.T) { t.Parallel() - err := g.wsHandleFuturesData([]byte(futuresCandlesticksPushData), asset.Futures) + err := g.WsHandleFuturesData(context.Background(), []byte(futuresCandlesticksPushData), asset.Futures) if err != nil { t.Error(err) } } -// TestGenerateSubscriptions exercises generateSubscriptions -func TestGenerateSubscriptions(t *testing.T) { +func TestGenerateSubscriptionsSpot(t *testing.T) { t.Parallel() g := new(Gateio) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes @@ -2966,7 +2967,7 @@ func TestGenerateSubscriptions(t *testing.T) { g.Features.Subscriptions = append(g.Features.Subscriptions, &subscription.Subscription{ Enabled: true, Channel: spotOrderbookChannel, Asset: asset.Spot, Interval: kline.ThousandMilliseconds, Levels: 5, }) - subs, err := g.generateSubscriptions() + subs, err := g.generateSubscriptionsSpot() require.NoError(t, err, "generateSubscriptions must not error") exp := subscription.List{} for _, s := range g.Features.Subscriptions { @@ -3005,13 +3006,10 @@ func TestGenerateSubscriptions(t *testing.T) { func TestSubscribe(t *testing.T) { t.Parallel() - g := new(Gateio) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - require.NoError(t, testexch.Setup(g), "Test instance Setup must not error") subs, err := g.Features.Subscriptions.ExpandTemplates(g) require.NoError(t, err, "ExpandTemplates must not error") g.Features.Subscriptions = subscription.List{} - testexch.SetupWs(t, g) - err = g.Subscribe(subs) + err = g.Subscribe(context.Background(), &DummyConnection{}, subs) require.NoError(t, err, "Subscribe must not error") } @@ -3023,7 +3021,11 @@ func TestGenerateDeliveryFuturesDefaultSubscriptions(t *testing.T) { } func TestGenerateFuturesDefaultSubscriptions(t *testing.T) { t.Parallel() - if _, err := g.GenerateFuturesDefaultSubscriptions(); err != nil { + if _, err := g.GenerateFuturesDefaultSubscriptions(currency.USDT); err != nil { + t.Error(err) + } + + if _, err := g.GenerateFuturesDefaultSubscriptions(currency.BTC); err != nil { t.Error(err) } } @@ -3657,3 +3659,26 @@ func TestGenerateWebsocketMessageID(t *testing.T) { t.Parallel() require.NotEmpty(t, g.GenerateWebsocketMessageID(false)) } + +type DummyConnection struct{ stream.Connection } + +func (d *DummyConnection) GenerateMessageID(bool) int64 { return 1337 } +func (d *DummyConnection) SendMessageReturnResponse(context.Context, request.EndpointLimit, any, any) ([]byte, error) { + return []byte(`{"time":1726121320,"time_ms":1726121320745,"id":1,"conn_id":"f903779a148987ca","trace_id":"d8ee37cd14347e4ed298d44e69aedaa7","channel":"spot.tickers","event":"subscribe","payload":["BRETT_USDT"],"result":{"status":"success"},"requestId":"d8ee37cd14347e4ed298d44e69aedaa7"}`), nil +} + +func TestHandleSubscriptions(t *testing.T) { + t.Parallel() + + subs := subscription.List{{Channel: subscription.OrderbookChannel}} + + err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { + return []WsInput{{}}, nil + }) + require.NoError(t, err) + + err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { + return []WsInput{{}}, nil + }) + require.NoError(t, err) +} diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 4d38fd37..d00f03b4 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -50,6 +50,9 @@ const ( spotFundingBalanceChannel = "spot.funding_balances" crossMarginBalanceChannel = "spot.cross_balances" crossMarginLoanChannel = "spot.cross_loan" + + subscribeEvent = "subscribe" + unsubscribeEvent = "unsubscribe" ) var defaultSubscriptions = subscription.List{ @@ -71,16 +74,13 @@ var subscriptionNames = map[string]string{ subscription.AllTradesChannel: spotTradesChannel, } -// WsConnect initiates a websocket connection -func (g *Gateio) WsConnect() error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } +// WsConnectSpot initiates a websocket connection +func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { return err } - err = g.Websocket.Conn.Dial(&websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } @@ -88,14 +88,12 @@ func (g *Gateio) WsConnect() error { if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(request.Unset, stream.PingHandler{ Websocket: true, Delay: time.Second * 15, Message: pingMessage, MessageType: websocket.TextMessage, }) - g.Websocket.Wg.Add(1) - go g.wsReadConnData() return nil } @@ -108,29 +106,15 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (st return hex.EncodeToString(mac.Sum(nil)), nil } -// wsReadConnData receives and passes on websocket messages for processing -func (g *Gateio) wsReadConnData() { - defer g.Websocket.Wg.Done() - for { - resp := g.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - err := g.wsHandleData(resp.Raw) - if err != nil { - g.Websocket.DataHandler <- err - } - } -} - -func (g *Gateio) wsHandleData(respRaw []byte) error { +// WsHandleSpotData handles spot data +func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } @@ -641,22 +625,25 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error { return nil } -// generateSubscriptions returns configured subscriptions -func (g *Gateio) generateSubscriptions() (subscription.List, error) { +// generateSubscriptionsSpot returns configured subscriptions +func (g *Gateio) generateSubscriptionsSpot() (subscription.List, error) { return g.Features.Subscriptions.ExpandTemplates(g) } // GetSubscriptionTemplate returns a subscription channel template func (g *Gateio) GetSubscriptionTemplate(_ *subscription.Subscription) (*template.Template, error) { - return template.New("master.tmpl").Funcs(sprig.FuncMap()).Funcs(template.FuncMap{ - "channelName": channelName, - "singleSymbolChannel": singleSymbolChannel, - "interval": g.GetIntervalString, - }).Parse(subTplText) + return template.New("master.tmpl"). + Funcs(sprig.FuncMap()). + Funcs(template.FuncMap{ + "channelName": channelName, + "singleSymbolChannel": singleSymbolChannel, + "interval": g.GetIntervalString, + }). + Parse(subTplText) } // manageSubs sends a websocket message to subscribe or unsubscribe from a list of channel -func (g *Gateio) manageSubs(event string, subs subscription.List) error { +func (g *Gateio) manageSubs(ctx context.Context, event string, conn stream.Connection, subs subscription.List) error { var errs error subs, errs = subs.ExpandTemplates(g) if errs != nil { @@ -665,11 +652,11 @@ func (g *Gateio) manageSubs(event string, subs subscription.List) error { for _, s := range subs { if err := func() error { - msg, err := g.manageSubReq(event, s) + msg, err := g.manageSubReq(ctx, event, conn, s) if err != nil { return err } - result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, msg.ID, msg) + result, err := conn.SendMessageReturnResponse(ctx, request.Unset, msg.ID, msg) if err != nil { return err } @@ -681,9 +668,9 @@ func (g *Gateio) manageSubs(event string, subs subscription.List) error { return fmt.Errorf("(%d) %s", resp.Error.Code, resp.Error.Message) } if event == "unsubscribe" { - return g.Websocket.RemoveSubscriptions(s) + return g.Websocket.RemoveSubscriptions(conn, s) } - return g.Websocket.AddSuccessfulSubscriptions(s) + return g.Websocket.AddSuccessfulSubscriptions(conn, s) }(); err != nil { errs = common.AppendError(errs, fmt.Errorf("%s %s %s: %w", s.Channel, s.Asset, s.Pairs, err)) } @@ -692,16 +679,16 @@ func (g *Gateio) manageSubs(event string, subs subscription.List) error { } // manageSubReq constructs the subscription management message for a subscription -func (g *Gateio) manageSubReq(event string, s *subscription.Subscription) (*WsInput, error) { +func (g *Gateio) manageSubReq(ctx context.Context, event string, conn stream.Connection, s *subscription.Subscription) (*WsInput, error) { req := &WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelName(s), Time: time.Now().Unix(), Payload: strings.Split(s.QualifiedChannel, ","), } if s.Authenticated { - creds, err := g.GetCredentials(context.TODO()) + creds, err := g.GetCredentials(ctx) if err != nil { return nil, err } @@ -719,13 +706,13 @@ func (g *Gateio) manageSubReq(event string, s *subscription.Subscription) (*WsIn } // Subscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Subscribe(subs subscription.List) error { - return g.manageSubs("subscribe", subs) +func (g *Gateio) Subscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { + return g.manageSubs(ctx, subscribeEvent, conn, subs) } // Unsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Unsubscribe(subs subscription.List) error { - return g.manageSubs("unsubscribe", subs) +func (g *Gateio) Unsubscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { + return g.manageSubs(ctx, unsubscribeEvent, conn, subs) } func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset.Item]bool { @@ -782,3 +769,37 @@ const subTplText = ` {{- end }} {{- end }} ` + +// GeneratePayload returns the payload for a websocket message +type GeneratePayload func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) + +// handleSubscription sends a websocket message to receive data from the channel +func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List, generatePayload GeneratePayload) error { + payloads, err := generatePayload(ctx, conn, event, channelsToSubscribe) + if err != nil { + return err + } + var errs error + for k := range payloads { + result, err := conn.SendMessageReturnResponse(ctx, request.Unset, payloads[k].ID, payloads[k]) + if err != nil { + errs = common.AppendError(errs, err) + continue + } + var resp WsEventResponse + if err = json.Unmarshal(result, &resp); err != nil { + errs = common.AppendError(errs, err) + } else { + if resp.Error != nil && resp.Error.Code != 0 { + errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) + continue + } + if event == subscribeEvent { + errs = common.AppendError(errs, g.Websocket.AddSuccessfulSubscriptions(conn, channelsToSubscribe[k])) + } else { + errs = common.AppendError(errs, g.Websocket.RemoveSubscriptions(conn, channelsToSubscribe[k])) + } + } + } + return errs +} diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index a1d4933e..6c337706 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -27,6 +27,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -93,12 +94,12 @@ func (g *Gateio) SetDefaults() { OrderbookFetching: true, TradeFetching: true, KlineFetching: true, - FullPayloadSubscribe: true, AuthenticatedEndpoints: true, MessageCorrelation: true, GetOrder: true, AccountBalance: true, Subscribe: true, + Unsubscribe: true, }, WithdrawPermissions: exchange.AutoWithdrawCrypto | exchange.NoFiatWithdrawals, @@ -155,26 +156,16 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } + // TODO: Majority of margin REST endpoints are labelled as deprecated on the API docs. These will need to be removed. err = g.DisableAssetWebsocketSupport(asset.Margin) if err != nil { log.Errorln(log.ExchangeSys, err) } + // TODO: Add websocket cross margin support. err = g.DisableAssetWebsocketSupport(asset.CrossMargin) if err != nil { log.Errorln(log.ExchangeSys, err) } - err = g.DisableAssetWebsocketSupport(asset.Futures) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } - err = g.DisableAssetWebsocketSupport(asset.DeliveryFutures) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } - err = g.DisableAssetWebsocketSupport(asset.Options) - if err != nil { - log.Errorln(log.ExchangeSys, err) - } g.API.Endpoints = g.NewEndpoints() err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: gateioTradeURL, @@ -206,31 +197,101 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } - wsRunningURL, err := g.API.Endpoints.GetURL(exchange.WebsocketSpot) - if err != nil { - return err - } - err = g.Websocket.Setup(&stream.WebsocketSetup{ - ExchangeConfig: exch, - DefaultURL: gateioWebsocketEndpoint, - RunningURL: wsRunningURL, - Connector: g.WsConnect, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.generateSubscriptions, - Features: &g.Features.Supports.WebsocketCapabilities, - FillsFeed: g.Features.Enabled.FillsFeed, - TradeFeed: g.Features.Enabled.TradeFeed, + ExchangeConfig: exch, + Features: &g.Features.Supports.WebsocketCapabilities, + FillsFeed: g.Features.Enabled.FillsFeed, + TradeFeed: g.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, }) if err != nil { return err } - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + // Spot connection + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: gateioWebsocketEndpoint, RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, + GenerateSubscriptions: g.generateSubscriptionsSpot, + Connector: g.WsConnectSpot, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: futuresWebsocketUsdtURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + + // Futures connection - BTC margined + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: futuresWebsocketBtcURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.Futures) + }, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + + // TODO: Add BTC margined delivery futures. + // Futures connection - Delivery - USDT margined + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: deliveryRealUSDTTradingURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: func(ctx context.Context, incoming []byte) error { + return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) + }, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + + // Futures connection - Options + return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: optionsWebsocketURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 4972f6ef..085e0c2c 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -4,14 +4,11 @@ import ( "context" "encoding/json" "errors" - "fmt" "net/http" "strconv" - "strings" "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -19,7 +16,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -39,60 +35,27 @@ var defaultDeliveryFuturesSubscriptions = []string{ futuresCandlesticksChannel, } -// responseDeliveryFuturesStream a channel thought which the data coming from the two websocket connection will go through. -var responseDeliveryFuturesStream = make(chan stream.Response) - var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account -func (g *Gateio) WsDeliveryFuturesConnect() error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } +func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { return err } - var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(deliveryRealUSDTTradingURL, false, true) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: deliveryRealBTCTradingURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: g.Config.WebsocketResponseCheckTimeout, - ResponseMaxLimit: g.Config.WebsocketResponseMaxLimit, - Authenticated: true, - }) - if err != nil { - return err - } - err = g.Websocket.AuthConn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - g.Websocket.Wg.Add(3) - go g.wsReadDeliveryFuturesData() - go g.wsFunnelDeliveryFuturesConnectionData(g.Websocket.Conn) - go g.wsFunnelDeliveryFuturesConnectionData(g.Websocket.AuthConn) - if g.Verbose { - log.Debugf(log.ExchangeSys, "successful connection to %v\n", - g.Websocket.GetWebsocketURL()) - } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Time: time.Now().Unix(), + ID: conn.GenerateMessageID(false), + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(request.Unset, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, @@ -101,47 +64,6 @@ func (g *Gateio) WsDeliveryFuturesConnect() error { return nil } -// wsReadDeliveryFuturesData read coming messages thought the websocket connection and pass the data to wsHandleFuturesData for further process. -func (g *Gateio) wsReadDeliveryFuturesData() { - defer g.Websocket.Wg.Done() - for { - select { - case <-g.Websocket.ShutdownC: - select { - case resp := <-responseDeliveryFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.DeliveryFutures) - if err != nil { - select { - case g.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", g.Name, err) - } - } - default: - } - return - case resp := <-responseDeliveryFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.DeliveryFutures) - if err != nil { - g.Websocket.DataHandler <- err - } - } - } -} - -// wsFunnelDeliveryFuturesConnectionData receives data from multiple connection and pass the data -// to wsRead through a channel responseStream -func (g *Gateio) wsFunnelDeliveryFuturesConnectionData(ws stream.Connection) { - defer g.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - responseDeliveryFuturesStream <- stream.Response{Raw: resp.Raw} - } -} - // GenerateDeliveryFuturesDefaultSubscriptions returns delivery futures default subscriptions params. func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.List, error) { _, err := g.GetCredentials(context.Background()) @@ -150,21 +72,21 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis } channelsToSubscribe := defaultDeliveryFuturesSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { - channelsToSubscribe = append( - channelsToSubscribe, - futuresOrdersChannel, - futuresUserTradesChannel, - futuresBalancesChannel, - ) + channelsToSubscribe = append(channelsToSubscribe, futuresOrdersChannel, futuresUserTradesChannel, futuresBalancesChannel) } - pairs, err := g.GetAvailablePairs(asset.DeliveryFutures) + + pairs, err := g.GetEnabledPairs(asset.DeliveryFutures) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. + } return nil, err } + var subscriptions subscription.List for i := range channelsToSubscribe { for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case futuresOrderbookChannel: params["limit"] = 20 @@ -172,13 +94,13 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis case futuresCandlesticksChannel: params["interval"] = kline.FiveMin } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.DeliveryFutures) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.DeliveryFutures) if err != nil { return nil, err } subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, }) } @@ -187,68 +109,31 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis } // DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } // DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleDeliveryFuturesSubscription("unsubscribe", channelsToUnsubscribe) +func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } -// handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateDeliveryFuturesPayload(event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - var respByte []byte - // con represents the websocket connection. 0 - for usdt settle and 1 - for btc settle connections. - for con, val := range payloads { - for k := range val { - if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, val[k].ID, val[k]) - } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, val[k].ID, val[k]) - } - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil { - errs = common.AppendError(errs, err) - } - } - } - } - return errs -} - -func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { - payloads := [2][]WsInput{} +func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { - return payloads, errors.New("cannot generate payload, no channels supplied") + return nil, errors.New("cannot generate payload, no channels supplied") } var creds *account.Credentials var err error if g.Websocket.CanUseAuthenticatedEndpoints() { - creds, err = g.GetCredentials(context.TODO()) + creds, err = g.GetCredentials(ctx) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) } } + outbound := make([]WsInput, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { - return payloads, subscription.ErrNotSinglePair + return nil, subscription.ErrNotSinglePair } var auth *WsAuthInput timestamp := time.Now() @@ -268,7 +153,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var sigTemp string sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp.Unix()) if err != nil { - return [2][]WsInput{}, err + return nil, err } auth = &WsAuthInput{ Method: "api_key", @@ -282,7 +167,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var frequencyString string frequencyString, err = g.GetIntervalString(frequency) if err != nil { - return payloads, err + return nil, err } params = append(params, frequencyString) } @@ -305,7 +190,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var intervalString string intervalString, err = g.GetIntervalString(interval) if err != nil { - return payloads, err + return nil, err } params = append([]string{intervalString}, params...) } @@ -315,25 +200,14 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { - payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } else { - payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } + outbound = append(outbound, WsInput{ + ID: conn.GenerateMessageID(false), + Event: event, + Channel: channelsToSubscribe[i].Channel, + Payload: params, + Auth: auth, + Time: timestamp.Unix(), + }) } - return payloads, nil + return outbound, nil } diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 61d299e8..00c37ba0 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -24,7 +23,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" - "github.com/thrasher-corp/gocryptotrader/log" ) const ( @@ -59,61 +57,25 @@ var defaultFuturesSubscriptions = []string{ futuresCandlesticksChannel, } -// responseFuturesStream a channel thought which the data coming from the two websocket connection will go through. -var responseFuturesStream = make(chan stream.Response) - // WsFuturesConnect initiates a websocket connection for futures account -func (g *Gateio) WsFuturesConnect() error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } +func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { return err } - var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(futuresWebsocketUsdtURL, false, true) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: futuresWebsocketBtcURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: g.Config.WebsocketResponseCheckTimeout, - ResponseMaxLimit: g.Config.WebsocketResponseMaxLimit, - Authenticated: true, - }) - if err != nil { - return err - } - err = g.Websocket.AuthConn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - g.Websocket.Wg.Add(3) - go g.wsReadFuturesData() - go g.wsFunnelFuturesConnectionData(g.Websocket.Conn) - go g.wsFunnelFuturesConnectionData(g.Websocket.AuthConn) - if g.Verbose { - log.Debugf(log.ExchangeSys, "Successful connection to %v\n", - g.Websocket.GetWebsocketURL()) - } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Time: func() int64 { - return time.Now().Unix() - }(), + ID: conn.GenerateMessageID(false), + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: futuresPingChannel, }) if err != nil { return err } - g.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(request.Unset, stream.PingHandler{ Websocket: true, MessageType: websocket.PingMessage, Delay: time.Second * 15, @@ -123,7 +85,7 @@ func (g *Gateio) WsFuturesConnect() error { } // GenerateFuturesDefaultSubscriptions returns default subscriptions information. -func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error) { +func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) (subscription.List, error) { channelsToSubscribe := defaultFuturesSubscriptions if g.Websocket.CanUseAuthenticatedEndpoints() { channelsToSubscribe = append(channelsToSubscribe, @@ -132,15 +94,39 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error futuresBalancesChannel, ) } + pairs, err := g.GetEnabledPairs(asset.Futures) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. + } return nil, err } - subscriptions := make(subscription.List, len(channelsToSubscribe)*len(pairs)) - count := 0 + + var subscriptions subscription.List for i := range channelsToSubscribe { + switch { + case settlement.Equal(currency.USDT): + pairs, err = pairs.GetPairsByQuote(currency.USDT) + if err != nil { + return nil, err + } + case settlement.Equal(currency.BTC): + offset := 0 + for x := range pairs { + if pairs[x].Quote.Equal(currency.USDT) { + continue // skip USDT pairs + } + pairs[offset] = pairs[x] + offset++ + } + pairs = pairs[:offset] + default: + return nil, fmt.Errorf("settlement currency %s not supported", settlement) + } + for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case futuresOrderbookChannel: params["limit"] = 100 @@ -151,80 +137,39 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error params["frequency"] = kline.ThousandMilliseconds params["level"] = "100" } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Futures) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.Futures) if err != nil { return nil, err } - subscriptions[count] = &subscription.Subscription{ + subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, - } - count++ + }) } } return subscriptions, nil } // FuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) FuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } // FuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleFuturesSubscription("unsubscribe", channelsToUnsubscribe) +func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } -// wsReadFuturesData read coming messages thought the websocket connection and pass the data to wsHandleData for further process. -func (g *Gateio) wsReadFuturesData() { - defer g.Websocket.Wg.Done() - for { - select { - case <-g.Websocket.ShutdownC: - select { - case resp := <-responseFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.Futures) - if err != nil { - select { - case g.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", g.Name, err) - } - } - default: - } - return - case resp := <-responseFuturesStream: - err := g.wsHandleFuturesData(resp.Raw, asset.Futures) - if err != nil { - g.Websocket.DataHandler <- err - } - } - } -} - -// wsFunnelFuturesConnectionData receives data from multiple connection and pass the data -// to wsRead through a channel responseStream -func (g *Gateio) wsFunnelFuturesConnectionData(ws stream.Connection) { - defer g.Websocket.Wg.Done() - for { - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - responseFuturesStream <- stream.Response{Raw: resp.Raw} - } -} - -func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error { +// WsHandleFuturesData handles futures websocket data +func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset.Item) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } @@ -233,27 +178,27 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error switch push.Channel { case futuresTickersChannel: - return g.processFuturesTickers(respRaw, assetType) + return g.processFuturesTickers(respRaw, a) case futuresTradesChannel: - return g.processFuturesTrades(respRaw, assetType) + return g.processFuturesTrades(respRaw, a) case futuresOrderbookChannel: - return g.processFuturesOrderbookSnapshot(push.Event, push.Result, assetType, push.Time.Time()) + return g.processFuturesOrderbookSnapshot(push.Event, push.Result, a, push.Time.Time()) case futuresOrderbookTickerChannel: return g.processFuturesOrderbookTicker(push.Result) case futuresOrderbookUpdateChannel: - return g.processFuturesAndOptionsOrderbookUpdate(push.Result, assetType) + return g.processFuturesAndOptionsOrderbookUpdate(push.Result, a) case futuresCandlesticksChannel: - return g.processFuturesCandlesticks(respRaw, assetType) + return g.processFuturesCandlesticks(respRaw, a) case futuresOrdersChannel: var processed []order.Detail - processed, err = g.processFuturesOrdersPushData(respRaw, assetType) + processed, err = g.processFuturesOrdersPushData(respRaw, a) if err != nil { return err } g.Websocket.DataHandler <- processed return nil case futuresUserTradesChannel: - return g.procesFuturesUserTrades(respRaw, assetType) + return g.procesFuturesUserTrades(respRaw, a) case futuresLiquidatesChannel: return g.processFuturesLiquidatesNotification(respRaw) case futuresAutoDeleveragesChannel: @@ -261,7 +206,7 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error case futuresAutoPositionCloseChannel: return g.processPositionCloseData(respRaw) case futuresBalancesChannel: - return g.processBalancePushData(respRaw, assetType) + return g.processBalancePushData(respRaw, a) case futuresReduceRiskLimitsChannel: return g.processFuturesReduceRiskLimitNotification(respRaw) case futuresPositionsChannel: @@ -276,62 +221,23 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error } } -// handleFuturesSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateFuturesPayload(event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - var respByte []byte - // con represents the websocket connection. 0 - for usdt settle and 1 - for btc settle connections. - for con, val := range payloads { - for k := range val { - if con == 0 { - respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, val[k].ID, val[k]) - } else { - respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, val[k].ID, val[k]) - } - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(respByte, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil { - errs = common.AppendError(errs, err) - } - } - } - } - if errs != nil { - return errs - } - return nil -} - -func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) { - payloads := [2][]WsInput{} +func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { - return payloads, errors.New("cannot generate payload, no channels supplied") + return nil, errors.New("cannot generate payload, no channels supplied") } var creds *account.Credentials var err error if g.Websocket.CanUseAuthenticatedEndpoints() { - creds, err = g.GetCredentials(context.TODO()) + creds, err = g.GetCredentials(ctx) if err != nil { g.Websocket.SetCanUseAuthenticatedEndpoints(false) } } + + outbound := make([]WsInput, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { if len(channelsToSubscribe[i].Pairs) != 1 { - return payloads, subscription.ErrNotSinglePair + return nil, subscription.ErrNotSinglePair } var auth *WsAuthInput timestamp := time.Now() @@ -353,7 +259,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr var sigTemp string sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp.Unix()) if err != nil { - return [2][]WsInput{}, err + return nil, err } auth = &WsAuthInput{ Method: "api_key", @@ -367,7 +273,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr var frequencyString string frequencyString, err = g.GetIntervalString(frequency) if err != nil { - return payloads, err + return nil, err } params = append(params, frequencyString) } @@ -390,7 +296,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr var intervalString string intervalString, err = g.GetIntervalString(interval) if err != nil { - return payloads, err + return nil, err } params = append([]string{intervalString}, params...) } @@ -400,27 +306,16 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscr params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") { - payloads[0] = append(payloads[0], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } else { - payloads[1] = append(payloads[1], WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Event: event, - Channel: channelsToSubscribe[i].Channel, - Payload: params, - Auth: auth, - Time: timestamp.Unix(), - }) - } + outbound = append(outbound, WsInput{ + ID: conn.GenerateMessageID(false), + Event: event, + Channel: channelsToSubscribe[i].Channel, + Payload: params, + Auth: auth, + Time: timestamp.Unix(), + }) } - return payloads, nil + return outbound, nil } func (g *Gateio) processFuturesTickers(data []byte, assetType asset.Item) error { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index b5d61c11..b91471d8 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -69,34 +68,24 @@ var defaultOptionsSubscriptions = []string{ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. -func (g *Gateio) WsOptionsConnect() error { - if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled - } +func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { return err } - var dialer websocket.Dialer - err = g.Websocket.SetWebsocketURL(optionsWebsocketURL, false, true) - if err != nil { - return err - } - err = g.Websocket.Conn.Dial(&dialer, http.Header{}) + err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) if err != nil { return err } pingMessage, err := json.Marshal(WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), - Time: time.Now().Unix(), + ID: conn.GenerateMessageID(false), + Time: time.Now().Unix(), // TODO: Func for dynamic time as this will be the same time for every ping message. Channel: optionsPingChannel, }) if err != nil { return err } - g.Websocket.Wg.Add(1) - go g.wsReadOptionsConnData() - g.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(request.Unset, stream.PingHandler{ Websocket: true, Delay: time.Second * 5, MessageType: websocket.PingMessage, @@ -130,15 +119,21 @@ func (g *Gateio) GenerateOptionsDefaultSubscriptions() (subscription.List, error log.Errorf(log.ExchangeSys, "no subaccount found for authenticated options channel subscriptions") } } + getEnabledPairs: - var subscriptions subscription.List + pairs, err := g.GetEnabledPairs(asset.Options) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil // no enabled pairs, subscriptions require an associated pair. + } return nil, err } + + var subscriptions subscription.List for i := range channelsToSubscribe { for j := range pairs { - params := make(map[string]interface{}) + params := make(map[string]any) switch channelsToSubscribe[i] { case optionsOrderbookChannel: params["accuracy"] = "0" @@ -160,13 +155,13 @@ getEnabledPairs: } params["user_id"] = userID } - fpair, err := g.FormatExchangeCurrency(pairs[j], asset.Options) + fPair, err := g.FormatExchangeCurrency(pairs[j], asset.Options) if err != nil { return nil, err } subscriptions = append(subscriptions, &subscription.Subscription{ Channel: channelsToSubscribe[i], - Pairs: currency.Pairs{fpair.Upper()}, + Pairs: currency.Pairs{fPair.Upper()}, Params: params, }) } @@ -174,7 +169,7 @@ getEnabledPairs: return subscriptions, nil } -func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -233,7 +228,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr } params = append([]string{strconv.FormatInt(userID, 10)}, params...) var creds *account.Credentials - creds, err = g.GetCredentials(context.Background()) + creds, err = g.GetCredentials(ctx) if err != nil { return nil, err } @@ -276,7 +271,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr params...) } payloads[i] = WsInput{ - ID: g.Websocket.Conn.GenerateMessageID(false), + ID: conn.GenerateMessageID(false), Event: event, Channel: channelsToSubscribe[i].Channel, Payload: params, @@ -287,73 +282,25 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscr return payloads, nil } -// wsReadOptionsConnData receives and passes on websocket messages for processing -func (g *Gateio) wsReadOptionsConnData() { - defer g.Websocket.Wg.Done() - for { - resp := g.Websocket.Conn.ReadMessage() - if resp.Raw == nil { - return - } - err := g.wsHandleOptionsData(resp.Raw) - if err != nil { - g.Websocket.DataHandler <- err - } - } -} - // OptionsSubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsSubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription("subscribe", channelsToUnsubscribe) +func (g *Gateio) OptionsSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } // OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsUnsubscribe(channelsToUnsubscribe subscription.List) error { - return g.handleOptionsSubscription("unsubscribe", channelsToUnsubscribe) +func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } -// handleOptionsSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe subscription.List) error { - payloads, err := g.generateOptionsPayload(event, channelsToSubscribe) - if err != nil { - return err - } - var errs error - for k := range payloads { - result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[k].ID, payloads[k]) - if err != nil { - errs = common.AppendError(errs, err) - continue - } - var resp WsEventResponse - if err = json.Unmarshal(result, &resp); err != nil { - errs = common.AppendError(errs, err) - } else { - if resp.Error != nil && resp.Error.Code != 0 { - errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s asset type: options error code: %d message: %s", payloads[k].Event, payloads[k].Channel, resp.Error.Code, resp.Error.Message)) - continue - } - if payloads[k].Event == "subscribe" { - err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]) - } else { - err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k]) - } - if err != nil { - errs = common.AppendError(errs, err) - } - } - } - return errs -} - -func (g *Gateio) wsHandleOptionsData(respRaw []byte) error { +// WsHandleOptionsData handles options websocket data +func (g *Gateio) WsHandleOptionsData(_ context.Context, respRaw []byte) error { var push WsResponse err := json.Unmarshal(respRaw, &push) if err != nil { return err } - if push.Event == "subscribe" || push.Event == "unsubscribe" { + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) } diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 27952409..ae8249a9 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -118,10 +118,10 @@ func (g *Gemini) manageSubs(subs subscription.List, op wsSubOp) error { } if op == wsUnsubscribeOp { - return g.Websocket.RemoveSubscriptions(subs...) + return g.Websocket.RemoveSubscriptions(g.Websocket.Conn, subs...) } - return g.Websocket.AddSuccessfulSubscriptions(subs...) + return g.Websocket.AddSuccessfulSubscriptions(g.Websocket.Conn, subs...) } // WsAuth will connect to Gemini's secure endpoint diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index e97c9d93..1159d376 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -152,7 +152,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - err = g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v2/" + geminiWsMarketData, @@ -161,7 +161,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v1/" + geminiWsOrderEvents, diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 51cc1806..2957b800 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -526,7 +526,7 @@ func (h *HitBTC) Subscribe(channelsToSubscribe subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(s) + err = h.Websocket.AddSuccessfulSubscriptions(h.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) @@ -562,7 +562,7 @@ func (h *HitBTC) Unsubscribe(subs subscription.List) error { err := h.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, r) if err == nil { - err = h.Websocket.RemoveSubscriptions(s) + err = h.Websocket.RemoveSubscriptions(h.Websocket.Conn, s) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 4d23ba01..8de08796 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -168,7 +168,7 @@ func (h *HitBTC) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 83974372..999f0251 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -570,7 +570,7 @@ func (h *HUOBI) Subscribe(channelsToSubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i]) + err = h.Websocket.AddSuccessfulSubscriptions(h.Websocket.Conn, channelsToSubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) @@ -604,7 +604,7 @@ func (h *HUOBI) Unsubscribe(channelsToUnsubscribe subscription.List) error { }) } if err == nil { - err = h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i]) + err = h.Websocket.RemoveSubscriptions(h.Websocket.Conn, channelsToUnsubscribe[i]) } if err != nil { errs = common.AppendError(errs, err) diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 879b5285..cf9b0168 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -220,7 +220,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - err = h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -229,7 +229,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index b41af30e..7f8f8e05 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -29,6 +29,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -1029,7 +1030,7 @@ func TestWsResubscribe(t *testing.T) { err = subs[0].SetState(subscription.UnsubscribingState) require.NoError(t, err) - err = k.Websocket.ResubscribeToChannel(subs[0]) + err = k.Websocket.ResubscribeToChannel(k.Websocket.Conn, subs[0]) require.NoError(t, err, "Resubscribe must not error") require.Equal(t, subscription.SubscribedState, subs[0].State(), "subscription must be subscribed again") } @@ -1209,7 +1210,7 @@ func TestWsHandleData(t *testing.T) { k := new(Kraken) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes require.NoError(t, testexch.Setup(k), "Setup Instance must not error") for _, l := range []int{10, 100} { - err := k.Websocket.AddSuccessfulSubscriptions(&subscription.Subscription{ + err := k.Websocket.AddSuccessfulSubscriptions(k.Websocket.Conn, &subscription.Subscription{ Channel: subscription.OrderbookChannel, Pairs: currency.Pairs{spotTestPair}, Asset: asset.Spot, @@ -1439,7 +1440,7 @@ func TestWsOrderbookMax10Depth(t *testing.T) { currency.NewPairWithDelimiter("GST", "EUR", "/"), } for _, p := range pairs { - err := k.Websocket.AddSuccessfulSubscriptions(&subscription.Subscription{ + err := k.Websocket.AddSuccessfulSubscriptions(k.Websocket.Conn, &subscription.Subscription{ Channel: subscription.OrderbookChannel, Pairs: currency.Pairs{p}, Asset: asset.Spot, @@ -1569,7 +1570,7 @@ func TestGetOpenInterest(t *testing.T) { } // curryWsMockUpgrader handles Kraken specific http auth token responses prior to handling off to standard Websocket upgrader -func curryWsMockUpgrader(tb testing.TB, h testexch.WsMockFunc) http.HandlerFunc { +func curryWsMockUpgrader(tb testing.TB, h mockws.WsMockFunc) http.HandlerFunc { tb.Helper() return func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "GetWebSocketsToken") { @@ -1577,7 +1578,7 @@ func curryWsMockUpgrader(tb testing.TB, h testexch.WsMockFunc) http.HandlerFunc assert.NoError(tb, err, "Write should not error") return } - testexch.WsMockUpgrader(tb, w, r, h) + mockws.WsMockUpgrader(tb, w, r, h) } } diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index a964bd28..b572af6a 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -628,7 +628,7 @@ func (k *Kraken) wsProcessOrderBook(c string, response []any, pair currency.Pair if errors.Is(err, errInvalidChecksum) { log.Debugf(log.Global, "%s Resubscribing to invalid %s orderbook", k.Name, pair) go func() { - if e2 := k.Websocket.ResubscribeToChannel(s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { + if e2 := k.Websocket.ResubscribeToChannel(k.Websocket.Conn, s); e2 != nil && !errors.Is(e2, subscription.ErrInStateAlready) { log.Errorf(log.ExchangeSys, "%s resubscription failure for %v: %v", k.Name, pair, e2) } }() @@ -981,7 +981,7 @@ func (k *Kraken) Subscribe(in subscription.List) error { subs := subscription.List{} for _, s := range in { if s.State() != subscription.ResubscribingState { - if err := k.Websocket.AddSubscriptions(s); err != nil { + if err := k.Websocket.AddSubscriptions(k.Websocket.Conn, s); err != nil { errs = common.AppendError(errs, fmt.Errorf("%w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) continue } @@ -999,7 +999,7 @@ func (k *Kraken) Subscribe(in subscription.List) error { for _, s := range subs { if s.State() != subscription.SubscribedState { _ = s.SetState(subscription.InactiveState) - if err := k.Websocket.RemoveSubscriptions(s); err != nil { + if err := k.Websocket.RemoveSubscriptions(k.Websocket.Conn, s); err != nil { errs = common.AppendError(errs, fmt.Errorf("error removing failed subscription: %w; Channel: %s Pairs: %s", err, s.Channel, s.Pairs.Join())) } } @@ -1215,7 +1215,7 @@ func (k *Kraken) wsProcessSubStatus(resp []byte) { if status == krakenWsSubscribed { err = s.SetState(subscription.SubscribedState) } else if s.State() != subscription.ResubscribingState { // Do not remove a resubscribing sub which just unsubbed - err = k.Websocket.RemoveSubscriptions(s) + err = k.Websocket.RemoveSubscriptions(k.Websocket.Conn, s) if e2 := s.SetState(subscription.UnsubscribedState); e2 != nil { err = common.AppendError(err, e2) } diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index 82447a60..610350c2 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -228,7 +228,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { return err } - err = k.Websocket.SetupNewConnection(stream.ConnectionSetup{ + err = k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -241,7 +241,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { if err != nil { return err } - return k.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 27c71925..ed97c1ab 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -1036,9 +1036,9 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string) errs = common.AppendError(errs, fmt.Errorf("%w: %s from %s", errInvalidMsgType, rType, respRaw)) default: if operation == "unsubscribe" { - err = ku.Websocket.RemoveSubscriptions(s) + err = ku.Websocket.RemoveSubscriptions(ku.Websocket.Conn, s) } else { - err = ku.Websocket.AddSuccessfulSubscriptions(s) + err = ku.Websocket.AddSuccessfulSubscriptions(ku.Websocket.Conn, s) if ku.Verbose { log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s", ku.Name, s.Channel) } diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 024afc82..12a2bc42 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -214,7 +214,7 @@ func (ku *Kucoin) Setup(exch *config.Exchange) error { if err != nil { return err } - return ku.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return ku.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 5a4d1b87..ab53be1f 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -487,9 +487,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(channels...) + err = ok.Websocket.RemoveSubscriptions(ok.Websocket.AuthConn, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.AuthConn, channels...) } if err != nil { return err @@ -511,9 +511,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L return err } if operation == operationUnsubscribe { - err = ok.Websocket.RemoveSubscriptions(channels...) + err = ok.Websocket.RemoveSubscriptions(ok.Websocket.Conn, channels...) } else { - err = ok.Websocket.AddSuccessfulSubscriptions(channels...) + err = ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.Conn, channels...) } if err != nil { return err @@ -539,10 +539,10 @@ func (ok *Okx) handleSubscription(operation string, subscriptions subscription.L channels = append(channels, authChannels...) if operation == operationUnsubscribe { - return ok.Websocket.RemoveSubscriptions(channels...) + return ok.Websocket.RemoveSubscriptions(ok.Websocket.Conn, channels...) } - return ok.Websocket.AddSuccessfulSubscriptions(channels...) + return ok.Websocket.AddSuccessfulSubscriptions(ok.Websocket.Conn, channels...) } // WsHandleData will read websocket raw data and pass to appropriate handler diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 971e19d5..debfb054 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -224,7 +224,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { go ok.WsResponseMultiplexer.Run() - if err := ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ + if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: okxAPIWebsocketPublicURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: okxWebsocketResponseMaxLimit, @@ -233,7 +233,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { return err } - return ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ URL: okxAPIWebsocketPrivateURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: okxWebsocketResponseMaxLimit, diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 04eb08d3..2d2ce08c 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -609,9 +609,9 @@ func (p *Poloniex) manageSubs(subs subscription.List, op wsOp) error { } if err == nil { if op == wsSubscribeOp { - err = p.Websocket.AddSuccessfulSubscriptions(s) + err = p.Websocket.AddSuccessfulSubscriptions(p.Websocket.Conn, s) } else { - err = p.Websocket.RemoveSubscriptions(s) + err = p.Websocket.RemoveSubscriptions(p.Websocket.Conn, s) } } if err != nil { diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index db35face..183e96dd 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -181,7 +181,7 @@ func (p *Poloniex) Setup(exch *config.Exchange) error { return err } - return p.Websocket.SetupNewConnection(stream.ConnectionSetup{ + return p.Websocket.SetupNewConnection(&stream.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/protocol/features.go b/exchanges/protocol/features.go index b3c0c666..d7a0a43d 100644 --- a/exchanges/protocol/features.go +++ b/exchanges/protocol/features.go @@ -3,40 +3,37 @@ package protocol // Features holds all variables for the exchanges supported features // for a protocol (e.g REST or Websocket) type Features struct { - TickerBatching bool `json:"tickerBatching,omitempty"` - AutoPairUpdates bool `json:"autoPairUpdates,omitempty"` - AccountBalance bool `json:"accountBalance,omitempty"` - CryptoDeposit bool `json:"cryptoDeposit,omitempty"` - CryptoWithdrawal bool `json:"cryptoWithdrawal,omitempty"` - FiatWithdraw bool `json:"fiatWithdraw,omitempty"` - GetOrder bool `json:"getOrder,omitempty"` - GetOrders bool `json:"getOrders,omitempty"` - CancelOrders bool `json:"cancelOrders,omitempty"` - CancelOrder bool `json:"cancelOrder,omitempty"` - SubmitOrder bool `json:"submitOrder,omitempty"` - SubmitOrders bool `json:"submitOrders,omitempty"` - ModifyOrder bool `json:"modifyOrder,omitempty"` - DepositHistory bool `json:"depositHistory,omitempty"` - WithdrawalHistory bool `json:"withdrawalHistory,omitempty"` - TradeHistory bool `json:"tradeHistory,omitempty"` - UserTradeHistory bool `json:"userTradeHistory,omitempty"` - TradeFee bool `json:"tradeFee,omitempty"` - FiatDepositFee bool `json:"fiatDepositFee,omitempty"` - FiatWithdrawalFee bool `json:"fiatWithdrawalFee,omitempty"` - CryptoDepositFee bool `json:"cryptoDepositFee,omitempty"` - CryptoWithdrawalFee bool `json:"cryptoWithdrawalFee,omitempty"` - TickerFetching bool `json:"tickerFetching,omitempty"` - KlineFetching bool `json:"klineFetching,omitempty"` - TradeFetching bool `json:"tradeFetching,omitempty"` - OrderbookFetching bool `json:"orderbookFetching,omitempty"` - AccountInfo bool `json:"accountInfo,omitempty"` - FiatDeposit bool `json:"fiatDeposit,omitempty"` - DeadMansSwitch bool `json:"deadMansSwitch,omitempty"` - FundingRateFetching bool `json:"fundingRateFetching"` - PredictedFundingRate bool `json:"predictedFundingRate,omitempty"` - // FullPayloadSubscribe flushes and changes full subscription on websocket - // connection by subscribing with full default stream channel list - FullPayloadSubscribe bool `json:"fullPayloadSubscribe,omitempty"` + TickerBatching bool `json:"tickerBatching,omitempty"` + AutoPairUpdates bool `json:"autoPairUpdates,omitempty"` + AccountBalance bool `json:"accountBalance,omitempty"` + CryptoDeposit bool `json:"cryptoDeposit,omitempty"` + CryptoWithdrawal bool `json:"cryptoWithdrawal,omitempty"` + FiatWithdraw bool `json:"fiatWithdraw,omitempty"` + GetOrder bool `json:"getOrder,omitempty"` + GetOrders bool `json:"getOrders,omitempty"` + CancelOrders bool `json:"cancelOrders,omitempty"` + CancelOrder bool `json:"cancelOrder,omitempty"` + SubmitOrder bool `json:"submitOrder,omitempty"` + SubmitOrders bool `json:"submitOrders,omitempty"` + ModifyOrder bool `json:"modifyOrder,omitempty"` + DepositHistory bool `json:"depositHistory,omitempty"` + WithdrawalHistory bool `json:"withdrawalHistory,omitempty"` + TradeHistory bool `json:"tradeHistory,omitempty"` + UserTradeHistory bool `json:"userTradeHistory,omitempty"` + TradeFee bool `json:"tradeFee,omitempty"` + FiatDepositFee bool `json:"fiatDepositFee,omitempty"` + FiatWithdrawalFee bool `json:"fiatWithdrawalFee,omitempty"` + CryptoDepositFee bool `json:"cryptoDepositFee,omitempty"` + CryptoWithdrawalFee bool `json:"cryptoWithdrawalFee,omitempty"` + TickerFetching bool `json:"tickerFetching,omitempty"` + KlineFetching bool `json:"klineFetching,omitempty"` + TradeFetching bool `json:"tradeFetching,omitempty"` + OrderbookFetching bool `json:"orderbookFetching,omitempty"` + AccountInfo bool `json:"accountInfo,omitempty"` + FiatDeposit bool `json:"fiatDeposit,omitempty"` + DeadMansSwitch bool `json:"deadMansSwitch,omitempty"` + FundingRateFetching bool `json:"fundingRateFetching"` + PredictedFundingRate bool `json:"predictedFundingRate,omitempty"` Subscribe bool `json:"subscribe,omitempty"` Unsubscribe bool `json:"unsubscribe,omitempty"` AuthenticatedEndpoints bool `json:"authenticatedEndpoints,omitempty"` diff --git a/exchanges/stream/README.md b/exchanges/stream/README.md new file mode 100644 index 00000000..3a02c2ef --- /dev/null +++ b/exchanges/stream/README.md @@ -0,0 +1,137 @@ +# GoCryptoTrader Exchange Stream Package + +This package is part of the GoCryptoTrader project and is responsible for handling exchange streaming data. + +## Overview + +The `stream` package uses Gorilla Websocket and provides functionalities to connect to various cryptocurrency exchanges and handle real-time data streams. + +## Features + +- Handle real-time market data streams +- Unified interface for managing data streams +- Multi-connection management - a system that can be used to manage multiple connections to the same exchange +- Connection monitoring - a system that can be used to monitor the health of the websocket connections. This can be used to check if the connection is still alive and if it is not, it will attempt to reconnect +- Traffic monitoring - will reconnect if no message is sent for a period of time defined in your config +- Subscription management - a system that can be used to manage subscriptions to various data streams +- Rate limiting - a system that can be used to rate limit the number of requests sent to the exchange +- Message ID generation - a system that can be used to generate message IDs for websocket requests +- Websocket message response matching - can be used to match websocket responses to the requests that were sent + +## Usage + +### Default single websocket connection +Here is a basic example of how to setup the `stream` package for websocket: + +```go +package main + +import ( + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" +) + +type Exchange struct { + exchange.Base +} + +// In the exchange wrapper this will set up the initial pointer field provided by exchange.Base +func (e *Exchange) SetDefault() { + e.Websocket = stream.NewWebsocket() + e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit + e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout + e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit +} + +// In the exchange wrapper this is the original setup pattern for the websocket services +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&stream.WebsocketSetup{ + ExchangeConfig: exch, + DefaultURL: connectionURLString, + RunningURL: connectionURLString, + Connector: e.WsConnect, + Subscriber: e.Subscribe, + Unsubscriber: e.Unsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptions, + Features: &e.Features.Supports.WebsocketCapabilities, + MaxWebsocketSubscriptionsPerConnection: 240, + OrderbookBufferConfig: buffer.Config{ Checksum: e.CalculateUpdateOrderbookChecksum }, + }); err != nil { + return err + } + + // This is a public websocket connection + if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }); err != nil { + return err + } + + // This is a private websocket connection + return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: privateConnectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + Authenticated: true, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }) +} +``` + +### Multiple websocket connections + The example below provides the now optional multi connection management system which allows for more connections + to be maintained and established based off URL, connections types, asset types etc. +```go +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&stream.WebsocketSetup{ + ExchangeConfig: exch, + Features: &e.Features.Supports.WebsocketCapabilities, + FillsFeed: e.Features.Enabled.FillsFeed, + TradeFeed: e.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, + }) + if err != nil { + return err + } + // Spot connection + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLStringForSpot, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: e.WsHandleSpotData, + Subscriber: e.SpotSubscribe, + Unsubscriber: e.SpotUnsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptionsSpot, + Connector: e.WsConnectSpot, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: connectionURLStringForSpotForFutures, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: func(ctx context.Context, incoming []byte) error { return e.WsHandleFuturesData(ctx, incoming, asset.Futures) }, + Subscriber: e.FuturesSubscribe, + Unsubscriber: e.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return e.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: e.WsFuturesConnect, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } +} +``` \ No newline at end of file diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 22963e25..2cbf0a2f 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -10,11 +10,13 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) // Connection defines a streaming services connection type Connection interface { Dial(*websocket.Dialer, http.Header) error + DialContext(context.Context, *websocket.Dialer, http.Header) error ReadMessage() Response SetupPingHandler(request.EndpointLimit, PingHandler) // GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set @@ -46,15 +48,50 @@ type ConnectionSetup struct { ResponseCheckTimeout time.Duration ResponseMaxLimit time.Duration RateLimit *request.RateLimiterWithWeight - URL string Authenticated bool ConnectionLevelReporter Reporter + + // URL defines the websocket server URL to connect to + URL string + // Connector is the function that will be called to connect to the + // exchange's websocket server. This will be called once when the stream + // service is started. Any bespoke connection logic should be handled here. + Connector func(ctx context.Context, conn Connection) error + // GenerateSubscriptions is a function that will be called to generate a + // list of subscriptions to be made to the exchange's websocket server. + GenerateSubscriptions func() (subscription.List, error) + // Subscriber is a function that will be called to send subscription + // messages based on the exchange's websocket server requirements to + // subscribe to specific channels. + Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error + // Unsubscriber is a function that will be called to send unsubscription + // messages based on the exchange's websocket server requirements to + // unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED. + Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error + // Handler defines the function that will be called when a message is + // received from the exchange's websocket server. This function should + // handle the incoming message and pass it to the appropriate data handler. + Handler func(ctx context.Context, incoming []byte) error // BespokeGenerateMessageID is a function that returns a unique message ID. // This is useful for when an exchange connection requires a unique or // structured message ID for each message sent. BespokeGenerateMessageID func(highPrecision bool) int64 } +// ConnectionWrapper contains the connection setup details to be used when +// attempting a new connection. It also contains the subscriptions that are +// associated with the specific connection. +type ConnectionWrapper struct { + // Setup contains the connection setup details + Setup *ConnectionSetup + // Subscriptions contains the subscriptions that are associated with the + // specific connection(s) + Subscriptions *subscription.Store + // Connection contains the active connection based off the connection + // details above. + Connection Connection // TODO: Upgrade to slice of connections. +} + // PingHandler container for ping handler settings type PingHandler struct { Websocket bool diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 69693db7..309db9a7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -1,14 +1,14 @@ package stream import ( + "context" "errors" "fmt" - "net" "net/url" "slices" + "sync" "time" - "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -17,9 +17,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/log" ) -const ( - jobBuffer = 5000 -) +const jobBuffer = 5000 // Public websocket errors var ( @@ -35,9 +33,7 @@ var ( // Private websocket errors var ( - errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") - errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") @@ -53,9 +49,9 @@ var ( errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") - errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") errSameProxyAddress = errors.New("cannot set proxy address to the same address") @@ -64,12 +60,13 @@ var ( errCannotShutdown = errors.New("websocket cannot shutdown") errAlreadyReconnecting = errors.New("websocket in the process of reconnection") errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errExchangeConfigEmpty = errors.New("exchange config is empty") ) -var ( - globalReporter Reporter - trafficCheckInterval = 100 * time.Millisecond -) +var globalReporter Reporter // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -80,15 +77,20 @@ func SetupGlobalReporter(r Reporter) { // NewWebsocket initialises the websocket struct func NewWebsocket() *Websocket { return &Websocket{ - DataHandler: make(chan interface{}, jobBuffer), - ToRoutine: make(chan interface{}, jobBuffer), - ShutdownC: make(chan struct{}), - TrafficAlert: make(chan struct{}, 1), - ReadMessageErrors: make(chan error), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + ShutdownC: make(chan struct{}), + TrafficAlert: make(chan struct{}, 1), + // ReadMessageErrors is buffered for an edge case when `Connect` fails + // after subscriptions are made but before the connectionMonitor has + // started. This allows the error to be read and handled in the + // connectionMonitor and start a connection cycle again. + ReadMessageErrors: make(chan error, 1), Match: NewMatch(), subscriptions: subscription.NewStore(), features: &protocol.Features{}, Orderbook: buffer.Orderbook{}, + connections: make(map[Connection]*ConnectionWrapper), } } @@ -129,47 +131,52 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) - if s.Connector == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) - } - w.connector = s.Connector + w.useMultiConnectionManagement = s.UseMultiConnectionManagement - if s.Subscriber == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriberUnset) - } - w.Subscriber = s.Subscriber + if !w.useMultiConnectionManagement { + // TODO: Remove this block when all exchanges are updated and backwards + // compatibility is no longer required. + if s.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if s.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if s.Unsubscriber == nil && w.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if s.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if s.DefaultURL == "" { + return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) + } + w.defaultURL = s.DefaultURL + if s.RunningURL == "" { + return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty) + } - if w.features.Unsubscribe && s.Unsubscriber == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketUnsubscriberUnset) - } - w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay - if w.connectionMonitorDelay <= 0 { - w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay - } - w.Unsubscriber = s.Unsubscriber + w.connector = s.Connector + w.Subscriber = s.Subscriber + w.Unsubscriber = s.Unsubscriber + w.GenerateSubs = s.GenerateSubscriptions - if s.GenerateSubscriptions == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriptionsGeneratorUnset) - } - w.GenerateSubs = s.GenerateSubscriptions - - if s.DefaultURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) - } - w.defaultURL = s.DefaultURL - if s.RunningURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty) - } - err := w.SetWebsocketURL(s.RunningURL, false, false) - if err != nil { - return fmt.Errorf("%s %w", w.exchangeName, err) - } - - if s.RunningURLAuth != "" { - err = w.SetWebsocketURL(s.RunningURLAuth, true, false) + err := w.SetWebsocketURL(s.RunningURL, false, false) if err != nil { return fmt.Errorf("%s %w", w.exchangeName, err) } + + if s.RunningURLAuth != "" { + err = w.SetWebsocketURL(s.RunningURLAuth, true, false) + if err != nil { + return fmt.Errorf("%s %w", w.exchangeName, err) + } + } + } + + w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay + if w.connectionMonitorDelay <= 0 { + w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay } if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second { @@ -180,7 +187,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout - w.ShutdownC = make(chan struct{}) w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil { @@ -201,12 +207,12 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } // SetupNewConnection sets up an auth or unauth streaming connection -func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { +func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { if w == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } - if c.ResponseCheckTimeout == 0 && + if c == nil || c.ResponseCheckTimeout == 0 && c.ResponseMaxLimit == 0 && c.RateLimit == nil && c.URL == "" && @@ -218,29 +224,71 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w.exchangeName == "" { return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } - if w.TrafficAlert == nil { return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } - if w.ReadMessageErrors == nil { return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } - - connectionURL := w.GetWebsocketURL() - if c.URL != "" { - connectionURL = c.URL - } - if c.ConnectionLevelReporter == nil { c.ConnectionLevelReporter = w.ExchangeLevelReporter } - if c.ConnectionLevelReporter == nil { c.ConnectionLevelReporter = globalReporter } - newConn := &WebsocketConnection{ + if w.useMultiConnectionManagement { + // The connection and supporting functions are defined per connection + // and the connection wrapper is stored in the connection manager. + if c.URL == "" { + return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) + } + if c.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if c.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if c.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if c.Unsubscriber == nil && w.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if c.Handler == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) + } + + for x := range w.connectionManager { + if w.connectionManager[x].Setup.URL == c.URL { + return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) + } + } + + w.connectionManager = append(w.connectionManager, ConnectionWrapper{ + Setup: c, + Subscriptions: subscription.NewStore(), + }) + return nil + } + + if c.Authenticated { + w.AuthConn = w.getConnectionFromSetup(c) + } else { + w.Conn = w.getConnectionFromSetup(c) + } + + return nil +} + +// getConnectionFromSetup returns a websocket connection from a setup +// configuration. This is used for setting up new connections on the fly. +func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnection { + connectionURL := w.GetWebsocketURL() + if c.URL != "" { + connectionURL = c.URL + } + return &WebsocketConnection{ ExchangeName: w.exchangeName, URL: connectionURL, ProxyURL: w.GetProxyAddress(), @@ -248,7 +296,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { ResponseMaxLimit: c.ResponseMaxLimit, Traffic: w.TrafficAlert, readMessageErrors: w.ReadMessageErrors, - ShutdownC: w.ShutdownC, + shutdown: w.ShutdownC, Wg: &w.Wg, Match: w.Match, RateLimit: c.RateLimit, @@ -256,25 +304,17 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { bespokeGenerateMessageID: c.BespokeGenerateMessageID, RateLimitDefinitions: w.rateLimitDefinitions, } - - if c.Authenticated { - w.AuthConn = newConn - } else { - w.Conn = newConn - } - - return nil } // Connect initiates a websocket connection by using a package defined connection // function func (w *Websocket) Connect() error { - if w.connector == nil { - return errNoConnectFunc - } w.m.Lock() defer w.m.Unlock() + return w.connect() +} +func (w *Websocket) connect() error { if !w.IsEnabled() { return ErrWebsocketNotEnabled } @@ -290,32 +330,149 @@ func (w *Websocket) Connect() error { } w.subscriptions.Clear() - w.dataMonitor() - w.trafficMonitor() w.setState(connectingState) - err := w.connector() - if err != nil { - w.setState(disconnectedState) - return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + w.Wg.Add(2) + go w.monitorFrame(&w.Wg, w.monitorData) + go w.monitorFrame(&w.Wg, w.monitorTraffic) + + if !w.useMultiConnectionManagement { + if w.connector == nil { + return fmt.Errorf("%v %w", w.exchangeName, errNoConnectFunc) + } + err := w.connector() + if err != nil { + w.setState(disconnectedState) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + } + w.setState(connectedState) + + if w.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go w.monitorFrame(nil, w.monitorConnection) + } + + subs, err := w.GenerateSubs() // regenerate state on new connection + if err != nil { + return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + } + if len(subs) != 0 { + if err := w.SubscribeToChannels(nil, subs); err != nil { + return err + } + } + return nil } + + if len(w.connectionManager) == 0 { + w.setState(disconnectedState) + return fmt.Errorf("cannot connect: %w", errNoPendingConnections) + } + + // multiConnectFatalError is a fatal error that will cause all connections to + // be shutdown and the websocket to be disconnected. + var multiConnectFatalError error + + // TODO: Implement concurrency below. + for i := range w.connectionManager { + if w.connectionManager[i].Setup.GenerateSubscriptions == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriptionsGeneratorUnset) + break + } + + subs, err := w.connectionManager[i].Setup.GenerateSubscriptions() // regenerate state on new connection + if err != nil { + multiConnectFatalError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + break + } + + if len(subs) == 0 { + // If no subscriptions are generated, we skip the connection + if w.verbose { + log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) + } + continue + } + + if w.connectionManager[i].Setup.Connector == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errNoConnectFunc) + break + } + if w.connectionManager[i].Setup.Handler == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketDataHandlerUnset) + break + } + if w.connectionManager[i].Setup.Subscriber == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriberUnset) + break + } + + // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. + + conn := w.getConnectionFromSetup(w.connectionManager[i].Setup) + + err = w.connectionManager[i].Setup.Connector(context.TODO(), conn) + if err != nil { + multiConnectFatalError = fmt.Errorf("%v Error connecting %w", w.exchangeName, err) + break + } + + if !conn.IsConnected() { + multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", w.exchangeName, i+1, conn.URL) + break + } + + w.connections[conn] = &w.connectionManager[i] + w.connectionManager[i].Connection = conn + + w.Wg.Add(1) + go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) + + err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) + if err != nil { + multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) + break + } + + if w.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]", + w.exchangeName, + i+1, + conn.URL, + len(subs)) + } + } + + if multiConnectFatalError != nil { + // Roll back any successful connections and flush subscriptions + for x := range w.connectionManager { + if w.connectionManager[x].Connection != nil { + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + w.connectionManager[x].Connection = nil + } + w.connectionManager[x].Subscriptions.Clear() + } + clear(w.connections) + w.setState(disconnectedState) // Flip from connecting to disconnected. + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + drain(w.ReadMessageErrors) + + return multiConnectFatalError + } + + // Assume connected state here. All connections have been established. + // All subscriptions have been sent and stored. All data received is being + // handled by the appropriate data handler. w.setState(connectedState) - if !w.IsConnectionMonitorRunning() { - err = w.connectionMonitor() - if err != nil { - log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) - } - } - - subs, err := w.GenerateSubs() // regenerate state on new connection - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) - } - if len(subs) != 0 { - if err := w.SubscribeToChannels(subs); err != nil { - return err - } + if w.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go w.monitorFrame(nil, w.monitorConnection) } return nil @@ -342,109 +499,15 @@ func (w *Websocket) Enable() error { return w.Connect() } -// dataMonitor monitors job throughput and logs if there is a back log of data -func (w *Websocket) dataMonitor() { - if w.IsDataMonitorRunning() { - return - } - w.setDataMonitorRunning(true) - w.Wg.Add(1) - - go func() { - defer func() { - w.setDataMonitorRunning(false) - w.Wg.Done() - }() - dropped := 0 - for { - select { - case <-w.ShutdownC: - return - case d := <-w.DataHandler: - select { - case w.ToRoutine <- d: - if dropped != 0 { - log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) - dropped = 0 - } - default: - if dropped == 0 { - // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible - log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) - } - dropped++ - } - } - } - }() -} - -// connectionMonitor ensures that the WS keeps connecting -func (w *Websocket) connectionMonitor() error { - if w.checkAndSetMonitorRunning() { - return errAlreadyRunning - } - delay := w.connectionMonitorDelay - - go func() { - timer := time.NewTimer(delay) - for { - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) - } - if !w.IsEnabled() { - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) - } - if w.IsConnected() { - if err := w.Shutdown(); err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) - } - timer.Stop() - w.setConnectionMonitorRunning(false) - return - } - select { - case err := <-w.ReadMessageErrors: - w.DataHandler <- err - if IsDisconnectionError(err) { - log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - if w.IsConnected() { - if shutdownErr := w.Shutdown(); shutdownErr != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) - } - } - } - case <-timer.C: - if !w.IsConnecting() && !w.IsConnected() { - err := w.Connect() - if err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(delay) - } - } - }() - return nil -} - // Shutdown attempts to shut down a websocket connection and associated routines // by using a package defined shutdown function func (w *Websocket) Shutdown() error { w.m.Lock() defer w.m.Unlock() + return w.shutdown() +} +func (w *Websocket) shutdown() error { if !w.IsConnected() { return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } @@ -460,18 +523,37 @@ func (w *Websocket) Shutdown() error { defer w.Orderbook.FlushBuffer() + // During the shutdown process, all errors are treated as non-fatal to avoid issues when the connection has already + // been closed. In such cases, attempting to close the connection may result in a + // "failed to send closeNotify alert (but connection was closed anyway)" error. Treating these errors as non-fatal + // prevents the shutdown process from being interrupted, which could otherwise trigger a continuous traffic monitor + // cycle and potentially block the initiation of a new connection. + var nonFatalCloseConnectionErrors error + + // Shutdown managed connections + for x := range w.connectionManager { + if w.connectionManager[x].Connection != nil { + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) + } + w.connectionManager[x].Connection = nil + // Flush any subscriptions from last connection across any managed connections + w.connectionManager[x].Subscriptions.Clear() + } + } + // Clean map of old connections + clear(w.connections) + if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { - return err + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } } - if w.AuthConn != nil { if err := w.AuthConn.Shutdown(); err != nil { - return err + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) } } - // flush any subscriptions from last connection if needed w.subscriptions.Clear() @@ -483,6 +565,16 @@ func (w *Websocket) Shutdown() error { if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + drain(w.ReadMessageErrors) + + if nonFatalCloseConnectionErrors != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: shutdown error: %v", w.exchangeName, nonFatalCloseConnectionErrors) + } + return nil } @@ -496,108 +588,78 @@ func (w *Websocket) FlushChannels() error { return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } - if w.features.Subscribe { - newsubs, err := w.GenerateSubs() + // If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to + // maintain consistency. + if !w.features.Subscribe || !w.features.Unsubscribe { + w.m.Lock() + defer w.m.Unlock() + if err := w.shutdown(); err != nil { + return err + } + return w.connect() + } + + if !w.useMultiConnectionManagement { + newSubs, err := w.GenerateSubs() if err != nil { return err } - - subs, unsubs := w.GetChannelDifference(newsubs) - if w.features.Unsubscribe { - if len(unsubs) != 0 { - err := w.UnsubscribeChannels(unsubs) - if err != nil { - return err - } - } + subs, unsubs := w.GetChannelDifference(nil, newSubs) + if err := w.UnsubscribeChannels(nil, unsubs); err != nil { + return err } - - if len(subs) < 1 { + if len(subs) == 0 { return nil } - return w.SubscribeToChannels(subs) - } else if w.features.FullPayloadSubscribe { - // FullPayloadSubscribe means that the endpoint requires all - // subscriptions to be sent via the websocket connection e.g. if you are - // subscribed to ticker and orderbook but require trades as well, you - // would need to send ticker, orderbook and trades channel subscription - // messages. - newsubs, err := w.GenerateSubs() + return w.SubscribeToChannels(nil, subs) + } + + for x := range w.connectionManager { + newSubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() if err != nil { return err } - if len(newsubs) != 0 { - // Purge subscription list as there will be conflicts - w.subscriptions.Clear() - return w.SubscribeToChannels(newsubs) + // Case if there is nothing to unsubscribe from and the connection is nil + if len(newSubs) == 0 && w.connectionManager[x].Connection == nil { + continue } - return nil - } - if err := w.Shutdown(); err != nil { - return err - } - return w.Connect() -} + // If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection. + if w.connectionManager[x].Connection == nil { + conn := w.getConnectionFromSetup(w.connectionManager[x].Setup) + if err := w.connectionManager[x].Setup.Connector(context.TODO(), conn); err != nil { + return err + } + w.Wg.Add(1) + go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) + w.connections[conn] = &w.connectionManager[x] + w.connectionManager[x].Connection = conn + } -// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert -// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic -// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again -func (w *Websocket) trafficMonitor() { - if w.IsTrafficMonitorRunning() { - return - } - w.setTrafficMonitorRunning(true) - w.Wg.Add(1) + subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs) - go func() { - t := time.NewTimer(w.trafficTimeout) - for { - select { - case <-w.ShutdownC: - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) - } - t.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-time.After(trafficCheckInterval): - select { - case <-w.TrafficAlert: - if !t.Stop() { - <-t.C - } - t.Reset(w.trafficTimeout) - default: - } - case <-t.C: - checkAgain := w.IsConnecting() - select { - case <-w.TrafficAlert: - checkAgain = true - default: - } - if checkAgain { - t.Reset(w.trafficTimeout) - break - } - if w.verbose { - log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) - } - w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call - w.Wg.Done() // Without this the w.Shutdown() call below will deadlock - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) - } - } - return + if len(unsubs) != 0 { + if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil { + return err } } - }() + if len(subs) != 0 { + if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil { + return err + } + } + + // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. + if w.connectionManager[x].Subscriptions.Len() == 0 { + delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map + if err := w.connectionManager[x].Connection.Shutdown(); err != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) + } + w.connectionManager[x].Connection = nil + } + } + return nil } func (w *Websocket) setState(s uint32) { @@ -628,37 +690,6 @@ func (w *Websocket) IsEnabled() bool { return w.enabled.Load() } -func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.trafficMonitorRunning.Store(b) -} - -// IsTrafficMonitorRunning returns status of the traffic monitor -func (w *Websocket) IsTrafficMonitorRunning() bool { - return w.trafficMonitorRunning.Load() -} - -func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - return !w.connectionMonitorRunning.CompareAndSwap(false, true) -} - -func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.connectionMonitorRunning.Store(b) -} - -// IsConnectionMonitorRunning returns status of connection monitor -func (w *Websocket) IsConnectionMonitorRunning() bool { - return w.connectionMonitorRunning.Load() -} - -func (w *Websocket) setDataMonitorRunning(b bool) { - w.dataMonitorRunning.Store(b) -} - -// IsDataMonitorRunning returns status of data monitor -func (w *Websocket) IsDataMonitorRunning() bool { - return w.dataMonitorRunning.Load() -} - // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { @@ -673,6 +704,10 @@ func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { // SetWebsocketURL sets websocket URL and can refresh underlying connections func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { + if w.useMultiConnectionManagement { + // TODO: Add functionality for multi-connection management to change URL + return fmt.Errorf("%s: %w", w.exchangeName, errCannotChangeConnectionURL) + } defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage if auth { if defaultVals { @@ -686,10 +721,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { w.runningURLAuth = url if w.verbose { - log.Debugf(log.WebsocketMgr, - "%s websocket: setting authenticated websocket URL: %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: setting authenticated websocket URL: %s\n", w.exchangeName, url) } if w.AuthConn != nil { @@ -706,10 +738,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { w.runningURL = url if w.verbose { - log.Debugf(log.WebsocketMgr, - "%s websocket: setting unauthenticated websocket URL: %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: setting unauthenticated websocket URL: %s\n", w.exchangeName, url) } if w.Conn != nil { @@ -718,10 +747,7 @@ func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { } if w.IsConnected() && reconnect { - log.Debugf(log.WebsocketMgr, - "%s websocket: flushing websocket connection to %s\n", - w.exchangeName, - url) + log.Debugf(log.WebsocketMgr, "%s websocket: flushing websocket connection to %s\n", w.exchangeName, url) return w.Shutdown() } return nil @@ -735,15 +761,13 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { w.m.Lock() - + defer w.m.Unlock() if proxyAddr != "" { if _, err := url.ParseRequestURI(proxyAddr); err != nil { - w.m.Unlock() return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - w.m.Unlock() return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } @@ -752,6 +776,11 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } + for _, wrapper := range w.connectionManager { + if wrapper.Connection != nil { + wrapper.Connection.SetProxy(proxyAddr) + } + } if w.Conn != nil { w.Conn.SetProxy(proxyAddr) } @@ -761,17 +790,13 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { w.proxyAddr = proxyAddr - if w.IsConnected() { - w.m.Unlock() - if err := w.Shutdown(); err != nil { - return err - } - return w.Connect() + if !w.IsConnected() { + return nil } - - w.m.Unlock() - - return nil + if err := w.shutdown(); err != nil { + return err + } + return w.connect() } // GetProxyAddress returns the current websocket proxy @@ -786,49 +811,78 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. -func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) { - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() +func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { + var subscriptionStore **subscription.Store + if wrapper, ok := w.connections[conn]; ok && conn != nil { + subscriptionStore = &wrapper.Subscriptions + } else { + subscriptionStore = &w.subscriptions } - return w.subscriptions.Diff(newSubs) + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() + } + return (*subscriptionStore).Diff(newSubs) } // UnsubscribeChannels unsubscribes from a list of websocket channel -func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { - if w.subscriptions == nil || len(channels) == 0 { +func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error { + if len(channels) == 0 { + return nil // No channels to unsubscribe from is not an error + } + if wrapper, ok := w.connections[conn]; ok && conn != nil { + return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { + return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) + }) + } + return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error { + return w.Unsubscriber(channels) + }) +} + +func (w *Websocket) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error { + if store == nil { return nil // No channels to unsubscribe from is not an error } for _, s := range channels { - if w.subscriptions.Get(s) == nil { + if store.Get(s) == nil { return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) } } - return w.Unsubscriber(channels) + return unsub(channels) } // ResubscribeToChannel resubscribes to channel // Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription // Errors if subscription is already subscribing -func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error { +func (w *Websocket) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error { l := subscription.List{s} if err := s.SetState(subscription.ResubscribingState); err != nil { return fmt.Errorf("%w: %s", err, s) } - if err := w.UnsubscribeChannels(l); err != nil { + if err := w.UnsubscribeChannels(conn, l); err != nil { return err } - return w.SubscribeToChannels(l) + return w.SubscribeToChannels(conn, l) } // SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method // Errors are returned for duplicates or exceeding max Subscriptions -func (w *Websocket) SubscribeToChannels(subs subscription.List) error { +func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) error { if slices.Contains(subs, nil) { return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer) } - if err := w.checkSubscriptions(subs); err != nil { + if err := w.checkSubscriptions(conn, subs); err != nil { return err } + + if wrapper, ok := w.connections[conn]; ok && conn != nil { + return wrapper.Setup.Subscriber(context.TODO(), conn, subs) + } + + if w.Subscriber == nil { + return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer) + } + if err := w.Subscriber(subs); err != nil { return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err) } @@ -837,12 +891,19 @@ func (w *Websocket) SubscribeToChannels(subs subscription.List) error { // AddSubscriptions adds subscriptions to the subscription store // Sets state to Subscribing unless the state is already set -func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() + var subscriptionStore **subscription.Store + if wrapper, ok := w.connections[conn]; ok && conn != nil { + subscriptionStore = &wrapper.Subscriptions + } else { + subscriptionStore = &w.subscriptions + } + + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() } var errs error for _, s := range subs { @@ -851,7 +912,7 @@ func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } } - if err := w.subscriptions.Add(s); err != nil { + if err := (*subscriptionStore).Add(s); err != nil { errs = common.AppendError(errs, err) } } @@ -859,19 +920,28 @@ func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error { } // AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store -func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer) } - if w.subscriptions == nil { - w.subscriptions = subscription.NewStore() + + var subscriptionStore **subscription.Store + if wrapper, ok := w.connections[conn]; ok && conn != nil { + subscriptionStore = &wrapper.Subscriptions + } else { + subscriptionStore = &w.subscriptions } + + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() + } + var errs error for _, s := range subs { if err := s.SetState(subscription.SubscribedState); err != nil { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } - if err := w.subscriptions.Add(s); err != nil { + if err := (*subscriptionStore).Add(s); err != nil { errs = common.AppendError(errs, err) } } @@ -879,19 +949,28 @@ func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscriptio } // RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed -func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error { +func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error { if w == nil { return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer) } - if w.subscriptions == nil { + + var subscriptionStore *subscription.Store + if wrapper, ok := w.connections[conn]; ok && conn != nil { + subscriptionStore = wrapper.Subscriptions + } else { + subscriptionStore = w.subscriptions + } + + if subscriptionStore == nil { return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer) } + var errs error for _, s := range subs { if err := s.SetState(subscription.UnsubscribedState); err != nil { errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) } - if err := w.subscriptions.Remove(s); err != nil { + if err := subscriptionStore.Remove(s); err != nil { errs = common.AppendError(errs, err) } } @@ -902,7 +981,19 @@ func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) erro // returns nil if no subscription is at that key or the key is nil // Keys can implement subscription.MatchableKey in order to provide custom matching logic func (w *Websocket) GetSubscription(key any) *subscription.Subscription { - if w == nil || w.subscriptions == nil || key == nil { + if w == nil || key == nil { + return nil + } + for _, c := range w.connectionManager { + if c.Subscriptions == nil { + continue + } + sub := c.Subscriptions.Get(key) + if sub != nil { + return sub + } + } + if w.subscriptions == nil { return nil } return w.subscriptions.Get(key) @@ -910,10 +1001,19 @@ func (w *Websocket) GetSubscription(key any) *subscription.Subscription { // GetSubscriptions returns a new slice of the subscriptions func (w *Websocket) GetSubscriptions() subscription.List { - if w == nil || w.subscriptions == nil { + if w == nil { return nil } - return w.subscriptions.List() + var subs subscription.List + for _, c := range w.connectionManager { + if c.Subscriptions != nil { + subs = append(subs, c.Subscriptions.List()...) + } + } + if w.subscriptions != nil { + subs = append(subs, w.subscriptions.List()...) + } + return subs } // SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner @@ -926,17 +1026,6 @@ func (w *Websocket) CanUseAuthenticatedEndpoints() bool { return w.canUseAuthenticatedEndpoints.Load() } -// IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error -func IsDisconnectionError(err error) bool { - if websocket.IsUnexpectedCloseError(err) { - return true - } - if _, ok := err.(*net.OpError); ok { - return !errors.Is(err, errClosedConnection) - } - return false -} - // checkWebsocketURL checks for a valid websocket url func checkWebsocketURL(s string) error { u, err := url.Parse(s) @@ -951,12 +1040,18 @@ func checkWebsocketURL(s string) error { // checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists // The subscription state is not considered when counting existing subscriptions -func (w *Websocket) checkSubscriptions(subs subscription.List) error { - if w.subscriptions == nil { +func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { + var subscriptionStore *subscription.Store + if wrapper, ok := w.connections[conn]; ok && conn != nil { + subscriptionStore = wrapper.Subscriptions + } else { + subscriptionStore = w.subscriptions + } + if subscriptionStore == nil { return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) } - existing := w.subscriptions.Len() + existing := subscriptionStore.Len() if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", errSubscriptionsExceedsLimit, @@ -976,3 +1071,173 @@ func (w *Websocket) checkSubscriptions(subs subscription.List) error { return nil } + +// Reader reads and handles data from a specific connection +func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, message []byte) error) { + defer w.Wg.Done() + for { + resp := conn.ReadMessage() + if resp.Raw == nil { + return // Connection has been closed + } + if err := handler(ctx, resp.Raw); err != nil { + w.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) + } + } +} + +func drain(ch <-chan error) { + for { + select { + case <-ch: + default: + return + } + } +} + +// ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit +type ClosureFrame func() func() bool + +// monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true +// This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup +// is optional and is used to signal when the monitor has finished. +func (w *Websocket) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { + if wg != nil { + defer w.Wg.Done() + } + observe := fn() + for { + if observe() { + return + } + } +} + +// monitorData monitors data throughput and logs if there is a back log of data +func (w *Websocket) monitorData() func() bool { + dropped := 0 + return func() bool { return w.observeData(&dropped) } +} + +// observeData observes data throughput and logs if there is a back log of data +func (w *Websocket) observeData(dropped *int) (exit bool) { + select { + case <-w.ShutdownC: + return true + case d := <-w.DataHandler: + select { + case w.ToRoutine <- d: + if *dropped != 0 { + log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) + *dropped = 0 + } + default: + if *dropped == 0 { + // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible + log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) + } + *dropped++ + } + return false + } +} + +// monitorConnection monitors the connection and attempts to reconnect if the connection is lost +func (w *Websocket) monitorConnection() func() bool { + timer := time.NewTimer(w.connectionMonitorDelay) + return func() bool { return w.observeConnection(timer) } +} + +// observeConnection observes the connection and attempts to reconnect if the connection is lost +func (w *Websocket) observeConnection(t *time.Timer) (exit bool) { + select { + case err := <-w.ReadMessageErrors: + if errors.Is(err, errConnectionFault) { + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) + if w.IsConnected() { + if shutdownErr := w.Shutdown(); shutdownErr != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) + } + } + } + // Speedier reconnection, instead of waiting for the next cycle. + if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { + if connectErr := w.Connect(); connectErr != nil { + log.Errorln(log.WebsocketMgr, connectErr) + } + } + w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) + case <-t.C: + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) + } + if !w.IsEnabled() { + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) + } + if w.IsConnected() { + if err := w.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) + } + t.Stop() + w.connectionMonitorRunning.Store(false) + return true + } + if !w.IsConnecting() && !w.IsConnected() { + err := w.Connect() + if err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + t.Reset(w.connectionMonitorDelay) + } + return false +} + +// monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic +// the connection is shutdown and will be reconnected by the connectionMonitor routine. +func (w *Websocket) monitorTraffic() func() bool { + timer := time.NewTimer(w.trafficTimeout) + return func() bool { return w.observeTraffic(timer) } +} + +func (w *Websocket) observeTraffic(t *time.Timer) bool { + select { + case <-w.ShutdownC: + if w.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) + } + case <-t.C: + if w.IsConnecting() || signalReceived(w.TrafficAlert) { + t.Reset(w.trafficTimeout) + return false + } + if w.verbose { + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) + } + if w.IsConnected() { + go func() { // Without this the w.Shutdown() call below will deadlock + if err := w.Shutdown(); err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) + } + }() + } + } + t.Stop() + return true +} + +// signalReceived checks if a signal has been received, this also clears the signal. +func signalReceived(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index a98ccee0..1f6f5e60 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -24,12 +24,19 @@ import ( ) var ( + // errConnectionFault is a connection fault error which alerts the system that a connection cycle needs to take place. + errConnectionFault = errors.New("connection fault") errWebsocketIsDisconnected = errors.New("websocket connection is disconnected") errRateLimitNotFound = errors.New("rate limit definition not found") ) // Dial sets proxy urls and then connects to the websocket func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error { + return w.DialContext(context.Background(), dialer, headers) +} + +// DialContext sets proxy urls and then connects to the websocket +func (w *WebsocketConnection) DialContext(ctx context.Context, dialer *websocket.Dialer, headers http.Header) error { if w.ProxyURL != "" { proxy, err := url.Parse(w.ProxyURL) if err != nil { @@ -40,15 +47,15 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header var err error var conStatus *http.Response - - w.Connection, conStatus, err = dialer.Dial(w.URL, headers) + w.Connection, conStatus, err = dialer.DialContext(ctx, w.URL, headers) if err != nil { if conStatus != nil { + _ = conStatus.Body.Close() return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } - defer conStatus.Body.Close() + _ = conStatus.Body.Close() if w.Verbose { log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) @@ -131,18 +138,18 @@ func (w *WebsocketConnection) SetupPingHandler(epl request.EndpointLimit, handle return } w.Wg.Add(1) - defer w.Wg.Done() go func() { + defer w.Wg.Done() ticker := time.NewTicker(handler.Delay) for { select { - case <-w.ShutdownC: + case <-w.shutdown: ticker.Stop() return case <-ticker.C: err := w.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message) if err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]", w.ExchangeName, handler.Message) + log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", w.ExchangeName, handler.Message, err) return } } @@ -168,23 +175,24 @@ func (w *WebsocketConnection) IsConnected() bool { func (w *WebsocketConnection) ReadMessage() Response { mType, resp, err := w.Connection.ReadMessage() if err != nil { - if IsDisconnectionError(err) { - if w.setConnectedStatus(false) { - // NOTE: When w.setConnectedStatus() returns true the underlying - // state was changed and this infers that the connection was - // externally closed and an error is reported else Shutdown() - // method on WebsocketConnection type has been called and can - // be skipped. - select { - case w.readMessageErrors <- err: - default: - // bypass if there is no receiver, as this stops it returning - // when shutdown is called. - log.Warnf(log.WebsocketMgr, - "%s failed to relay error: %v", - w.ExchangeName, - err) - } + // If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the + // reader routine to exit. This leaves the connection without an active reader, + // leading to potential buffer issue from the ongoing websocket writes. + // Such errors are passed to `w.readMessageErrors` when the connection is active. + // The `connectionMonitor` handles these errors by flushing the buffer, reconnecting, + // and resubscribing to the websocket to restore the connection. + if w.setConnectedStatus(false) { + // NOTE: When w.setConnectedStatus() returns true the underlying + // state was changed and this infers that the connection was + // externally closed and an error is reported else Shutdown() + // method on WebsocketConnection type has been called and can + // be skipped. + select { + case w.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault): + default: + // bypass if there is no receiver, as this stops it returning + // when shutdown is called. + log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", w.ExchangeName, err) } } return Response{} @@ -203,7 +211,7 @@ func (w *WebsocketConnection) ReadMessage() Response { standardMessage, err = w.parseBinaryResponse(resp) if err != nil { log.Errorf(log.WebsocketMgr, "%v %v: Parse binary response error: %v", w.ExchangeName, removeURLQueryString(w.URL), err) - return Response{} + return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case. } } if w.Verbose { @@ -264,7 +272,9 @@ func (w *WebsocketConnection) Shutdown() error { return nil } w.setConnectedStatus(false) - return w.Connection.UnderlyingConn().Close() + w.writeControl.Lock() + defer w.writeControl.Unlock() + return w.Connection.NetConn().Close() } // SetURL sets connection URL diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index ecd32646..2904bcca 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -7,10 +7,8 @@ import ( "context" "encoding/json" "errors" - "fmt" - "net" "net/http" - "os" + "net/http/httptest" "strconv" "strings" "sync" @@ -26,20 +24,18 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" ) const ( - websocketTestURL = "wss://www.bitmex.com/realtime" - useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find - proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server + useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find + proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) var ( errDastardlyReason = errors.New("some dastardly reason") ) -var dialer websocket.Dialer - type testStruct struct { Error error WC WebsocketConnection @@ -94,215 +90,83 @@ var defaultSetup = &WebsocketSetup{ Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, } -type dodgyConnection struct { - WebsocketConnection -} - -// override websocket connection method to produce a wicked terrible error -func (d *dodgyConnection) Shutdown() error { - return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) -} - -// override websocket connection method to produce a wicked terrible error -func (d *dodgyConnection) Connect() error { - return fmt.Errorf("cannot connect: %w", errDastardlyReason) -} - -func TestMain(m *testing.M) { - // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing - trafficCheckInterval = 50 * time.Millisecond - os.Exit(m.Run()) -} - func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketIsNil) w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketSetupIsNil) websocketSetup := &WebsocketSetup{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigIsNil) websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigNameEmpty) websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset) websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") + assert.ErrorIs(t, err, errConfigFeaturesIsNil) websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} + websocketSetup.Subscriber = func(subscription.List) error { return nil } // kicks off the setup err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketConnectorUnset) + websocketSetup.Subscriber = nil websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketSubscriberUnset) websocketSetup.Subscriber = func(subscription.List) error { return nil } - websocketSetup.Features.Unsubscribe = true + w.features.Unsubscribe = true err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset) websocketSetup.Unsubscriber = func(subscription.List) error { return nil } err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errDefaultURLIsEmpty) websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") + assert.ErrorIs(t, err, errRunningURLIsEmpty) websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidWebsocketURL) websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidWebsocketURL) websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") + assert.ErrorIs(t, err, errInvalidTrafficTimeout) websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) assert.NoError(t, err, "Setup should not error") } -// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval -// ensures shutdown works after traffic alerts -func TestTrafficMonitorTrafficAlerts(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - signal := struct{}{} - patience := 10 * time.Millisecond - ws.trafficTimeout = 200 * time.Millisecond - ws.state.Store(connectedState) - - thenish := time.Now() - ws.trafficMonitor() - - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected") - - for i := range 6 { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass - select { - case ws.TrafficAlert <- signal: - if i == 0 { - require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") - } - default: - require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) - } - - select { - case ws.TrafficAlert <- signal: - require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) - default: - if i == 0 { - require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") - } - } - - require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 5*time.Second, patience, "trafficAlert should be drained; Check #%d", i) - assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) - } - - require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") -} - -// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown -func TestTrafficMonitorConnecting(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - ws.state.Store(connectingState) - ws.trafficTimeout = 50 * time.Millisecond - ws.trafficMonitor() - require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting") - <-time.After(4 * ws.trafficTimeout) - require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks") - ws.state.Store(connectedState) - require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") -} - -// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared -func TestTrafficMonitorShutdown(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - err := ws.Setup(defaultSetup) - require.NoError(t, err, "Setup must not error") - - ws.state.Store(connectedState) - ws.trafficTimeout = time.Minute - ws.trafficMonitor() - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - - wgReady := make(chan bool) - go func() { - ws.Wg.Wait() - close(wgReady) - }() - select { - case <-wgReady: - require.Failf(t, "", "WaitGroup should be blocking still") - case <-time.After(trafficCheckInterval): - } - - close(ws.ShutdownC) - - <-time.After(2 * trafficCheckInterval) - assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") - select { - case <-wgReady: - default: - require.Failf(t, "", "WaitGroup should be freed now") - } -} - -func TestIsDisconnectionError(t *testing.T) { - t.Parallel() - assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") - assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") - assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") - assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") -} - func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} - err := wsWrong.Connect() - assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") - wsWrong.connector = func() error { return nil } - err = wsWrong.Connect() + err := wsWrong.Connect() assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) @@ -330,29 +194,95 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.NoError(t, err, "Connect must not error") - c := func(tb *assert.CollectT) { - select { - case v, ok := <-ws.ToRoutine: - require.True(tb, ok, "ToRoutine should not be closed on us") - switch err := v.(type) { - case *websocket.CloseError: - assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") - case error: - assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") - default: - assert.Failf(tb, "Wrong data type sent to ToRoutine", "Got type: %T", err) - } + checkToRoutineResult := func(t *testing.T) { + t.Helper() + v, ok := <-ws.ToRoutine + require.True(t, ok, "ToRoutine should not be closed on us") + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(t, err, errDastardlyReason, "Should get the correct error") default: - assert.Fail(tb, "Nothing available on ToRoutine") + assert.Failf(t, "Wrong data type sent to ToRoutine", "Got type: %T", err) } } ws.TrafficAlert <- struct{}{} ws.ReadMessageErrors <- errDastardlyReason - assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") + checkToRoutineResult(t) ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} - assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") + checkToRoutineResult(t) + + // Test individual connection defined functions + require.NoError(t, ws.Shutdown()) + ws.useMultiConnectionManagement = true + + err = ws.Connect() + assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") + + ws.useMultiConnectionManagement = true + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { + return nil, errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { + return subscription.List{{}}, nil + } + err = ws.Connect() + require.ErrorIs(t, err, errNoConnectFunc) + + ws.connectionManager[0].Setup.Connector = func(context.Context, Connection) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketDataHandlerUnset) + + ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errWebsocketSubscriberUnset) + + ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.connectionManager[0].Setup.Connector = func(ctx context.Context, conn Connection) error { + return conn.DialContext(ctx, websocket.DefaultDialer, nil) + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { + return errDastardlyReason + } + err = ws.Connect() + require.ErrorIs(t, err, errDastardlyReason) + + ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + return nil + } + err = ws.Connect() + require.NoError(t, err) + + err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test")) + require.NoError(t, err) + + require.NoError(t, err) + require.NoError(t, ws.Shutdown()) } func TestWebsocket(t *testing.T) { @@ -363,10 +293,7 @@ func TestWebsocket(t *testing.T) { err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") - ws.Conn = &dodgyConnection{} - ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) - err = ws.Setup(defaultSetup) // Sets to enabled again require.NoError(t, err, "Setup may not error") @@ -387,6 +314,7 @@ func TestWebsocket(t *testing.T) { ws.setState(connectedState) + ws.connector = func() error { return errDastardlyReason } err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") @@ -394,13 +322,9 @@ func TestWebsocket(t *testing.T) { assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") // removing proxy - err = ws.SetProxyAddress("") - assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") - assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") + assert.NoError(t, ws.SetProxyAddress("")) - ws.Conn = &WebsocketConnection{} ws.setEnabled(true) - // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") @@ -408,15 +332,11 @@ func TestWebsocket(t *testing.T) { assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") + assert.ErrorIs(t, ws.Shutdown(), ErrNotConnected) ws.setState(connectedState) - ws.AuthConn = &dodgyConnection{} - err = ws.Shutdown() - assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") - assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") - - ws.AuthConn = &WebsocketConnection{} - ws.setState(disconnectedState) + assert.NoError(t, ws.Shutdown()) + ws.connector = func() error { return nil } err = ws.Connect() assert.NoError(t, err, "Connect should not error") @@ -448,17 +368,35 @@ func TestWebsocket(t *testing.T) { err = ws.Shutdown() assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() + + ws.useMultiConnectionManagement = true + + ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} + err = ws.SetProxyAddress("https://192.168.0.1:1337") + require.NoError(t, err) } func currySimpleSub(w *Websocket) func(subscription.List) error { return func(subs subscription.List) error { - return w.AddSuccessfulSubscriptions(subs...) + return w.AddSuccessfulSubscriptions(nil, subs...) + } +} + +func currySimpleSubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, subs subscription.List) error { + return w.AddSuccessfulSubscriptions(conn, subs...) } } func currySimpleUnsub(w *Websocket) func(subscription.List) error { return func(unsubs subscription.List) error { - return w.RemoveSubscriptions(unsubs...) + return w.RemoveSubscriptions(nil, unsubs...) + } +} + +func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, unsubs subscription.List) error { + return w.RemoveSubscriptions(conn, unsubs...) } } @@ -473,11 +411,11 @@ func TestSubscribeUnsubscribe(t *testing.T) { subs, err := ws.GenerateSubs() require.NoError(t, err, "Generating test subscriptions should not error") - assert.NoError(t, new(Websocket).UnsubscribeChannels(subs), "Should not error when w.subscriptions is nil") - assert.NoError(t, ws.UnsubscribeChannels(nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") - assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") + assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { @@ -495,14 +433,76 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, ws.SubscribeToChannels(nil), "Subscribe to an nil List should not error") - assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") + assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") ws.Subscriber = func(subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, ws.SubscribeToChannels(subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - err = ws.SubscribeToChannels(subscription.List{nil}) + err = ws.SubscribeToChannels(nil, subscription.List{nil}) + assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") + + multi := NewWebsocket() + set := *defaultSetup + set.UseMultiConnectionManagement = true + assert.NoError(t, multi.Setup(&set)) + + amazingCandidate := &ConnectionSetup{ + URL: "AMAZING", + Connector: func(context.Context, Connection) error { return nil }, + GenerateSubscriptions: ws.GenerateSubs, + Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(multi)(ctx, c, s) + }, + Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(multi)(ctx, c, s) + }, + Handler: func(context.Context, []byte) error { return nil }, + } + require.NoError(t, multi.SetupNewConnection(amazingCandidate)) + + amazingConn := multi.getConnectionFromSetup(amazingCandidate) + multi.connections = map[Connection]*ConnectionWrapper{ + amazingConn: &multi.connectionManager[0], + } + + subs, err = amazingCandidate.GenerateSubscriptions() + require.NoError(t, err, "Generating test subscriptions should not error") + assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") + + assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") + + assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") + assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") + bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") + } + if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") { + assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { + assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") { + assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") + } + assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") + assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") + + amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + + err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") } @@ -524,9 +524,9 @@ func TestResubscribe(t *testing.T) { channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") - assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") + assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") } // TestSubscriptions tests adding, getting and removing subscriptions @@ -535,18 +535,18 @@ func TestSubscriptions(t *testing.T) { w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil assert.ErrorIs(t, (*Websocket)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error") + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSubscriptions(s), subscription.ErrDuplicate, "Adding same subscription should return error") + assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error") assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") - err := w.RemoveSubscriptions(s) + err := w.RemoveSubscriptions(nil, s) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error") - require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error") + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state") } @@ -554,35 +554,21 @@ func TestSubscriptions(t *testing.T) { func TestSuccessfulSubscriptions(t *testing.T) { t.Parallel() w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil - assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") + assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSuccessfulSubscriptions(c), "Adding first subscription should not error") + require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription should not error") assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrDuplicate, "Adding same subscription should return error") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error") - err := w.RemoveSubscriptions(c) + err := w.RemoveSubscriptions(nil, c) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") - assert.ErrorIs(t, w.RemoveSubscriptions(c), subscription.ErrNotFound, "Should error correctly when not found") - assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found") + assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") w.subscriptions = nil - assert.ErrorIs(t, w.RemoveSubscriptions(c), common.ErrNilPointer, "Should error correctly when nil websocket") -} - -// TestConnectionMonitorNoConnection logic test -func TestConnectionMonitorNoConnection(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - ws.connectionMonitorDelay = 500 - ws.exchangeName = "hello" - ws.setEnabled(true) - err := ws.connectionMonitor() - require.NoError(t, err, "connectionMonitor must not error") - assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") - err = ws.connectionMonitor() - assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") } // TestGetSubscription logic test @@ -593,7 +579,7 @@ func TestGetSubscription(t *testing.T) { w := NewWebsocket() assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") s := &subscription.Subscription{Key: 42, Channel: "hello3"} - require.NoError(t, w.AddSubscriptions(s), "AddSubscriptions must not error") + require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error") assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } @@ -607,7 +593,7 @@ func TestGetSubscriptions(t *testing.T) { {Key: 42, Channel: "hello3"}, {Key: 45, Channel: "hello4"}, } - err := w.AddSubscriptions(s...) + err := w.AddSubscriptions(nil, s...) require.NoError(t, err, "AddSubscriptions must not error") assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details") } @@ -624,17 +610,22 @@ func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { // TestDial logic test func TestDial(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + var testCases = []testStruct{ - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test1", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), ResponseMaxLimit: 7000000000, }, }, - {Error: errors.New(" Error: malformed ws or wss URL"), + { + Error: errors.New(" Error: malformed ws or wss URL"), WC: WebsocketConnection{ ExchangeName: "test2", Verbose: true, @@ -642,47 +633,51 @@ func TestDial(t *testing.T) { ResponseMaxLimit: 7000000000, }, }, - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test3", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ProxyURL: proxyURL, ResponseMaxLimit: 7000000000, }, }, } + // Mock server rejects parallel connections for i := range testCases { - testData := &testCases[i] - t.Run(testData.WC.ExchangeName, func(t *testing.T) { - t.Parallel() - if testData.WC.ProxyURL != "" && !useProxyTests { - t.Skip("Proxy testing not enabled, skipping") + if testCases[i].WC.ProxyURL != "" && !useProxyTests { + t.Log("Proxy testing not enabled, skipping") + continue + } + err := testCases[i].WC.Dial(&websocket.Dialer{}, http.Header{}) + if err != nil { + if testCases[i].Error != nil && strings.Contains(err.Error(), testCases[i].Error.Error()) { + return } - err := testData.WC.Dial(&dialer, http.Header{}) - if err != nil { - if testData.Error != nil && strings.Contains(err.Error(), testData.Error.Error()) { - return - } - t.Fatal(err) - } - }) + t.Fatal(err) + } } } // TestSendMessage logic test func TestSendMessage(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + var testCases = []testStruct{ - {Error: nil, WC: WebsocketConnection{ - ExchangeName: "test1", - Verbose: true, - URL: websocketTestURL, - RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), - ResponseMaxLimit: 7000000000, + { + WC: WebsocketConnection{ + ExchangeName: "test1", + Verbose: true, + URL: "ws" + mock.URL[len("http"):] + "/ws", + RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond), + ResponseMaxLimit: 7000000000, + }, }, - }, - {Error: errors.New(" Error: malformed ws or wss URL"), + { + Error: errors.New(" Error: malformed ws or wss URL"), WC: WebsocketConnection{ ExchangeName: "test2", Verbose: true, @@ -690,47 +685,45 @@ func TestSendMessage(t *testing.T) { ResponseMaxLimit: 7000000000, }, }, - {Error: nil, + { WC: WebsocketConnection{ ExchangeName: "test3", Verbose: true, - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ProxyURL: proxyURL, ResponseMaxLimit: 7000000000, }, }, } - for i := range testCases { - testData := &testCases[i] - t.Run(testData.WC.ExchangeName, func(t *testing.T) { - t.Parallel() - if testData.WC.ProxyURL != "" && !useProxyTests { - t.Skip("Proxy testing not enabled, skipping") + // Mock server rejects parallel connections + for x := range testCases { + if testCases[x].WC.ProxyURL != "" && !useProxyTests { + t.Log("Proxy testing not enabled, skipping") + continue + } + err := testCases[x].WC.Dial(&websocket.Dialer{}, http.Header{}) + if err != nil { + if testCases[x].Error != nil && strings.Contains(err.Error(), testCases[x].Error.Error()) { + return } - err := testData.WC.Dial(&dialer, http.Header{}) - if err != nil { - if testData.Error != nil && strings.Contains(err.Error(), testData.Error.Error()) { - return - } - t.Fatal(err) - } - err = testData.WC.SendJSONMessage(context.Background(), request.Unset, Ping) - if err != nil { - t.Error(err) - } - err = testData.WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) - if err != nil { - t.Error(err) - } - }) + t.Fatal(err) + } + err = testCases[x].WC.SendJSONMessage(context.Background(), request.Unset, Ping) + require.NoError(t, err) + err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) + require.NoError(t, err) } } func TestSendMessageReturnResponse(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ Verbose: true, - URL: "wss://ws.kraken.com", + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), } @@ -738,7 +731,7 @@ func TestSendMessageReturnResponse(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&dialer, http.Header{}) + err := wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -813,8 +806,12 @@ func readMessages(t *testing.T, wc *WebsocketConnection) { // TestSetupPingHandler logic test func TestSetupPingHandler(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), Wg: &sync.WaitGroup{}, @@ -823,8 +820,8 @@ func TestSetupPingHandler(t *testing.T) { if wc.ProxyURL != "" && !useProxyTests { t.Skip("Proxy testing not enabled, skipping") } - wc.ShutdownC = make(chan struct{}) - err := wc.Dial(&dialer, http.Header{}) + wc.shutdown = make(chan struct{}) + err := wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -840,7 +837,7 @@ func TestSetupPingHandler(t *testing.T) { t.Error(err) } - err = wc.Dial(&dialer, http.Header{}) + err = wc.Dial(&websocket.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -850,15 +847,19 @@ func TestSetupPingHandler(t *testing.T) { Delay: 200, }) time.Sleep(time.Millisecond * 201) - close(wc.ShutdownC) + close(wc.shutdown) wc.Wg.Wait() } // TestParseBinaryResponse logic test func TestParseBinaryResponse(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + wc := &WebsocketConnection{ - URL: websocketTestURL, + URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), } @@ -917,7 +918,7 @@ func TestGenerateMessageID(t *testing.T) { assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID must use bespokeGenerateMessageID") } -// BenchmarkGenerateMessageID-8 2850018 408 ns/op 56 B/op 4 allocs/op +// 7002502 166.7 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_High(b *testing.B) { wc := WebsocketConnection{} for i := 0; i < b.N; i++ { @@ -925,7 +926,7 @@ func BenchmarkGenerateMessageID_High(b *testing.B) { } } -// BenchmarkGenerateMessageID_Low-8 2591596 447 ns/op 56 B/op 4 allocs/op +// 6536250 186.1 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_Low(b *testing.B) { wc := WebsocketConnection{} for i := 0; i < b.N; i++ { @@ -959,14 +960,44 @@ func TestGetChannelDifference(t *testing.T) { t.Parallel() w := &Websocket{} - assert.NotPanics(t, func() { w.GetChannelDifference(subscription.List{}) }, "Should not panic when called without a store") - subs, unsubs := w.GetChannelDifference(subscription.List{{Channel: subscription.CandlesChannel}}) + assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store") + subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") require.Empty(t, unsubs, "Should get no unsubs") - require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error") - subs, unsubs = w.GetChannelDifference(subscription.List{{Channel: subscription.TickerChannel}}) + require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") + subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}}) require.Equal(t, 1, len(subs), "Should get the correct number of subs") assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs") + + w = &Websocket{} + sweetConn := &WebsocketConnection{} + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + w.connections = map[Connection]*ConnectionWrapper{ + sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, + } + + naughtyConn := &WebsocketConnection{} + subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs)) + require.Empty(t, unsubs, "Should get no unsubs") + + err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + require.NoError(t, err) + + subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) + require.Empty(t, subs, "Should get no subs") + require.Empty(t, unsubs, "Should get no unsubs") + + subs, unsubs = w.GetChannelDifference(sweetConn, nil) + require.Empty(t, subs, "Should get no subs") + require.Equal(t, 1, len(unsubs)) } // GenSubs defines a theoretical exchange with pair management @@ -1010,10 +1041,6 @@ func connect() error { return nil } func TestFlushChannels(t *testing.T) { t.Parallel() // Enabled pairs/setup system - newgen := GenSubs{EnabledPairs: []currency.Pair{ - currency.NewPair(currency.BTC, currency.AUD), - currency.NewPair(currency.BTC, currency.USDT), - }} dodgyWs := Websocket{} err := dodgyWs.FlushChannels() @@ -1023,7 +1050,13 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") + newgen := GenSubs{EnabledPairs: []currency.Pair{ + currency.NewPair(currency.BTC, currency.AUD), + currency.NewPair(currency.BTC, currency.USDT), + }} + w := NewWebsocket() + w.exchangeName = "test" w.connector = connect w.Subscriber = newgen.SUBME w.Unsubscriber = newgen.UNSUBME @@ -1033,40 +1066,30 @@ func TestFlushChannels(t *testing.T) { w.setEnabled(true) w.setState(connectedState) - problemFunc := func() (subscription.List, error) { - return nil, errDastardlyReason - } - - noSub := func() (subscription.List, error) { - return nil, nil - } + // Allow subscribe and unsubscribe feature set, without these the tests will call shutdown and connect. + w.features.Subscribe = true + w.features.Unsubscribe = true // Disable pair and flush system - newgen.EnabledPairs = []currency.Pair{ - currency.NewPair(currency.BTC, currency.AUD)} - w.GenerateSubs = func() (subscription.List, error) { - return subscription.List{{Channel: "test"}}, nil - } + newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)} + w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } err = w.FlushChannels() require.NoError(t, err, "Flush Channels must not error") - w.features.FullPayloadSubscribe = true - w.GenerateSubs = problemFunc - err = w.FlushChannels() // error on full subscribeToChannels + w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs + err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") - w.GenerateSubs = noSub - err = w.FlushChannels() // No subs to sub + w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub + err = w.FlushChannels() // No subs to sub assert.NoError(t, err, "Flush Channels should not error") w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error") + require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - w.features.FullPayloadSubscribe = false - w.features.Subscribe = true w.GenerateSubs = newgen.generateSubs w.subscriptions = subscription.NewStore() @@ -1087,9 +1110,40 @@ func TestFlushChannels(t *testing.T) { assert.NoError(t, err, "FlushChannels should not error") w.setState(connectedState) - w.features.Unsubscribe = true err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") + + // Multi connection management + w.useMultiConnectionManagement = true + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + + amazingCandidate := &ConnectionSetup{ + URL: "ws" + mock.URL[len("http"):] + "/ws", + Connector: func(ctx context.Context, conn Connection) error { + return conn.DialContext(ctx, websocket.DefaultDialer, nil) + }, + GenerateSubscriptions: newgen.generateSubs, + Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(w)(ctx, c, s) + }, + Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(w)(ctx, c, s) + }, + Handler: func(context.Context, []byte) error { return nil }, + } + require.NoError(t, w.SetupNewConnection(amazingCandidate)) + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + + // Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines. + w.features.Subscribe = false + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") + + // Unsubscribe what's already subscribed. No subscriptions left over, which then forces the shutdown and removal + // of the connection from management. + w.features.Subscribe = true + w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } + require.NoError(t, w.FlushChannels(), "FlushChannels must not error") } func TestDisable(t *testing.T) { @@ -1115,19 +1169,19 @@ func TestEnable(t *testing.T) { func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket - err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err := nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) - err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") web := NewWebsocket() @@ -1135,26 +1189,71 @@ func TestSetupNewConnection(t *testing.T) { err = web.Setup(defaultSetup) assert.NoError(t, err, "Setup should not error") - err = web.SetupNewConnection(ConnectionSetup{}) - assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") - - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) + err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.NoError(t, err, "SetupNewConnection should not error") - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring", Authenticated: true}) assert.NoError(t, err, "SetupNewConnection should not error") + + // Test connection candidates for multi connection tracking. + multi := NewWebsocket() + set := *defaultSetup + set.UseMultiConnectionManagement = true + require.NoError(t, multi.Setup(&set)) + + err = multi.SetupNewConnection(nil) + require.ErrorIs(t, err, errExchangeConfigEmpty) + + connSetup := &ConnectionSetup{ResponseCheckTimeout: time.Millisecond} + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errDefaultURLIsEmpty) + + connSetup.URL = "urlstring" + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketConnectorUnset) + + connSetup.Connector = func(context.Context, Connection) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + + connSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketSubscriberUnset) + + connSetup.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketUnsubscriberUnset) + + connSetup.Unsubscriber = func(context.Context, Connection, subscription.List) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWebsocketDataHandlerUnset) + + connSetup.Handler = func(context.Context, []byte) error { return nil } + err = multi.SetupNewConnection(connSetup) + require.NoError(t, err) + + require.Len(t, multi.connectionManager, 1) + + require.Nil(t, multi.AuthConn) + require.Nil(t, multi.Conn) + + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errConnectionWrapperDuplication) } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() - wc := WebsocketConnection{} + wc := WebsocketConnection{shutdown: make(chan struct{})} err := wc.Shutdown() assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") - wc.URL = websocketTestURL + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + + wc.URL = "ws" + mock.URL[len("http"):] + "/ws" err = wc.Dial(&websocket.Dialer{}, nil) require.NoError(t, err, "Dial must not error") @@ -1166,13 +1265,17 @@ func TestWebsocketConnectionShutdown(t *testing.T) { // TestLatency logic test func TestLatency(t *testing.T) { t.Parallel() + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + r := &reporter{} exch := "Kraken" wc := &WebsocketConnection{ ExchangeName: exch, Verbose: true, - URL: "wss://ws.kraken.com", - ResponseMaxLimit: time.Second * 5, + URL: "ws" + mock.URL[len("http"):] + "/ws", + ResponseMaxLimit: time.Second * 1, Match: NewMatch(), Reporter: r, } @@ -1180,54 +1283,42 @@ func TestLatency(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&dialer, http.Header{}) - if err != nil { - t.Fatal(err) - } + err := wc.Dial(&websocket.Dialer{}, http.Header{}) + require.NoError(t, err) go readMessages(t, wc) req := testRequest{ - Event: "subscribe", - Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()}, - Subscription: testRequestData{ - Name: "ticker", - }, - RequestID: wc.GenerateMessageID(false), + Event: "subscribe", + Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()}, + Subscription: testRequestData{Name: "ticker"}, + RequestID: wc.GenerateMessageID(false), } _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req) - if err != nil { - t.Error(err) - } - - if r.t == 0 { - t.Error("expected a nonzero duration, got zero") - } - - if r.name != exch { - t.Errorf("expected %v, got %v", exch, r.name) - } + require.NoError(t, err) + require.NotEmpty(t, r.t, "Latency should have a duration") + require.Equal(t, exch, r.name, "Latency should have the correct exchange name") } func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} - err := ws.checkSubscriptions(nil) + err := ws.checkSubscriptions(nil, nil) assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions") assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions") ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(nil) + err = ws.checkSubscriptions(nil, nil) assert.NoError(t, err, "checkSubscriptions should not error on a nil list") ws.MaxSubscriptionsPerConnection = 1 - err = ws.checkSubscriptions(subscription.List{{}}) + err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty") ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(subscription.List{{}, {}}) + err = ws.checkSubscriptions(nil, subscription.List{{}, {}}) assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 @@ -1235,10 +1326,10 @@ func TestCheckSubscriptions(t *testing.T) { ws.subscriptions = subscription.NewStore() err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) require.NoError(t, err, "Add subscription must not error") - err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}}) + err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}}) assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - err = ws.checkSubscriptions(subscription.List{{}}) + err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } @@ -1259,10 +1350,9 @@ func TestWriteToConn(t *testing.T) { // connection rate limit set wc.RateLimit = request.NewWeightedRateLimitByDuration(time.Millisecond) require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil })) - // context cancelled - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 0) // deadline exceeded cancel() - require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.Canceled) + require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.DeadlineExceeded) // definitions set but with fallover wc.RateLimitDefinitions = request.RateLimitDefinitions{ request.Auth: request.NewWeightedRateLimitByDuration(time.Millisecond), @@ -1274,3 +1364,123 @@ func TestWriteToConn(t *testing.T) { wc.RateLimit = nil require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), errRateLimitNotFound) } + +func TestDrain(t *testing.T) { + t.Parallel() + drain(nil) + ch := make(chan error) + drain(ch) + require.Empty(t, ch, "Drain should empty the channel") + ch = make(chan error, 10) + for range 10 { + ch <- errors.New("test") + } + drain(ch) + require.Empty(t, ch, "Drain should empty the channel") +} + +func TestMonitorFrame(t *testing.T) { + t.Parallel() + ws := Websocket{} + require.Panics(t, func() { ws.monitorFrame(nil, nil) }, "monitorFrame must panic on nil frame") + require.Panics(t, func() { ws.monitorFrame(nil, func() func() bool { return nil }) }, "monitorFrame must panic on nil function") + ws.Wg.Add(1) + ws.monitorFrame(&ws.Wg, func() func() bool { return func() bool { return true } }) + ws.Wg.Wait() +} + +func TestMonitorData(t *testing.T) { + t.Parallel() + ws := Websocket{ShutdownC: make(chan struct{}), DataHandler: make(chan interface{}, 10)} + // Handle shutdown signal + close(ws.ShutdownC) + require.True(t, ws.observeData(nil)) + ws.ShutdownC = make(chan struct{}) + // Handle blockage of ToRoutine + go func() { ws.DataHandler <- nil }() + var dropped int + require.False(t, ws.observeData(&dropped)) + require.Equal(t, 1, dropped) + // Handle reinstate of ToRoutine functionality which will reset dropped counter + ws.ToRoutine = make(chan interface{}, 10) + go func() { ws.DataHandler <- nil }() + require.False(t, ws.observeData(&dropped)) + require.Empty(t, dropped) + // Handle outer closure shell + innerShell := ws.monitorData() + go func() { ws.DataHandler <- nil }() + require.False(t, innerShell()) + // Handle shutdown signal + close(ws.ShutdownC) + require.True(t, innerShell()) +} + +func TestMonitorConnection(t *testing.T) { + t.Parallel() + ws := Websocket{verbose: true, ReadMessageErrors: make(chan error, 1), ShutdownC: make(chan struct{})} + // Handle timer expired and websocket disabled, shutdown everything. + timer := time.NewTimer(0) + ws.setState(connectedState) + ws.connectionMonitorRunning.Store(true) + require.True(t, ws.observeConnection(timer)) + require.False(t, ws.connectionMonitorRunning.Load()) + require.Equal(t, disconnectedState, ws.state.Load()) + // Handle timer expired and everything is great, reset the timer. + ws.setEnabled(true) + ws.setState(connectedState) + ws.connectionMonitorRunning.Store(true) + timer = time.NewTimer(0) + require.False(t, ws.observeConnection(timer)) // Not shutting down + // Handle timer expired and for reason its not connected, so lets happily connect again. + ws.setState(disconnectedState) + require.False(t, ws.observeConnection(timer)) // Connect is intentionally erroring + // Handle error from a connection which will then trigger a reconnect + ws.setState(connectedState) + ws.DataHandler = make(chan interface{}, 1) + ws.ReadMessageErrors <- errConnectionFault + timer = time.NewTimer(time.Second) + require.False(t, ws.observeConnection(timer)) + payload := <-ws.DataHandler + err, ok := payload.(error) + require.True(t, ok) + require.ErrorIs(t, err, errConnectionFault) + // Handle outta closure shell + innerShell := ws.monitorConnection() + ws.setState(connectedState) + ws.ReadMessageErrors <- errConnectionFault + require.False(t, innerShell()) +} + +func TestMonitorTraffic(t *testing.T) { + t.Parallel() + ws := Websocket{verbose: true, ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1)} + ws.Wg.Add(1) + // Handle external shutdown signal + timer := time.NewTimer(time.Second) + close(ws.ShutdownC) + require.True(t, ws.observeTraffic(timer)) + // Handle timer expired but system is connecting, so reset the timer + ws.ShutdownC = make(chan struct{}) + ws.setState(connectingState) + timer = time.NewTimer(0) + require.False(t, ws.observeTraffic(timer)) + // Handle timer expired and system is connected and has traffic within time window + ws.setState(connectedState) + timer = time.NewTimer(0) + ws.TrafficAlert <- struct{}{} + require.False(t, ws.observeTraffic(timer)) + // Handle timer expired and system is connected but no traffic within time window, causes shutdown to occur. + timer = time.NewTimer(0) + require.True(t, ws.observeTraffic(timer)) + ws.Wg.Done() + // Shutdown is done in a routine, so we need to wait for it to finish + require.Eventually(t, func() bool { return disconnectedState == ws.state.Load() }, time.Second, time.Millisecond) + // Handle outer closure shell + innerShell := ws.monitorTraffic() + ws.m.Lock() + ws.ShutdownC = make(chan struct{}) + ws.m.Unlock() + ws.setState(connectedState) + ws.TrafficAlert <- struct{}{} + require.False(t, innerShell()) +} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 54d73fcb..27a5c819 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -38,8 +38,6 @@ type Websocket struct { state atomic.Uint32 verbose bool connectionMonitorRunning atomic.Bool - trafficMonitorRunning atomic.Bool - dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -51,6 +49,15 @@ type Websocket struct { m sync.Mutex connector func() error + // connectionManager stores all *potential* connections for the exchange, organised within ConnectionWrapper structs. + // Each ConnectionWrapper one connection (will be expanded soon) tailored for specific exchange functionalities or asset types. // TODO: Expand this to support multiple connections per ConnectionWrapper + // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful + // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. + // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. + connectionManager []ConnectionWrapper + // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder + connections map[Connection]*ConnectionWrapper + subscriptions *subscription.Store // Subscriber function for exchange specific subscribe implementation @@ -60,6 +67,8 @@ type Websocket struct { // GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets GenerateSubs func() (subscription.List, error) + useMultiConnectionManagement bool + DataHandler chan interface{} ToRoutine chan interface{} @@ -117,6 +126,11 @@ type WebsocketSetup struct { // Local orderbook buffer config values OrderbookBufferConfig buffer.Config + // UseMultiConnectionManagement allows the connections to be managed by the + // connection manager. If false, this will default to the global fields + // provided in this struct. + UseMultiConnectionManagement bool + TradeFeed bool // Fill data config values @@ -155,7 +169,9 @@ type WebsocketConnection struct { ProxyURL string Wg *sync.WaitGroup Connection *websocket.Conn - ShutdownC chan struct{} + + // shutdown synchronises shutdown event across routines associated with this connection only e.g. ping handler + shutdown chan struct{} Match *Match ResponseMaxLimit time.Duration diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index 048585d2..fe82b4f5 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -14,7 +14,6 @@ import ( "sync" "testing" - "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" @@ -86,11 +85,6 @@ func MockHTTPInstance(e exchange.IBotExchange) error { return nil } -var upgrader = websocket.Upgrader{} - -// WsMockFunc is a websocket handler to be called with each websocket message -type WsMockFunc func(testing.TB, []byte, *websocket.Conn) error - // MockWsInstance creates a new Exchange instance with a mock websocket instance and HTTP server // It accepts an exchange package type argument and a http.HandlerFunc // See CurryWsMockUpgrader for a convenient way to curry t and a ws mock function @@ -128,33 +122,6 @@ func MockWsInstance[T any, PT interface { return e } -// CurryWsMockUpgrader curries a WsMockUpgrader with a testing.TB and a mock func -// bridging the gap between information known before the Server is created and during a request -func CurryWsMockUpgrader(tb testing.TB, wsHandler WsMockFunc) http.HandlerFunc { - tb.Helper() - return func(w http.ResponseWriter, r *http.Request) { - WsMockUpgrader(tb, w, r, wsHandler) - } -} - -// WsMockUpgrader handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input -func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHandler WsMockFunc) { - tb.Helper() - c, err := upgrader.Upgrade(w, r, nil) - require.NoError(tb, err, "Upgrade connection should not error") - defer c.Close() - for { - _, p, err := c.ReadMessage() - if websocket.IsUnexpectedCloseError(err) { - return - } - require.NoError(tb, err, "ReadMessage should not error") - - err = wsHandler(tb, p, c) - assert.NoError(tb, err, "WS Mock Function should not error") - } -} - // FixtureToDataHandler squirts the contents of a file to a reader function (probably e.wsHandleData) func FixtureToDataHandler(tb testing.TB, fixturePath string, reader func([]byte) error) { tb.Helper() diff --git a/internal/testing/exchange/exchange_test.go b/internal/testing/exchange/exchange_test.go index 0976d47b..c77848a7 100644 --- a/internal/testing/exchange/exchange_test.go +++ b/internal/testing/exchange/exchange_test.go @@ -9,6 +9,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/binance" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" ) // TestSetup exercises Setup @@ -30,6 +31,6 @@ func TestMockHTTPInstance(t *testing.T) { // TestMockWsInstance exercises MockWsInstance func TestMockWsInstance(t *testing.T) { - b := MockWsInstance[binance.Binance](t, CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *websocket.Conn) error { return nil })) + b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *websocket.Conn) error { return nil })) require.NotNil(t, b, "MockWsInstance must not be nil") } diff --git a/internal/testing/websocket/mock.go b/internal/testing/websocket/mock.go new file mode 100644 index 00000000..4bf28f03 --- /dev/null +++ b/internal/testing/websocket/mock.go @@ -0,0 +1,49 @@ +package websocket + +import ( + "net/http" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var upgrader = websocket.Upgrader{CheckOrigin: func(_ *http.Request) bool { return true }} + +// WsMockFunc is a websocket handler to be called with each websocket message +type WsMockFunc func(testing.TB, []byte, *websocket.Conn) error + +// CurryWsMockUpgrader curries a WsMockUpgrader with a testing.TB and a mock func +// bridging the gap between information known before the Server is created and during a request +func CurryWsMockUpgrader(tb testing.TB, wsHandler WsMockFunc) http.HandlerFunc { + tb.Helper() + return func(w http.ResponseWriter, r *http.Request) { + WsMockUpgrader(tb, w, r, wsHandler) + } +} + +// WsMockUpgrader handles upgrading an initial HTTP request to WS, and then runs a for loop calling the mock func on each input +func WsMockUpgrader(tb testing.TB, w http.ResponseWriter, r *http.Request, wsHandler WsMockFunc) { + tb.Helper() + c, err := upgrader.Upgrade(w, r, nil) + require.NoError(tb, err, "Upgrade connection should not error") + defer c.Close() + for { + _, p, err := c.ReadMessage() + if err != nil { + // Any error here is likely due to the connection closing + return + } + err = wsHandler(tb, p, c) + assert.NoError(tb, err, "WS Mock Function should not error") + } +} + +// EchoHandler is a simple echo function after a read, this doesn't need to worry if writing to the connection fails +func EchoHandler(_ testing.TB, p []byte, c *websocket.Conn) error { + time.Sleep(time.Nanosecond) // Shift clock to simulate time passing + _ = c.WriteMessage(websocket.TextMessage, p) + return nil +}