From 52c6b3bf0be42185bcb33d4fcc5d46d9b75e47eb Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 23 Feb 2024 08:39:25 +0100 Subject: [PATCH] Websocket: Various refactors and test improvements (#1466) * Websocket: Remove IsInit and simplify SetProxyAddress IsInit was basically the same as IsConnected. Any time Connect was called both would be set to true. Any time we had a disconnect they'd both be set to false Shutdown() incorrectly didn't setInit(false) SetProxyAddress simplified to only reconnect a connected Websocket. Any other state means it hasn't been Connected, or it's about to reconnect anyway. There's no handling for IsConnecting previously, either, so I've wrapped that behind the main mutex. * Websocket: Expand and Assertify tests * Websocket: Simplify state transistions * Websocket: Simplify Connecting/Connected state * Websocket: Tests and errors for websocket * Websocket: Make WebsocketNotEnabled a real error This allows for testing and avoids the repetition. If each returned error is a error.New() you can never use errors.Is() * Websocket: Add more testable errors * Websocket: Improve GenerateMessageID test Testing just the last id doesn't feel very robust * Websocket: Protect Setup() from races * Websocket: Use atomics instead of mutex This was spurred by looking at the setState call in trafficMonitor and the effect on blocking and efficiency. With the new atomic types in Go 1.19, and the small types in use here, atomics should be safe for our usage. bools should be truly atomic, and uint32 is atomic when the accepted value range is less than one byte/uint8 since that can be written atomicly by concurrent processors. Maybe that's not even a factor any more, however we don't even have to worry enough to check. * Websocket: Fix and simplify traffic monitor trafficMonitor had a check throttle at the end of the for loop to stop it just gobbling the (blocking) trafficAlert channel non-stop. That makes sense, except that nothing is sent to the trafficAlert channel if there's no listener. So that means that it's out by one second on the trafficAlert, because any traffic received during the pause is doesn't try to send a traffic alert. The unstopped timer is deliberately leaked for later GC when shutdown. It won't delay/block anything, and it's a trivial memory leak during an infrequent event. Deliberately Choosing to recreate the timer each time instead of using Stop, drain and reset * Websocket: Split traficMonitor test on behaviours * Websocket: Remove trafficMonitor connected status trafficMonitor does not need to set the connection to be connected. Connect() does that. Anything after that should result in a full shutdown and restart. It can't and shouldn't become connected unexpectedly, and this is most likely a race anyway. Also dropped trafficCheckInterval to 100ms to mitigate races of traffic alerts being buffered for too long. * Websocket: Set disconnected earlier in Shutdown This caused a possible race where state is still connected, but we start to trigger interested actors via ShutdownC and Wait. They may check state and then call Shutdown again, such as trafficMonitor * Websocket: Wait 5s for slow tests to pass traffic draining Keep getting failures upstream on test rigs. Think they can be very contended, so this pushes the boundary right out to 5s --- cmd/exchange_template/wrapper_file.tmpl | 2 +- engine/websocketroutine_manager_test.go | 2 +- exchanges/binance/binance_websocket.go | 2 +- exchanges/binance/binance_wrapper.go | 2 +- exchanges/binanceus/binanceus_websocket.go | 2 +- exchanges/binanceus/binanceus_wrapper.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 2 +- exchanges/bitfinex/bitfinex_websocket.go | 2 +- exchanges/bitfinex/bitfinex_wrapper.go | 2 +- exchanges/bithumb/bithumb_websocket.go | 3 +- exchanges/bithumb/bithumb_wrapper.go | 2 +- exchanges/bitmex/bitmex_test.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 2 +- exchanges/bitmex/bitmex_wrapper.go | 2 +- exchanges/bitstamp/bitstamp_websocket.go | 2 +- exchanges/bitstamp/bitstamp_wrapper.go | 2 +- exchanges/btcmarkets/btcmarkets_websocket.go | 2 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 2 +- exchanges/btse/btse_websocket.go | 2 +- exchanges/btse/btse_wrapper.go | 2 +- exchanges/bybit/bybit.go | 2 - exchanges/bybit/bybit_inverse_websocket.go | 2 +- exchanges/bybit/bybit_linear_websocket.go | 2 +- exchanges/bybit/bybit_options_websocket.go | 2 +- exchanges/bybit/bybit_test.go | 7 +- exchanges/bybit/bybit_websocket.go | 2 +- exchanges/bybit/bybit_wrapper.go | 2 +- exchanges/coinbasepro/coinbasepro_test.go | 2 +- .../coinbasepro/coinbasepro_websocket.go | 2 +- exchanges/coinbasepro/coinbasepro_wrapper.go | 2 +- exchanges/coinut/coinut_test.go | 2 +- exchanges/coinut/coinut_websocket.go | 2 +- exchanges/coinut/coinut_wrapper.go | 2 +- exchanges/exchange_test.go | 6 +- exchanges/gateio/gateio_websocket.go | 2 +- exchanges/gateio/gateio_wrapper.go | 2 +- .../gateio/gateio_ws_delivery_futures.go | 2 +- exchanges/gateio/gateio_ws_futures.go | 2 +- exchanges/gateio/gateio_ws_option.go | 2 +- exchanges/gemini/gemini_test.go | 2 +- exchanges/gemini/gemini_websocket.go | 2 +- exchanges/gemini/gemini_wrapper.go | 2 +- exchanges/hitbtc/hitbtc_test.go | 2 +- exchanges/hitbtc/hitbtc_websocket.go | 2 +- exchanges/hitbtc/hitbtc_wrapper.go | 2 +- exchanges/huobi/huobi_test.go | 2 +- exchanges/huobi/huobi_websocket.go | 2 +- exchanges/huobi/huobi_wrapper.go | 2 +- exchanges/kraken/kraken_test.go | 2 +- exchanges/kraken/kraken_websocket.go | 2 +- exchanges/kraken/kraken_wrapper.go | 2 +- exchanges/kucoin/kucoin_websocket.go | 2 +- exchanges/kucoin/kucoin_wrapper.go | 2 +- exchanges/okcoin/okcoin_websocket.go | 2 +- exchanges/okcoin/okcoin_wrapper.go | 2 +- exchanges/okcoin/okcoin_ws_trade.go | 2 +- exchanges/okx/okx_websocket.go | 2 +- exchanges/okx/okx_wrapper.go | 2 +- exchanges/poloniex/poloniex_test.go | 2 +- exchanges/poloniex/poloniex_websocket.go | 2 +- exchanges/poloniex/poloniex_wrapper.go | 2 +- .../sharedtestvalues/sharedtestvalues.go | 1 - exchanges/stream/websocket.go | 392 +++----- exchanges/stream/websocket_connection.go | 25 +- exchanges/stream/websocket_test.go | 861 +++++++----------- exchanges/stream/websocket_types.go | 25 +- 66 files changed, 574 insertions(+), 862 deletions(-) diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index d57f96b5..e74ecbc3 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -125,7 +125,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.New() + {{.Variable}}.Websocket = stream.NewWebsocket() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index e082193b..c1c3541d 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.New() + mock := stream.NewWebsocket() mock.ToRoutine = make(chan interface{}) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 52cdd0cc..bd96d546 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -50,7 +50,7 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index bdf70e0d..d54f8924 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -238,7 +238,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8f4d5c3c..14098c11 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 3f078ede..ed2e5d10 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -162,7 +162,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.New() + bi.Websocket = stream.NewWebsocket() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index c2bba90d..e7be0fbe 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1128,7 +1128,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index e1010eb1..ae7cded7 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -43,7 +43,7 @@ var cMtx sync.Mutex // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 7e1ef64c..2aba9987 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -198,7 +198,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 3005f423..667ce131 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -2,7 +2,6 @@ package bithumb import ( "encoding/json" - "errors" "fmt" "net/http" "time" @@ -29,7 +28,7 @@ var ( // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 5dbc3146..24ea8980 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -150,7 +150,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index b59cb125..e825ad8a 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -789,7 +789,7 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6d04c106..e1a47525 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -68,7 +68,7 @@ const ( // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 151f1ef2..a0810b1a 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -175,7 +175,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab5465b5..98aa6201 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 888fae0b..2fe7fa96 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -146,7 +146,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 8c979b4d..01ba1a64 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -39,7 +39,7 @@ var ( // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 17ee7277..8a924b08 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -150,7 +150,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 41f25c95..e32fb9a8 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -30,7 +30,7 @@ const ( // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index f7ec3c8f..3cdbd56b 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -176,7 +176,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 5652c464..3588336f 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -21,7 +21,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Bybit is the overarching type across this package @@ -90,7 +89,6 @@ var ( errAPIKeyIsNotUnified = errors.New("api key is not unified") errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") - errWebsocketNotEnabled = errors.New(stream.WebsocketNotEnabled) ) var ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 77f387ac..d1387c27 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -12,7 +12,7 @@ import ( // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index efc2f68d..9b3ed084 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -14,7 +14,7 @@ import ( // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 2f4abc7a..4bb25cef 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -14,7 +14,7 @@ import ( // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 58ca5bc4..092ca125 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -3064,7 +3065,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3074,7 +3075,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3084,7 +3085,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 857fd690..ff2698b8 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -57,7 +57,7 @@ const ( // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 28d4f150..b6e7f2be 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -216,7 +216,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.New() + by.Websocket = stream.NewWebsocket() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 77f05838..0b04d4ab 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -681,7 +681,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index e4b02b76..5946cf77 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -31,7 +31,7 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 47a20d9e..21a34e2e 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -145,7 +145,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 3431d60f..0b165327 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -66,7 +66,7 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 2453816e..78b38987 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -41,7 +41,7 @@ var ( // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 503c2909..db4af53b 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -127,7 +127,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index d41f499d..ad006049 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) { Name: "rawr", Requester: requester} - newBase.Websocket = stream.New() + newBase.Websocket = stream.NewWebsocket() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.Features.Supports.Websocket = true err = b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() err := b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c26d04af..3a3dddec 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -60,7 +60,7 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection func (g *Gateio) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3026efc0..adea4d2f 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -194,7 +194,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index ba5c64af..449181c1 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -45,7 +45,7 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index c0411a58..20e293b9 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -64,7 +64,7 @@ var responseFuturesStream = make(chan stream.Response) // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index d5340f0c..3278914f 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -70,7 +70,7 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 6a3b35e1..7477e78f 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -556,7 +556,7 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer go g.wsReadData() diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 913c856d..43c2135e 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -39,7 +39,7 @@ var comms = make(chan stream.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index fee75d6b..d2d89eb5 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -128,7 +128,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 3e629d0a..68b84951 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -466,7 +466,7 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 705f584f..deb88542 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -34,7 +34,7 @@ const ( // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3b65fe64..7bcf7ada 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -147,7 +147,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 4aafc7b7..16106daf 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -78,7 +78,7 @@ func setupWsTests(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } comms = make(chan WsMessage, sharedtestvalues.WebsocketChannelOverrideCapacity) go h.wsReadData() diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index f601b03d..92b5bf4c 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -62,7 +62,7 @@ var comms = make(chan WsMessage) // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.wsDial(&dialer) diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3d1d576b..90d70491 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -202,7 +202,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 263d70bf..8530cb53 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1215,7 +1215,7 @@ func setupWsTests(t *testing.T) { return } if !k.Websocket.IsEnabled() && !k.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(k) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 787d52b2..fd432516 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -87,7 +87,7 @@ var cancelOrdersStatus = make(map[int64]*struct { // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index de631e0d..58759175 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -209,7 +209,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.New() + k.Websocket = stream.NewWebsocket() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 3f8a0c78..917bc1a9 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -97,7 +97,7 @@ var ( // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } fetchedFuturesSnapshotOrderbook = map[string]bool{} var dialer websocket.Dialer diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 8a98a77a..d767a2cf 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -195,7 +195,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.New() + ku.Websocket = stream.NewWebsocket() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index e714aba7..0787775e 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -74,7 +74,7 @@ func isAuthenticatedChannel(channel string) bool { // WsConnect initiates a websocket connection func (o *Okcoin) WsConnect() error { if !o.Websocket.IsEnabled() || !o.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 8519cdc1..af0655ef 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -150,7 +150,7 @@ func (o *Okcoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - o.Websocket = stream.New() + o.Websocket = stream.NewWebsocket() o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index cd0c7f86..85b6102e 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -130,7 +130,7 @@ func (o *Okcoin) WsAmendMultipleOrder(args []AmendTradeOrderRequestParam) ([]Ame func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}, authenticated bool) error { switch { case !o.Websocket.IsEnabled(): - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled case !o.Websocket.IsConnected(): return stream.ErrNotConnected case !o.Websocket.CanUseAuthenticatedEndpoints() && authenticated: diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index b4d211ee..54bc6b0e 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -216,7 +216,7 @@ const ( // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 0b5472b9..64e2e269 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -190,7 +190,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.New() + ok.Websocket = stream.NewWebsocket() ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index bab0ffd4..d74e2353 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -548,7 +548,7 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 1be429a5..23774335 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -55,7 +55,7 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index 97eb9293..57f28fac 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -159,7 +159,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.New() + p.Websocket = stream.NewWebsocket() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 60d51d82..54acf9dc 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *stream.Websocket { return &stream.Websocket{ - Init: true, DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity), ToRoutine: make(chan interface{}, 1000), TrafficAlert: make(chan struct{}), diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index aefc3400..404d76d8 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -16,47 +16,43 @@ import ( ) const ( - defaultJobBuffer = 5000 - // defaultTrafficPeriod defines a period of pause for the traffic monitor, - // as there are periods with large incoming traffic alerts which requires a - // timer reset, this limits work on this routine to a more effective rate - // of check. - defaultTrafficPeriod = time.Second + jobBuffer = 5000 ) +// Public websocket errors var ( - // ErrSubscriptionNotFound defines an error when a subscription is not found - ErrSubscriptionNotFound = errors.New("subscription not found") - // ErrSubscribedAlready defines an error when a channel is already subscribed - ErrSubscribedAlready = errors.New("duplicate subscription") - // ErrSubscriptionFailure defines an error when a subscription fails - ErrSubscriptionFailure = errors.New("subscription failure") - // ErrSubscriptionNotSupported defines an error when a subscription channel is not supported by an exchange + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrSubscriptionNotFound = errors.New("subscription not found") + ErrSubscribedAlready = errors.New("duplicate subscription") + ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") - // ErrUnsubscribeFailure defines an error when a unsubscribe fails - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - // ErrChannelInStateAlready defines an error when a subscription channel is already in a new state - ErrChannelInStateAlready = errors.New("channel already in state") - // ErrAlreadyDisabled is returned when you double-disable the websocket - ErrAlreadyDisabled = errors.New("websocket already disabled") - // ErrNotConnected defines an error when websocket is not connected - ErrNotConnected = errors.New("websocket is not connected") + ErrUnsubscribeFailure = errors.New("unsubscribe failure") + ErrChannelInStateAlready = errors.New("channel already in state") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") +) +// Private websocket errors +var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") + errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") errConfigFeaturesIsNil = errors.New("exchange config features is nil") errDefaultURLIsEmpty = errors.New("default url is empty") errRunningURLIsEmpty = errors.New("running url cannot be empty") errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameUnset = errors.New("exchange config name unset") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") @@ -64,9 +60,18 @@ var ( errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") ) -var globalReporter Reporter +var ( + globalReporter Reporter + trafficCheckInterval = 100 * time.Millisecond +) // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -74,13 +79,12 @@ func SetupGlobalReporter(r Reporter) { globalReporter = r } -// New initialises the websocket struct -func New() *Websocket { +// NewWebsocket initialises the websocket struct +func NewWebsocket() *Websocket { return &Websocket{ - Init: true, - DataHandler: make(chan interface{}, defaultJobBuffer), - ToRoutine: make(chan interface{}, defaultJobBuffer), - TrafficAlert: make(chan struct{}), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), Subscribe: make(chan []subscription.Subscription), Unsubscribe: make(chan []subscription.Subscription), @@ -98,7 +102,10 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if !w.Init { + w.m.Lock() + defer w.m.Unlock() + + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -107,7 +114,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameUnset + return errExchangeConfigNameEmpty } w.exchangeName = s.ExchangeConfig.Name w.verbose = s.ExchangeConfig.Verbose @@ -120,7 +127,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -188,28 +195,30 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + w.setState(disconnected) + return nil } // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { - return errors.New("setting up new connection error: websocket is nil") + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } if c == (ConnectionSetup{}) { - return errors.New("setting up new connection error: websocket connection configuration empty") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } if w.exchangeName == "" { - return errors.New("setting up new connection error: exchange name not set, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } if w.TrafficAlert == nil { - return errors.New("setting up new connection error: traffic alert is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } if w.ReadMessageErrors == nil { - return errors.New("setting up new connection error: read message errors is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } connectionURL := w.GetWebsocketURL() @@ -253,21 +262,19 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // function func (w *Websocket) Connect() error { if w.connector == nil { - return errors.New("websocket connect function not set, cannot continue") + return errNoConnectFunc } w.m.Lock() defer w.m.Unlock() if !w.IsEnabled() { - return errors.New(WebsocketNotEnabled) + return ErrWebsocketNotEnabled } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } w.subscriptionMutex.Lock() @@ -276,25 +283,19 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setConnectingStatus(true) + w.setState(connecting) err := w.connector() if err != nil { - w.setConnectingStatus(false) - return fmt.Errorf("%v Error connecting %s", - w.exchangeName, err) + w.setState(disconnected) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setConnectedStatus(true) - w.setConnectingStatus(false) - w.setInit(true) + w.setState(connected) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() if err != nil { - log.Errorf(log.WebsocketMgr, - "%s cannot start websocket connection monitor %v", - w.GetName(), - err) + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -317,9 +318,10 @@ func (w *Websocket) Connect() error { } // Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { - return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) } w.setEnabled(false) @@ -329,8 +331,7 @@ func (w *Websocket) Disable() error { // Enable enables the exchange websocket protocol func (w *Websocket) Enable() error { if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("websocket is already enabled for exchange %s", - w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) } w.setEnabled(true) @@ -369,9 +370,7 @@ func (w *Websocket) dataMonitor() { case <-w.ShutdownC: return default: - log.Warnf(log.WebsocketMgr, - "%s exchange backlog in websocket processing detected", - w.exchangeName) + log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) select { case w.ToRoutine <- d: case <-w.ShutdownC: @@ -388,34 +387,25 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) for { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: running connection monitor cycle\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) } if !w.IsEnabled() { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connectionMonitor - websocket disabled, shutting down\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) } if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) } } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connection monitor exiting\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) } timer.Stop() w.setConnectionMonitorRunning(false) @@ -424,11 +414,8 @@ func (w *Websocket) connectionMonitor() error { select { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { - w.setInit(false) - log.Warnf(log.WebsocketMgr, - "%v websocket has been disconnected. Reason: %v", - w.exchangeName, err) - w.setConnectedStatus(false) + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) + w.setState(disconnected) } w.DataHandler <- err @@ -459,21 +446,16 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", - w.exchangeName, - ErrNotConnected) + return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", - w.exchangeName) + return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: shutting down websocket\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) } defer w.Orderbook.FlushBuffer() @@ -495,15 +477,13 @@ func (w *Websocket) Shutdown() error { w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() + w.setState(disconnected) + close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setConnectedStatus(false) - w.setConnectingStatus(false) if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: completed websocket shutdown\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } return nil } @@ -511,11 +491,11 @@ func (w *Websocket) Shutdown() error { // FlushChannels flushes channel subscriptions when there is a pair/asset change func (w *Websocket) FlushChannels() error { if !w.IsEnabled() { - return fmt.Errorf("%s websocket: service not enabled", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { @@ -565,9 +545,9 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } -// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires, -// it will reconnect if the TrafficAlert channel has not received any data. The -// trafficTimer will reset on each traffic alert +// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert +// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic +// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again func (w *Websocket) trafficMonitor() { if w.IsTrafficMonitorRunning() { return @@ -576,183 +556,121 @@ func (w *Websocket) trafficMonitor() { w.Wg.Add(1) go func() { - var trafficTimer = time.NewTimer(w.trafficTimeout) - pause := make(chan struct{}) + t := time.NewTimer(w.trafficTimeout) for { select { case <-w.ShutdownC: if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown message received\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } - trafficTimer.Stop() + t.Stop() w.setTrafficMonitorRunning(false) w.Wg.Done() return - case <-w.TrafficAlert: - if !trafficTimer.Stop() { - select { - case <-trafficTimer.C: - default: + case <-time.After(trafficCheckInterval): + select { + case <-w.TrafficAlert: + if !t.Stop() { + <-t.C } + t.Reset(w.trafficTimeout) + default: + } + case <-t.C: + checkAgain := w.IsConnecting() + select { + case <-w.TrafficAlert: + checkAgain = true + default: + } + if checkAgain { + t.Reset(w.trafficTimeout) + break } - w.setConnectedStatus(true) - trafficTimer.Reset(w.trafficTimeout) - case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { - log.Warnf(log.WebsocketMgr, - "%v websocket: has not received a traffic alert in %v. Reconnecting", - w.exchangeName, - w.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() // without this the w.Shutdown() call below will deadlock - if !w.IsConnecting() && w.IsConnected() { + w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if w.IsConnected() { err := w.Shutdown() if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown err: %s", - w.exchangeName, err) + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } - return } - - if w.IsConnected() { - // Routine pausing mechanism - go func(p chan<- struct{}) { - time.Sleep(defaultTrafficPeriod) - select { - case p <- struct{}{}: - default: - } - }(pause) - select { - case <-w.ShutdownC: - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-pause: - } - } } }() } -func (w *Websocket) setConnectedStatus(b bool) { - w.fieldMutex.Lock() - w.connected = b - w.fieldMutex.Unlock() +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -// IsConnected returns status of connection +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialised +} + +// IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connected + return w.state.Load() == connected } -func (w *Websocket) setConnectingStatus(b bool) { - w.fieldMutex.Lock() - w.connecting = b - w.fieldMutex.Unlock() -} - -// IsConnecting returns status of connecting +// IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connecting + return w.state.Load() == connecting } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } -// IsEnabled returns status of enabled +// IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled -} - -func (w *Websocket) setInit(b bool) { - w.fieldMutex.Lock() - w.Init = b - w.fieldMutex.Unlock() -} - -// IsInit returns status of init -func (w *Websocket) IsInit() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.Init + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() && w.CanUseAuthenticatedEndpoints() { - return true - } else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() { - log.Infof(log.WebsocketMgr, - WebsocketNotAuthenticatedUsingRest, - w.exchangeName) + if w.IsConnected() { + if w.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) } return false } @@ -820,28 +738,22 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { + w.m.Lock() + if proxyAddr != "" { - _, err := url.ParseRequestURI(proxyAddr) - if err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", - w.exchangeName, - err) + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", - w.exchangeName, - w.proxyAddr) + w.m.Unlock() + return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } - log.Debugf(log.ExchangeSys, - "%s websocket: setting websocket proxy: %s\n", - w.exchangeName, - proxyAddr) + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) } else { - log.Debugf(log.ExchangeSys, - "%s websocket: removing websocket proxy\n", - w.exchangeName) + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } if w.Conn != nil { @@ -852,15 +764,17 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { } w.proxyAddr = proxyAddr - if w.IsInit() && w.IsEnabled() { - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - return err - } + + if w.IsConnected() { + w.m.Unlock() + if err := w.Shutdown(); err != nil { + return err } return w.Connect() } + + w.m.Unlock() + return nil } @@ -1035,20 +949,14 @@ func (w *Websocket) GetSubscriptions() []subscription.Subscription { return subs } -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in -// a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - w.canUseAuthenticatedEndpoints = val +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { + w.canUseAuthenticatedEndpoints.Store(b) } -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in -// a thread safe manner +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.canUseAuthenticatedEndpoints + return w.canUseAuthenticatedEndpoints.Load() } // IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 0bb1e660..4d7681f8 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -50,9 +50,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter return payload, nil case <-timer.C: timer.Stop() - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", - w.ExchangeName, - signature) + return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } @@ -72,25 +70,14 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { - return fmt.Errorf("%s websocket connection: %v %v %v Error: %v", - w.ExchangeName, - w.URL, - conStatus, - conStatus.StatusCode, - err) + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } - return fmt.Errorf("%s websocket connection: %v Error: %v", - w.ExchangeName, - w.URL, - err) + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } defer conStatus.Body.Close() if w.Verbose { - log.Infof(log.WebsocketMgr, - "%v Websocket connected to %s\n", - w.ExchangeName, - w.URL) + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) } select { case w.Traffic <- struct{}{}: @@ -240,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // causes contention, just bypass if there is no receiver. + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte @@ -285,7 +272,7 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a messageID to checkout +// GenerateMessageID Creates a random message ID func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { var min int64 = 1e8 var max int64 = 2e8 diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ab49e0d..0d4e9c02 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -18,6 +19,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -30,6 +32,10 @@ const ( proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) +var ( + errDastardlyReason = errors.New("some dastardly reason") +) + var dialer websocket.Dialer type testStruct struct { @@ -68,7 +74,7 @@ var defaultSetup = &WebsocketSetup{ AuthenticatedWebsocketSupport: true, }, WebsocketTrafficTimeout: time.Second * 5, - Name: "exchangeName", + Name: "GTX", }, DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", @@ -92,416 +98,355 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errors.New("cannot shutdown due to some dastardly reason") + return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errors.New("cannot connect due to some dastardly reason") + return fmt.Errorf("cannot connect: %w", errDastardlyReason) +} + +func TestMain(m *testing.M) { + // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing + trafficCheckInterval = 50 * time.Millisecond + os.Exit(m.Run()) } func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - if !errors.Is(err, errWebsocketIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketIsNil) - } + assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - if !errors.Is(err, errWebsocketSetupIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSetupIsNil) - } + assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") websocketSetup := &WebsocketSetup{} - err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - w.Init = true err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigNameUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigNameUnset) - } - websocketSetup.ExchangeConfig.Name = "testname" + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") + websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketFeaturesIsUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketFeaturesIsUnset) - } + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - if !errors.Is(err, errConfigFeaturesIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigFeaturesIsNil) - } + assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketConnectorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketConnectorUnset) - } + assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketUnsubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errDefaultURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) - } + assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - if !errors.Is(err, errRunningURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRunningURLIsEmpty) - } + assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidTrafficTimeout) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidTrafficTimeout) - } + assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + assert.NoError(t, err, "Setup should not error") } -func TestTrafficMonitorTimeout(t *testing.T) { +// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval +// ensures shutdown works after traffic alerts +func TestTrafficMonitorTrafficAlerts(t *testing.T) { t.Parallel() - ws := *New() - if err := ws.Setup(defaultSetup); err != nil { - t.Fatal(err) - } - ws.trafficTimeout = time.Second * 2 + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + signal := struct{}{} + patience := 10 * time.Millisecond + ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) + + thenish := time.Now() ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") + + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + + for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass + select { + case ws.TrafficAlert <- signal: + if i == 0 { + require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") + } + default: + require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) + } + + select { + case ws.TrafficAlert <- signal: + require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) + default: + if i == 0 { + require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") + } + } + + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 5*time.Second, patience, "trafficAlert should be drained; Check #%d", i) + assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) } - // Deploy traffic alert - ws.TrafficAlert <- struct{}{} - // try to add another traffic monitor + + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") +} + +// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown +func TestTrafficMonitorConnecting(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.ShutdownC = make(chan struct{}) + ws.state.Store(connecting) + ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") + require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + <-time.After(4 * ws.trafficTimeout) + require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connected) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") +} + +// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared +func TestTrafficMonitorShutdown(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) + ws.trafficTimeout = time.Minute + ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + + wgReady := make(chan bool) + go func() { + ws.Wg.Wait() + close(wgReady) + }() + select { + case <-wgReady: + require.Failf(t, "", "WaitGroup should be blocking still") + case <-time.After(trafficCheckInterval): } - // prevent shutdown routine - ws.setConnectedStatus(false) - // await timeout closure - ws.Wg.Wait() - if ws.IsTrafficMonitorRunning() { - t.Error("should be dead") + + close(ws.ShutdownC) + + <-time.After(2 * trafficCheckInterval) + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") + select { + case <-wgReady: + default: + require.Failf(t, "", "WaitGroup should be freed now") } } func TestIsDisconnectionError(t *testing.T) { t.Parallel() - isADisconnectionError := IsDisconnectionError(errors.New("errorText")) - if isADisconnectionError { - t.Error("Its not") - } - isADisconnectionError = IsDisconnectionError(&websocket.CloseError{ - Code: 1006, - Text: "errorText", - }) - if !isADisconnectionError { - t.Error("It is") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errClosedConnection, - }) - if isADisconnectionError { - t.Error("It's not") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errors.New("errText"), - }) - if !isADisconnectionError { - t.Error("It is") - } + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") } func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} err := wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") wsWrong.connector = func() error { return nil } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) - wsWrong.setConnectingStatus(true) + wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") - wsWrong.setConnectedStatus(false) - wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } + wsWrong.setState(disconnected) + wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") - ws := *New() + ws := NewWebsocket() err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute ws.connector = func() error { return nil } err = ws.Connect() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Connect must not error") ws.TrafficAlert <- struct{}{} - timer := time.NewTimer(900 * time.Millisecond) - ws.ReadMessageErrors <- errors.New("errorText") - select { - case err := <-ws.ToRoutine: - errText, ok := err.(error) - if !ok { - t.Error("unable to type assert error") - } else if errText.Error() != "errorText" { - t.Errorf("Expected 'errorText', received %v", err) - } - case <-timer.C: - t.Error("Timeout waiting for datahandler to receive error") - } - ws.ReadMessageErrors <- &websocket.CloseError{ - Code: 1006, - Text: "errorText", - } -outer: - for { + c := func(tb *assert.CollectT) { select { - case err := <-ws.ToRoutine: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("Error is not a disconnection error: %v", err) + case v := <-ws.ToRoutine: + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") } - case <-timer.C: - break outer + default: } } + + ws.ReadMessageErrors <- errDastardlyReason + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + + ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) { t.Parallel() - wsInit := Websocket{} - err := wsInit.Setup(&WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, - }, - Name: "test", - }, - }) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - ws := *New() - err = ws.SetProxyAddress("garbagio") - if err == nil { - t.Error("error cannot be nil") - } + ws := NewWebsocket() - ws.Conn = &WebsocketConnection{} + err := ws.SetProxyAddress("garbagio") + assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") + + ws.Conn = &dodgyConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + err = ws.Setup(defaultSetup) + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") + + assert.Equal(t, "GTX", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") + + ws.setEnabled(false) + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + + ws.setEnabled(true) + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") + err = ws.SetProxyAddress("https://192.168.0.1:1337") - if err == nil { - t.Error("error cannot be nil") - } - ws.setConnectedStatus(true) - ws.ShutdownC = make(chan struct{}) - ws.Wg = &sync.WaitGroup{} - err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") + + ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } - ws.setEnabled(false) + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") + + err = ws.SetProxyAddress("https://192.168.0.1:1336") + assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") // removing proxy err = ws.SetProxyAddress("") - if err != nil { - t.Error(err) - } - // reinstate proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err != nil { - t.Error(err) - } - // conflict proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err == nil { - t.Error("error cannot be nil") - } - err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } - if ws.GetName() != "exchangeName" { - t.Error("WebsocketSetup") - } - - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } - - ws.setEnabled(false) - if ws.IsEnabled() { - t.Error("WebsocketSetup") - } - ws.setEnabled(true) - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } - - if ws.GetProxyAddress() != "http://localhost:1337" { - t.Error("WebsocketSetup") - } - - if ws.GetWebsocketURL() != "wss://testRunningURL" { - t.Error("WebsocketSetup") - } - if ws.trafficTimeout != time.Second*5 { - t.Error("WebsocketSetup") - } - // -- Not connected shutdown - err = ws.Shutdown() - if err == nil { - t.Fatal("should not be connected to able to shut down") - } - - ws.setConnectedStatus(true) - ws.Conn = &dodgyConnection{} - err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") + assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") ws.Conn = &WebsocketConnection{} + ws.setEnabled(true) - ws.setConnectedStatus(true) + // reinstate proxy + err = ws.SetProxyAddress("http://localhost:1337") + assert.NoError(t, err, "SetProxyAddress should not error") + assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") + assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") + assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") + + ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil ") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") + assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setConnectedStatus(false) + ws.setState(disconnected) - // -- Normal connect err = ws.Connect() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" err = ws.SetWebsocketURL("", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("", true, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false) - if err != nil { - t.Fatal(err) - } - // Attempt reconnect + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") + // -- initiate the reconnect which is usually handled by connection monitor err = ws.Connect() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "ReConnect called manually should not error") + err = ws.Connect() - if err == nil { - t.Fatal("should already be connected") - } - // -- Normal shutdown + assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") + err = ws.Shutdown() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() } // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") fnSub := func(subs []subscription.Subscription) error { @@ -546,7 +491,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() wackedOutSetup := *defaultSetup wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 @@ -577,7 +522,7 @@ func TestResubscribe(t *testing.T) { // TestSubscriptionState tests Subscription state changes func TestSubscriptionState(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") @@ -603,7 +548,7 @@ func TestSubscriptionState(t *testing.T) { // TestRemoveSubscriptions tests removing a subscription func TestRemoveSubscriptions(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") @@ -616,24 +561,18 @@ func TestRemoveSubscriptions(t *testing.T) { // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() ws.connectionMonitorDelay = 500 ws.DataHandler = make(chan interface{}, 1) ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } - if !ws.IsConnectionMonitorRunning() { - t.Fatal("Should not have exited") - } + require.NoError(t, err, "connectionMonitor must not error") + assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() - if !errors.Is(err, errAlreadyRunning) { - t.Fatalf("received: %v, but expected: %v", err, errAlreadyRunning) - } + assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") } // TestGetSubscription logic test @@ -671,16 +610,10 @@ func TestGetSubscriptions(t *testing.T) { // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := *New() - result := ws.CanUseAuthenticatedEndpoints() - if result { - t.Error("expected `canUseAuthenticatedEndpoints` to be false") - } + ws := NewWebsocket() + assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) - result = ws.CanUseAuthenticatedEndpoints() - if !result { - t.Error("expected `canUseAuthenticatedEndpoints` to be true") - } + assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") } // TestDial logic test @@ -917,81 +850,53 @@ func TestParseBinaryResponse(t *testing.T) { } var b bytes.Buffer - w := gzip.NewWriter(&b) - _, err := w.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - err = w.Close() - if err != nil { - t.Error(err) - } - var resp []byte + g := gzip.NewWriter(&b) + _, err := g.Write([]byte("hello")) + require.NoError(t, err, "gzip.Write must not error") + assert.NoError(t, g.Close(), "Close should not error") + + resp, err := wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing gzip") + assert.EqualValues(t, "hello", resp, "parseBinaryResponse should decode gzip") + + b.Reset() + f, err := flate.NewWriter(&b, 1) + require.NoError(t, err, "flate.NewWriter must not error") + _, err = f.Write([]byte("goodbye")) + require.NoError(t, err, "flate.Write must not error") + assert.NoError(t, f.Close(), "Close should not error") + resp, err = wc.parseBinaryResponse(b.Bytes()) - if err != nil { - t.Error(err) - } - if !strings.EqualFold(string(resp), "hello") { - t.Errorf("GZip conversion failed. Received: '%v', Expected: 'hello'", string(resp)) - } + assert.NoError(t, err, "parseBinaryResponse should not error parsing inflate") + assert.EqualValues(t, "goodbye", resp, "parseBinaryResponse should deflate") - var b2 bytes.Buffer - w2, err2 := flate.NewWriter(&b2, 1) - if err2 != nil { - t.Error(err2) - } - _, err2 = w2.Write([]byte("hello")) - if err2 != nil { - t.Error(err) - } - err2 = w2.Close() - if err2 != nil { - t.Error(err) - } - resp2, err3 := wc.parseBinaryResponse(b2.Bytes()) - if err3 != nil { - t.Error(err3) - } - if !strings.EqualFold(string(resp2), "hello") { - t.Errorf("Deflate conversion failed. Received: '%v', Expected: 'hello'", string(resp2)) - } - - _, err4 := wc.parseBinaryResponse([]byte{}) - if err4 == nil || err4.Error() != "unexpected EOF" { - t.Error("Expected error 'unexpected EOF'") - } + _, err = wc.parseBinaryResponse([]byte{}) + assert.ErrorContains(t, err, "unexpected EOF", "parseBinaryResponse should error on empty input") } // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() ws := &Websocket{} - resp := ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is false") - } - ws.setConnectedStatus(true) - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") - } - ws.canUseAuthenticatedEndpoints = true - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if !resp { - t.Error("Expected true, `connected` and `CanUseAuthenticatedEndpoints` is true") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.setState(connected) + require.True(t, ws.IsConnected(), "IsConnected must return true") + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.SetCanUseAuthenticatedEndpoints(true) + assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } func TestGenerateMessageID(t *testing.T) { t.Parallel() wc := WebsocketConnection{} - var id int64 - for i := 0; i < 10; i++ { - newID := wc.GenerateMessageID(true) - if id == newID { - t.Fatal("ID generation is not unique") - } - id = newID + const spins = 1000 + ids := make([]int64, spins) + for i := 0; i < spins; i++ { + id := wc.GenerateMessageID(true) + assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + ids[i] = id } } @@ -1013,34 +918,22 @@ func BenchmarkGenerateMessageID_Low(b *testing.B) { func TestCheckWebsocketURL(t *testing.T) { err := checkWebsocketURL("") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on empty string") err = checkWebsocketURL("wowowow:wowowowo") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on bad format") err = checkWebsocketURL("://") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "missing protocol scheme", "checkWebsocketURL should error correctly on bad proto") err = checkWebsocketURL("http://www.google.com") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on wrong proto") err = checkWebsocketURL("wss://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") err = checkWebsocketURL("ws://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") } func TestGetChannelDifference(t *testing.T) { @@ -1142,19 +1035,13 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - connected: true, + w := Websocket{ connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -1167,9 +1054,11 @@ func TestFlushChannels(t *testing.T) { // in FlushChannels() so the traffic monitor doesn't time out and turn // this to an unconnected state } + w.setEnabled(true) + w.setState(connected) problemFunc := func() ([]subscription.Subscription, error) { - return nil, errors.New("problems") + return nil, errDastardlyReason } noSub := func() ([]subscription.Subscription, error) { @@ -1179,53 +1068,40 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { + w.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{{Channel: "test"}}, nil } - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to unsub + assert.NoError(t, err, "FlushChannels should not error") - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() - if err != nil { - t.Fatal(err) - } - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() + require.NoError(t, err, "GenerateSubs must not error") - web.GenerateSubs = problemFunc - err = web.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + w.AddSuccessfulSubscriptions(subs...) + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ + w.GenerateSubs = problemFunc + err = w.FlushChannels() + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") + + w.GenerateSubs = newgen.generateSubs + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") + w.subscriptionMutex.Lock() + w.subscriptions = subscriptionMap{ 41: { Key: 41, Channel: "match channel", @@ -1237,46 +1113,34 @@ func TestFlushChannels(t *testing.T) { Pair: currency.NewPair(currency.THETA, currency.USDT), }, } - web.subscriptionMutex.Unlock() + w.subscriptionMutex.Unlock() - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - web.setConnectedStatus(true) - web.features.Unsubscribe = true - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + w.setState(connected) + w.features.Unsubscribe = true + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - connected: true, + w := Websocket{ ShutdownC: make(chan struct{}), } - err := web.Disable() - if err != nil { - t.Fatal(err) - } - err = web.Disable() - if err == nil { - t.Fatal("should already be disabled") - } + w.setEnabled(true) + w.setState(connected) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ + w := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), @@ -1286,98 +1150,59 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - err := web.Enable() - if err != nil { - t.Fatal(err) - } - - err = web.Enable() - if err == nil { - t.Fatal("should already be enabled") - } - - fmt.Print() + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") - nonsenseWebsock.TrafficAlert = make(chan struct{}) + nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - Init: true, - TrafficAlert: make(chan struct{}), - ReadMessageErrors: make(chan error), - DataHandler: make(chan interface{}), - } + web := NewWebsocket() err = web.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Setup should not error") + err = web.SetupNewConnection(ConnectionSetup{}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err != nil { - t.Fatal(err) - } - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", - Authenticated: true}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetupNewConnection should not error") + + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + assert.NoError(t, err, "SetupNewConnection should not error") } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() wc := WebsocketConnection{} err := wc.Shutdown() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") wc.URL = websocketTestURL err = wc.Dial(&websocket.Dialer{}, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial must not error") err = wc.Shutdown() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Shutdown must not error") } // TestLatency logic test @@ -1431,27 +1256,19 @@ func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} err := ws.checkSubscriptions(nil) - if !errors.Is(err, errNoSubscriptionsSupplied) { - t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied) - } + assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 1 err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) - if !errors.Is(err, errSubscriptionsExceedsLimit) { - t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit) - } + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - if !errors.Is(err, errChannelAlreadySubscribed) { - t.Fatalf("received: %v, but expected: %v", err, errChannelAlreadySubscribed) - } + assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly") err = ws.checkSubscriptions([]subscription.Subscription{{}}) - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } + assert.NoError(t, err, "checkSubscriptions should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 925c34b9..a783d585 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -15,8 +16,6 @@ import ( // Websocket functionality list and state consts const ( - // WebsocketNotEnabled alerts of a disabled websocket - WebsocketNotEnabled = "exchange_websocket_not_enabled" WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" Ping = "ping" Pong = "pong" @@ -25,18 +24,23 @@ const ( type subscriptionMap map[any]*subscription.Subscription +const ( + uninitialised uint32 = iota + disconnected + connecting + connected +) + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { - canUseAuthenticatedEndpoints bool - enabled bool - Init bool - connected bool - connecting bool + canUseAuthenticatedEndpoints atomic.Bool + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -46,7 +50,6 @@ type Websocket struct { runningURLAuth string exchangeName string m sync.Mutex - fieldMutex sync.RWMutex connector func() error subscriptionMutex sync.RWMutex