diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index d57f96b5..e74ecbc3 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -125,7 +125,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.New() + {{.Variable}}.Websocket = stream.NewWebsocket() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index e082193b..c1c3541d 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.New() + mock := stream.NewWebsocket() mock.ToRoutine = make(chan interface{}) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 52cdd0cc..bd96d546 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -50,7 +50,7 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index bdf70e0d..d54f8924 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -238,7 +238,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8f4d5c3c..14098c11 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 3f078ede..ed2e5d10 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -162,7 +162,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.New() + bi.Websocket = stream.NewWebsocket() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index c2bba90d..e7be0fbe 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1128,7 +1128,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index e1010eb1..ae7cded7 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -43,7 +43,7 @@ var cMtx sync.Mutex // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 7e1ef64c..2aba9987 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -198,7 +198,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 3005f423..667ce131 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -2,7 +2,6 @@ package bithumb import ( "encoding/json" - "errors" "fmt" "net/http" "time" @@ -29,7 +28,7 @@ var ( // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 5dbc3146..24ea8980 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -150,7 +150,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index b59cb125..e825ad8a 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -789,7 +789,7 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6d04c106..e1a47525 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -68,7 +68,7 @@ const ( // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 151f1ef2..a0810b1a 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -175,7 +175,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab5465b5..98aa6201 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 888fae0b..2fe7fa96 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -146,7 +146,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 8c979b4d..01ba1a64 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -39,7 +39,7 @@ var ( // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 17ee7277..8a924b08 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -150,7 +150,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 41f25c95..e32fb9a8 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -30,7 +30,7 @@ const ( // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index f7ec3c8f..3cdbd56b 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -176,7 +176,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 5652c464..3588336f 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -21,7 +21,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Bybit is the overarching type across this package @@ -90,7 +89,6 @@ var ( errAPIKeyIsNotUnified = errors.New("api key is not unified") errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") - errWebsocketNotEnabled = errors.New(stream.WebsocketNotEnabled) ) var ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 77f387ac..d1387c27 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -12,7 +12,7 @@ import ( // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index efc2f68d..9b3ed084 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -14,7 +14,7 @@ import ( // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 2f4abc7a..4bb25cef 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -14,7 +14,7 @@ import ( // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 58ca5bc4..092ca125 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -3064,7 +3065,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3074,7 +3075,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3084,7 +3085,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 857fd690..ff2698b8 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -57,7 +57,7 @@ const ( // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 28d4f150..b6e7f2be 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -216,7 +216,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.New() + by.Websocket = stream.NewWebsocket() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 77f05838..0b04d4ab 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -681,7 +681,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index e4b02b76..5946cf77 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -31,7 +31,7 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 47a20d9e..21a34e2e 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -145,7 +145,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 3431d60f..0b165327 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -66,7 +66,7 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 2453816e..78b38987 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -41,7 +41,7 @@ var ( // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 503c2909..db4af53b 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -127,7 +127,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index d41f499d..ad006049 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) { Name: "rawr", Requester: requester} - newBase.Websocket = stream.New() + newBase.Websocket = stream.NewWebsocket() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.Features.Supports.Websocket = true err = b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() err := b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c26d04af..3a3dddec 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -60,7 +60,7 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection func (g *Gateio) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3026efc0..adea4d2f 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -194,7 +194,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index ba5c64af..449181c1 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -45,7 +45,7 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index c0411a58..20e293b9 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -64,7 +64,7 @@ var responseFuturesStream = make(chan stream.Response) // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index d5340f0c..3278914f 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -70,7 +70,7 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 6a3b35e1..7477e78f 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -556,7 +556,7 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer go g.wsReadData() diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 913c856d..43c2135e 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -39,7 +39,7 @@ var comms = make(chan stream.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index fee75d6b..d2d89eb5 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -128,7 +128,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 3e629d0a..68b84951 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -466,7 +466,7 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 705f584f..deb88542 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -34,7 +34,7 @@ const ( // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3b65fe64..7bcf7ada 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -147,7 +147,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 4aafc7b7..16106daf 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -78,7 +78,7 @@ func setupWsTests(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } comms = make(chan WsMessage, sharedtestvalues.WebsocketChannelOverrideCapacity) go h.wsReadData() diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index f601b03d..92b5bf4c 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -62,7 +62,7 @@ var comms = make(chan WsMessage) // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.wsDial(&dialer) diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3d1d576b..90d70491 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -202,7 +202,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 263d70bf..8530cb53 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1215,7 +1215,7 @@ func setupWsTests(t *testing.T) { return } if !k.Websocket.IsEnabled() && !k.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(k) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 787d52b2..fd432516 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -87,7 +87,7 @@ var cancelOrdersStatus = make(map[int64]*struct { // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index de631e0d..58759175 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -209,7 +209,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.New() + k.Websocket = stream.NewWebsocket() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 3f8a0c78..917bc1a9 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -97,7 +97,7 @@ var ( // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } fetchedFuturesSnapshotOrderbook = map[string]bool{} var dialer websocket.Dialer diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 8a98a77a..d767a2cf 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -195,7 +195,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.New() + ku.Websocket = stream.NewWebsocket() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index e714aba7..0787775e 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -74,7 +74,7 @@ func isAuthenticatedChannel(channel string) bool { // WsConnect initiates a websocket connection func (o *Okcoin) WsConnect() error { if !o.Websocket.IsEnabled() || !o.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 8519cdc1..af0655ef 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -150,7 +150,7 @@ func (o *Okcoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - o.Websocket = stream.New() + o.Websocket = stream.NewWebsocket() o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index cd0c7f86..85b6102e 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -130,7 +130,7 @@ func (o *Okcoin) WsAmendMultipleOrder(args []AmendTradeOrderRequestParam) ([]Ame func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}, authenticated bool) error { switch { case !o.Websocket.IsEnabled(): - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled case !o.Websocket.IsConnected(): return stream.ErrNotConnected case !o.Websocket.CanUseAuthenticatedEndpoints() && authenticated: diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index b4d211ee..54bc6b0e 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -216,7 +216,7 @@ const ( // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 0b5472b9..64e2e269 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -190,7 +190,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.New() + ok.Websocket = stream.NewWebsocket() ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index bab0ffd4..d74e2353 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -548,7 +548,7 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 1be429a5..23774335 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -55,7 +55,7 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index 97eb9293..57f28fac 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -159,7 +159,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.New() + p.Websocket = stream.NewWebsocket() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 60d51d82..54acf9dc 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *stream.Websocket { return &stream.Websocket{ - Init: true, DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity), ToRoutine: make(chan interface{}, 1000), TrafficAlert: make(chan struct{}), diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index aefc3400..404d76d8 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -16,47 +16,43 @@ import ( ) const ( - defaultJobBuffer = 5000 - // defaultTrafficPeriod defines a period of pause for the traffic monitor, - // as there are periods with large incoming traffic alerts which requires a - // timer reset, this limits work on this routine to a more effective rate - // of check. - defaultTrafficPeriod = time.Second + jobBuffer = 5000 ) +// Public websocket errors var ( - // ErrSubscriptionNotFound defines an error when a subscription is not found - ErrSubscriptionNotFound = errors.New("subscription not found") - // ErrSubscribedAlready defines an error when a channel is already subscribed - ErrSubscribedAlready = errors.New("duplicate subscription") - // ErrSubscriptionFailure defines an error when a subscription fails - ErrSubscriptionFailure = errors.New("subscription failure") - // ErrSubscriptionNotSupported defines an error when a subscription channel is not supported by an exchange + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrSubscriptionNotFound = errors.New("subscription not found") + ErrSubscribedAlready = errors.New("duplicate subscription") + ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") - // ErrUnsubscribeFailure defines an error when a unsubscribe fails - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - // ErrChannelInStateAlready defines an error when a subscription channel is already in a new state - ErrChannelInStateAlready = errors.New("channel already in state") - // ErrAlreadyDisabled is returned when you double-disable the websocket - ErrAlreadyDisabled = errors.New("websocket already disabled") - // ErrNotConnected defines an error when websocket is not connected - ErrNotConnected = errors.New("websocket is not connected") + ErrUnsubscribeFailure = errors.New("unsubscribe failure") + ErrChannelInStateAlready = errors.New("channel already in state") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") +) +// Private websocket errors +var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") + errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") errConfigFeaturesIsNil = errors.New("exchange config features is nil") errDefaultURLIsEmpty = errors.New("default url is empty") errRunningURLIsEmpty = errors.New("running url cannot be empty") errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameUnset = errors.New("exchange config name unset") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") @@ -64,9 +60,18 @@ var ( errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") ) -var globalReporter Reporter +var ( + globalReporter Reporter + trafficCheckInterval = 100 * time.Millisecond +) // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -74,13 +79,12 @@ func SetupGlobalReporter(r Reporter) { globalReporter = r } -// New initialises the websocket struct -func New() *Websocket { +// NewWebsocket initialises the websocket struct +func NewWebsocket() *Websocket { return &Websocket{ - Init: true, - DataHandler: make(chan interface{}, defaultJobBuffer), - ToRoutine: make(chan interface{}, defaultJobBuffer), - TrafficAlert: make(chan struct{}), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), Subscribe: make(chan []subscription.Subscription), Unsubscribe: make(chan []subscription.Subscription), @@ -98,7 +102,10 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if !w.Init { + w.m.Lock() + defer w.m.Unlock() + + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -107,7 +114,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameUnset + return errExchangeConfigNameEmpty } w.exchangeName = s.ExchangeConfig.Name w.verbose = s.ExchangeConfig.Verbose @@ -120,7 +127,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -188,28 +195,30 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + w.setState(disconnected) + return nil } // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { - return errors.New("setting up new connection error: websocket is nil") + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } if c == (ConnectionSetup{}) { - return errors.New("setting up new connection error: websocket connection configuration empty") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } if w.exchangeName == "" { - return errors.New("setting up new connection error: exchange name not set, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } if w.TrafficAlert == nil { - return errors.New("setting up new connection error: traffic alert is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } if w.ReadMessageErrors == nil { - return errors.New("setting up new connection error: read message errors is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } connectionURL := w.GetWebsocketURL() @@ -253,21 +262,19 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // function func (w *Websocket) Connect() error { if w.connector == nil { - return errors.New("websocket connect function not set, cannot continue") + return errNoConnectFunc } w.m.Lock() defer w.m.Unlock() if !w.IsEnabled() { - return errors.New(WebsocketNotEnabled) + return ErrWebsocketNotEnabled } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } w.subscriptionMutex.Lock() @@ -276,25 +283,19 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setConnectingStatus(true) + w.setState(connecting) err := w.connector() if err != nil { - w.setConnectingStatus(false) - return fmt.Errorf("%v Error connecting %s", - w.exchangeName, err) + w.setState(disconnected) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setConnectedStatus(true) - w.setConnectingStatus(false) - w.setInit(true) + w.setState(connected) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() if err != nil { - log.Errorf(log.WebsocketMgr, - "%s cannot start websocket connection monitor %v", - w.GetName(), - err) + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -317,9 +318,10 @@ func (w *Websocket) Connect() error { } // Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { - return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) } w.setEnabled(false) @@ -329,8 +331,7 @@ func (w *Websocket) Disable() error { // Enable enables the exchange websocket protocol func (w *Websocket) Enable() error { if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("websocket is already enabled for exchange %s", - w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) } w.setEnabled(true) @@ -369,9 +370,7 @@ func (w *Websocket) dataMonitor() { case <-w.ShutdownC: return default: - log.Warnf(log.WebsocketMgr, - "%s exchange backlog in websocket processing detected", - w.exchangeName) + log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) select { case w.ToRoutine <- d: case <-w.ShutdownC: @@ -388,34 +387,25 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) for { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: running connection monitor cycle\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) } if !w.IsEnabled() { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connectionMonitor - websocket disabled, shutting down\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) } if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) } } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connection monitor exiting\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) } timer.Stop() w.setConnectionMonitorRunning(false) @@ -424,11 +414,8 @@ func (w *Websocket) connectionMonitor() error { select { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { - w.setInit(false) - log.Warnf(log.WebsocketMgr, - "%v websocket has been disconnected. Reason: %v", - w.exchangeName, err) - w.setConnectedStatus(false) + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) + w.setState(disconnected) } w.DataHandler <- err @@ -459,21 +446,16 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", - w.exchangeName, - ErrNotConnected) + return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", - w.exchangeName) + return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: shutting down websocket\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) } defer w.Orderbook.FlushBuffer() @@ -495,15 +477,13 @@ func (w *Websocket) Shutdown() error { w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() + w.setState(disconnected) + close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setConnectedStatus(false) - w.setConnectingStatus(false) if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: completed websocket shutdown\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } return nil } @@ -511,11 +491,11 @@ func (w *Websocket) Shutdown() error { // FlushChannels flushes channel subscriptions when there is a pair/asset change func (w *Websocket) FlushChannels() error { if !w.IsEnabled() { - return fmt.Errorf("%s websocket: service not enabled", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { @@ -565,9 +545,9 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } -// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires, -// it will reconnect if the TrafficAlert channel has not received any data. The -// trafficTimer will reset on each traffic alert +// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert +// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic +// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again func (w *Websocket) trafficMonitor() { if w.IsTrafficMonitorRunning() { return @@ -576,183 +556,121 @@ func (w *Websocket) trafficMonitor() { w.Wg.Add(1) go func() { - var trafficTimer = time.NewTimer(w.trafficTimeout) - pause := make(chan struct{}) + t := time.NewTimer(w.trafficTimeout) for { select { case <-w.ShutdownC: if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown message received\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } - trafficTimer.Stop() + t.Stop() w.setTrafficMonitorRunning(false) w.Wg.Done() return - case <-w.TrafficAlert: - if !trafficTimer.Stop() { - select { - case <-trafficTimer.C: - default: + case <-time.After(trafficCheckInterval): + select { + case <-w.TrafficAlert: + if !t.Stop() { + <-t.C } + t.Reset(w.trafficTimeout) + default: + } + case <-t.C: + checkAgain := w.IsConnecting() + select { + case <-w.TrafficAlert: + checkAgain = true + default: + } + if checkAgain { + t.Reset(w.trafficTimeout) + break } - w.setConnectedStatus(true) - trafficTimer.Reset(w.trafficTimeout) - case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { - log.Warnf(log.WebsocketMgr, - "%v websocket: has not received a traffic alert in %v. Reconnecting", - w.exchangeName, - w.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() // without this the w.Shutdown() call below will deadlock - if !w.IsConnecting() && w.IsConnected() { + w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if w.IsConnected() { err := w.Shutdown() if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown err: %s", - w.exchangeName, err) + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } - return } - - if w.IsConnected() { - // Routine pausing mechanism - go func(p chan<- struct{}) { - time.Sleep(defaultTrafficPeriod) - select { - case p <- struct{}{}: - default: - } - }(pause) - select { - case <-w.ShutdownC: - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-pause: - } - } } }() } -func (w *Websocket) setConnectedStatus(b bool) { - w.fieldMutex.Lock() - w.connected = b - w.fieldMutex.Unlock() +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -// IsConnected returns status of connection +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialised +} + +// IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connected + return w.state.Load() == connected } -func (w *Websocket) setConnectingStatus(b bool) { - w.fieldMutex.Lock() - w.connecting = b - w.fieldMutex.Unlock() -} - -// IsConnecting returns status of connecting +// IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connecting + return w.state.Load() == connecting } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } -// IsEnabled returns status of enabled +// IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled -} - -func (w *Websocket) setInit(b bool) { - w.fieldMutex.Lock() - w.Init = b - w.fieldMutex.Unlock() -} - -// IsInit returns status of init -func (w *Websocket) IsInit() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.Init + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() && w.CanUseAuthenticatedEndpoints() { - return true - } else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() { - log.Infof(log.WebsocketMgr, - WebsocketNotAuthenticatedUsingRest, - w.exchangeName) + if w.IsConnected() { + if w.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) } return false } @@ -820,28 +738,22 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { + w.m.Lock() + if proxyAddr != "" { - _, err := url.ParseRequestURI(proxyAddr) - if err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", - w.exchangeName, - err) + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", - w.exchangeName, - w.proxyAddr) + w.m.Unlock() + return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } - log.Debugf(log.ExchangeSys, - "%s websocket: setting websocket proxy: %s\n", - w.exchangeName, - proxyAddr) + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) } else { - log.Debugf(log.ExchangeSys, - "%s websocket: removing websocket proxy\n", - w.exchangeName) + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } if w.Conn != nil { @@ -852,15 +764,17 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { } w.proxyAddr = proxyAddr - if w.IsInit() && w.IsEnabled() { - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - return err - } + + if w.IsConnected() { + w.m.Unlock() + if err := w.Shutdown(); err != nil { + return err } return w.Connect() } + + w.m.Unlock() + return nil } @@ -1035,20 +949,14 @@ func (w *Websocket) GetSubscriptions() []subscription.Subscription { return subs } -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in -// a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - w.canUseAuthenticatedEndpoints = val +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { + w.canUseAuthenticatedEndpoints.Store(b) } -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in -// a thread safe manner +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.canUseAuthenticatedEndpoints + return w.canUseAuthenticatedEndpoints.Load() } // IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 0bb1e660..4d7681f8 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -50,9 +50,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter return payload, nil case <-timer.C: timer.Stop() - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", - w.ExchangeName, - signature) + return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } @@ -72,25 +70,14 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { - return fmt.Errorf("%s websocket connection: %v %v %v Error: %v", - w.ExchangeName, - w.URL, - conStatus, - conStatus.StatusCode, - err) + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } - return fmt.Errorf("%s websocket connection: %v Error: %v", - w.ExchangeName, - w.URL, - err) + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } defer conStatus.Body.Close() if w.Verbose { - log.Infof(log.WebsocketMgr, - "%v Websocket connected to %s\n", - w.ExchangeName, - w.URL) + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) } select { case w.Traffic <- struct{}{}: @@ -240,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // causes contention, just bypass if there is no receiver. + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte @@ -285,7 +272,7 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a messageID to checkout +// GenerateMessageID Creates a random message ID func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { var min int64 = 1e8 var max int64 = 2e8 diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ab49e0d..0d4e9c02 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -18,6 +19,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -30,6 +32,10 @@ const ( proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) +var ( + errDastardlyReason = errors.New("some dastardly reason") +) + var dialer websocket.Dialer type testStruct struct { @@ -68,7 +74,7 @@ var defaultSetup = &WebsocketSetup{ AuthenticatedWebsocketSupport: true, }, WebsocketTrafficTimeout: time.Second * 5, - Name: "exchangeName", + Name: "GTX", }, DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", @@ -92,416 +98,355 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errors.New("cannot shutdown due to some dastardly reason") + return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errors.New("cannot connect due to some dastardly reason") + return fmt.Errorf("cannot connect: %w", errDastardlyReason) +} + +func TestMain(m *testing.M) { + // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing + trafficCheckInterval = 50 * time.Millisecond + os.Exit(m.Run()) } func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - if !errors.Is(err, errWebsocketIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketIsNil) - } + assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - if !errors.Is(err, errWebsocketSetupIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSetupIsNil) - } + assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") websocketSetup := &WebsocketSetup{} - err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - w.Init = true err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigNameUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigNameUnset) - } - websocketSetup.ExchangeConfig.Name = "testname" + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") + websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketFeaturesIsUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketFeaturesIsUnset) - } + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - if !errors.Is(err, errConfigFeaturesIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigFeaturesIsNil) - } + assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketConnectorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketConnectorUnset) - } + assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketUnsubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errDefaultURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) - } + assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - if !errors.Is(err, errRunningURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRunningURLIsEmpty) - } + assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidTrafficTimeout) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidTrafficTimeout) - } + assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + assert.NoError(t, err, "Setup should not error") } -func TestTrafficMonitorTimeout(t *testing.T) { +// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval +// ensures shutdown works after traffic alerts +func TestTrafficMonitorTrafficAlerts(t *testing.T) { t.Parallel() - ws := *New() - if err := ws.Setup(defaultSetup); err != nil { - t.Fatal(err) - } - ws.trafficTimeout = time.Second * 2 + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + signal := struct{}{} + patience := 10 * time.Millisecond + ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) + + thenish := time.Now() ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") + + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + + for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass + select { + case ws.TrafficAlert <- signal: + if i == 0 { + require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") + } + default: + require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) + } + + select { + case ws.TrafficAlert <- signal: + require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) + default: + if i == 0 { + require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") + } + } + + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 5*time.Second, patience, "trafficAlert should be drained; Check #%d", i) + assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) } - // Deploy traffic alert - ws.TrafficAlert <- struct{}{} - // try to add another traffic monitor + + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") +} + +// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown +func TestTrafficMonitorConnecting(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.ShutdownC = make(chan struct{}) + ws.state.Store(connecting) + ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") + require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + <-time.After(4 * ws.trafficTimeout) + require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connected) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") +} + +// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared +func TestTrafficMonitorShutdown(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) + ws.trafficTimeout = time.Minute + ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + + wgReady := make(chan bool) + go func() { + ws.Wg.Wait() + close(wgReady) + }() + select { + case <-wgReady: + require.Failf(t, "", "WaitGroup should be blocking still") + case <-time.After(trafficCheckInterval): } - // prevent shutdown routine - ws.setConnectedStatus(false) - // await timeout closure - ws.Wg.Wait() - if ws.IsTrafficMonitorRunning() { - t.Error("should be dead") + + close(ws.ShutdownC) + + <-time.After(2 * trafficCheckInterval) + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") + select { + case <-wgReady: + default: + require.Failf(t, "", "WaitGroup should be freed now") } } func TestIsDisconnectionError(t *testing.T) { t.Parallel() - isADisconnectionError := IsDisconnectionError(errors.New("errorText")) - if isADisconnectionError { - t.Error("Its not") - } - isADisconnectionError = IsDisconnectionError(&websocket.CloseError{ - Code: 1006, - Text: "errorText", - }) - if !isADisconnectionError { - t.Error("It is") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errClosedConnection, - }) - if isADisconnectionError { - t.Error("It's not") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errors.New("errText"), - }) - if !isADisconnectionError { - t.Error("It is") - } + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") } func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} err := wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") wsWrong.connector = func() error { return nil } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) - wsWrong.setConnectingStatus(true) + wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") - wsWrong.setConnectedStatus(false) - wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } + wsWrong.setState(disconnected) + wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") - ws := *New() + ws := NewWebsocket() err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute ws.connector = func() error { return nil } err = ws.Connect() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Connect must not error") ws.TrafficAlert <- struct{}{} - timer := time.NewTimer(900 * time.Millisecond) - ws.ReadMessageErrors <- errors.New("errorText") - select { - case err := <-ws.ToRoutine: - errText, ok := err.(error) - if !ok { - t.Error("unable to type assert error") - } else if errText.Error() != "errorText" { - t.Errorf("Expected 'errorText', received %v", err) - } - case <-timer.C: - t.Error("Timeout waiting for datahandler to receive error") - } - ws.ReadMessageErrors <- &websocket.CloseError{ - Code: 1006, - Text: "errorText", - } -outer: - for { + c := func(tb *assert.CollectT) { select { - case err := <-ws.ToRoutine: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("Error is not a disconnection error: %v", err) + case v := <-ws.ToRoutine: + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") } - case <-timer.C: - break outer + default: } } + + ws.ReadMessageErrors <- errDastardlyReason + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + + ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) { t.Parallel() - wsInit := Websocket{} - err := wsInit.Setup(&WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, - }, - Name: "test", - }, - }) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - ws := *New() - err = ws.SetProxyAddress("garbagio") - if err == nil { - t.Error("error cannot be nil") - } + ws := NewWebsocket() - ws.Conn = &WebsocketConnection{} + err := ws.SetProxyAddress("garbagio") + assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") + + ws.Conn = &dodgyConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + err = ws.Setup(defaultSetup) + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") + + assert.Equal(t, "GTX", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") + + ws.setEnabled(false) + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + + ws.setEnabled(true) + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") + err = ws.SetProxyAddress("https://192.168.0.1:1337") - if err == nil { - t.Error("error cannot be nil") - } - ws.setConnectedStatus(true) - ws.ShutdownC = make(chan struct{}) - ws.Wg = &sync.WaitGroup{} - err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") + + ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } - ws.setEnabled(false) + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") + + err = ws.SetProxyAddress("https://192.168.0.1:1336") + assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") // removing proxy err = ws.SetProxyAddress("") - if err != nil { - t.Error(err) - } - // reinstate proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err != nil { - t.Error(err) - } - // conflict proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err == nil { - t.Error("error cannot be nil") - } - err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } - if ws.GetName() != "exchangeName" { - t.Error("WebsocketSetup") - } - - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } - - ws.setEnabled(false) - if ws.IsEnabled() { - t.Error("WebsocketSetup") - } - ws.setEnabled(true) - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } - - if ws.GetProxyAddress() != "http://localhost:1337" { - t.Error("WebsocketSetup") - } - - if ws.GetWebsocketURL() != "wss://testRunningURL" { - t.Error("WebsocketSetup") - } - if ws.trafficTimeout != time.Second*5 { - t.Error("WebsocketSetup") - } - // -- Not connected shutdown - err = ws.Shutdown() - if err == nil { - t.Fatal("should not be connected to able to shut down") - } - - ws.setConnectedStatus(true) - ws.Conn = &dodgyConnection{} - err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") + assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") ws.Conn = &WebsocketConnection{} + ws.setEnabled(true) - ws.setConnectedStatus(true) + // reinstate proxy + err = ws.SetProxyAddress("http://localhost:1337") + assert.NoError(t, err, "SetProxyAddress should not error") + assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") + assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") + assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") + + ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil ") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") + assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setConnectedStatus(false) + ws.setState(disconnected) - // -- Normal connect err = ws.Connect() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" err = ws.SetWebsocketURL("", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("", true, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false) - if err != nil { - t.Fatal(err) - } - // Attempt reconnect + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") + // -- initiate the reconnect which is usually handled by connection monitor err = ws.Connect() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "ReConnect called manually should not error") + err = ws.Connect() - if err == nil { - t.Fatal("should already be connected") - } - // -- Normal shutdown + assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") + err = ws.Shutdown() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() } // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") fnSub := func(subs []subscription.Subscription) error { @@ -546,7 +491,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() wackedOutSetup := *defaultSetup wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 @@ -577,7 +522,7 @@ func TestResubscribe(t *testing.T) { // TestSubscriptionState tests Subscription state changes func TestSubscriptionState(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") @@ -603,7 +548,7 @@ func TestSubscriptionState(t *testing.T) { // TestRemoveSubscriptions tests removing a subscription func TestRemoveSubscriptions(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") @@ -616,24 +561,18 @@ func TestRemoveSubscriptions(t *testing.T) { // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() ws.connectionMonitorDelay = 500 ws.DataHandler = make(chan interface{}, 1) ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } - if !ws.IsConnectionMonitorRunning() { - t.Fatal("Should not have exited") - } + require.NoError(t, err, "connectionMonitor must not error") + assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() - if !errors.Is(err, errAlreadyRunning) { - t.Fatalf("received: %v, but expected: %v", err, errAlreadyRunning) - } + assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") } // TestGetSubscription logic test @@ -671,16 +610,10 @@ func TestGetSubscriptions(t *testing.T) { // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := *New() - result := ws.CanUseAuthenticatedEndpoints() - if result { - t.Error("expected `canUseAuthenticatedEndpoints` to be false") - } + ws := NewWebsocket() + assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) - result = ws.CanUseAuthenticatedEndpoints() - if !result { - t.Error("expected `canUseAuthenticatedEndpoints` to be true") - } + assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") } // TestDial logic test @@ -917,81 +850,53 @@ func TestParseBinaryResponse(t *testing.T) { } var b bytes.Buffer - w := gzip.NewWriter(&b) - _, err := w.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - err = w.Close() - if err != nil { - t.Error(err) - } - var resp []byte + g := gzip.NewWriter(&b) + _, err := g.Write([]byte("hello")) + require.NoError(t, err, "gzip.Write must not error") + assert.NoError(t, g.Close(), "Close should not error") + + resp, err := wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing gzip") + assert.EqualValues(t, "hello", resp, "parseBinaryResponse should decode gzip") + + b.Reset() + f, err := flate.NewWriter(&b, 1) + require.NoError(t, err, "flate.NewWriter must not error") + _, err = f.Write([]byte("goodbye")) + require.NoError(t, err, "flate.Write must not error") + assert.NoError(t, f.Close(), "Close should not error") + resp, err = wc.parseBinaryResponse(b.Bytes()) - if err != nil { - t.Error(err) - } - if !strings.EqualFold(string(resp), "hello") { - t.Errorf("GZip conversion failed. Received: '%v', Expected: 'hello'", string(resp)) - } + assert.NoError(t, err, "parseBinaryResponse should not error parsing inflate") + assert.EqualValues(t, "goodbye", resp, "parseBinaryResponse should deflate") - var b2 bytes.Buffer - w2, err2 := flate.NewWriter(&b2, 1) - if err2 != nil { - t.Error(err2) - } - _, err2 = w2.Write([]byte("hello")) - if err2 != nil { - t.Error(err) - } - err2 = w2.Close() - if err2 != nil { - t.Error(err) - } - resp2, err3 := wc.parseBinaryResponse(b2.Bytes()) - if err3 != nil { - t.Error(err3) - } - if !strings.EqualFold(string(resp2), "hello") { - t.Errorf("Deflate conversion failed. Received: '%v', Expected: 'hello'", string(resp2)) - } - - _, err4 := wc.parseBinaryResponse([]byte{}) - if err4 == nil || err4.Error() != "unexpected EOF" { - t.Error("Expected error 'unexpected EOF'") - } + _, err = wc.parseBinaryResponse([]byte{}) + assert.ErrorContains(t, err, "unexpected EOF", "parseBinaryResponse should error on empty input") } // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() ws := &Websocket{} - resp := ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is false") - } - ws.setConnectedStatus(true) - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") - } - ws.canUseAuthenticatedEndpoints = true - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if !resp { - t.Error("Expected true, `connected` and `CanUseAuthenticatedEndpoints` is true") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.setState(connected) + require.True(t, ws.IsConnected(), "IsConnected must return true") + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.SetCanUseAuthenticatedEndpoints(true) + assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } func TestGenerateMessageID(t *testing.T) { t.Parallel() wc := WebsocketConnection{} - var id int64 - for i := 0; i < 10; i++ { - newID := wc.GenerateMessageID(true) - if id == newID { - t.Fatal("ID generation is not unique") - } - id = newID + const spins = 1000 + ids := make([]int64, spins) + for i := 0; i < spins; i++ { + id := wc.GenerateMessageID(true) + assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + ids[i] = id } } @@ -1013,34 +918,22 @@ func BenchmarkGenerateMessageID_Low(b *testing.B) { func TestCheckWebsocketURL(t *testing.T) { err := checkWebsocketURL("") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on empty string") err = checkWebsocketURL("wowowow:wowowowo") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on bad format") err = checkWebsocketURL("://") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "missing protocol scheme", "checkWebsocketURL should error correctly on bad proto") err = checkWebsocketURL("http://www.google.com") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on wrong proto") err = checkWebsocketURL("wss://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") err = checkWebsocketURL("ws://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") } func TestGetChannelDifference(t *testing.T) { @@ -1142,19 +1035,13 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - connected: true, + w := Websocket{ connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -1167,9 +1054,11 @@ func TestFlushChannels(t *testing.T) { // in FlushChannels() so the traffic monitor doesn't time out and turn // this to an unconnected state } + w.setEnabled(true) + w.setState(connected) problemFunc := func() ([]subscription.Subscription, error) { - return nil, errors.New("problems") + return nil, errDastardlyReason } noSub := func() ([]subscription.Subscription, error) { @@ -1179,53 +1068,40 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { + w.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{{Channel: "test"}}, nil } - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to unsub + assert.NoError(t, err, "FlushChannels should not error") - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() - if err != nil { - t.Fatal(err) - } - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() + require.NoError(t, err, "GenerateSubs must not error") - web.GenerateSubs = problemFunc - err = web.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + w.AddSuccessfulSubscriptions(subs...) + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ + w.GenerateSubs = problemFunc + err = w.FlushChannels() + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") + + w.GenerateSubs = newgen.generateSubs + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") + w.subscriptionMutex.Lock() + w.subscriptions = subscriptionMap{ 41: { Key: 41, Channel: "match channel", @@ -1237,46 +1113,34 @@ func TestFlushChannels(t *testing.T) { Pair: currency.NewPair(currency.THETA, currency.USDT), }, } - web.subscriptionMutex.Unlock() + w.subscriptionMutex.Unlock() - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - web.setConnectedStatus(true) - web.features.Unsubscribe = true - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + w.setState(connected) + w.features.Unsubscribe = true + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - connected: true, + w := Websocket{ ShutdownC: make(chan struct{}), } - err := web.Disable() - if err != nil { - t.Fatal(err) - } - err = web.Disable() - if err == nil { - t.Fatal("should already be disabled") - } + w.setEnabled(true) + w.setState(connected) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ + w := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), @@ -1286,98 +1150,59 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - err := web.Enable() - if err != nil { - t.Fatal(err) - } - - err = web.Enable() - if err == nil { - t.Fatal("should already be enabled") - } - - fmt.Print() + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") - nonsenseWebsock.TrafficAlert = make(chan struct{}) + nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - Init: true, - TrafficAlert: make(chan struct{}), - ReadMessageErrors: make(chan error), - DataHandler: make(chan interface{}), - } + web := NewWebsocket() err = web.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Setup should not error") + err = web.SetupNewConnection(ConnectionSetup{}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err != nil { - t.Fatal(err) - } - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", - Authenticated: true}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetupNewConnection should not error") + + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + assert.NoError(t, err, "SetupNewConnection should not error") } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() wc := WebsocketConnection{} err := wc.Shutdown() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") wc.URL = websocketTestURL err = wc.Dial(&websocket.Dialer{}, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial must not error") err = wc.Shutdown() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Shutdown must not error") } // TestLatency logic test @@ -1431,27 +1256,19 @@ func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} err := ws.checkSubscriptions(nil) - if !errors.Is(err, errNoSubscriptionsSupplied) { - t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied) - } + assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 1 err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) - if !errors.Is(err, errSubscriptionsExceedsLimit) { - t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit) - } + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - if !errors.Is(err, errChannelAlreadySubscribed) { - t.Fatalf("received: %v, but expected: %v", err, errChannelAlreadySubscribed) - } + assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly") err = ws.checkSubscriptions([]subscription.Subscription{{}}) - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } + assert.NoError(t, err, "checkSubscriptions should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 925c34b9..a783d585 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -15,8 +16,6 @@ import ( // Websocket functionality list and state consts const ( - // WebsocketNotEnabled alerts of a disabled websocket - WebsocketNotEnabled = "exchange_websocket_not_enabled" WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" Ping = "ping" Pong = "pong" @@ -25,18 +24,23 @@ const ( type subscriptionMap map[any]*subscription.Subscription +const ( + uninitialised uint32 = iota + disconnected + connecting + connected +) + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { - canUseAuthenticatedEndpoints bool - enabled bool - Init bool - connected bool - connecting bool + canUseAuthenticatedEndpoints atomic.Bool + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -46,7 +50,6 @@ type Websocket struct { runningURLAuth string exchangeName string m sync.Mutex - fieldMutex sync.RWMutex connector func() error subscriptionMutex sync.RWMutex