diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 4a1bcb6d..2c9cf17a 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -1,28 +1,28 @@ Thanks to the following contributors: thrasher- | https://github.com/thrasher- -shazbert | https://github.com/shazbert dependabot[bot] | https://github.com/apps/dependabot +shazbert | https://github.com/shazbert gloriousCode | https://github.com/gloriousCode gbjk | https://github.com/gbjk dependabot-preview[bot] | https://github.com/apps/dependabot-preview xtda | https://github.com/xtda lrascao | https://github.com/lrascao Beadko | https://github.com/Beadko -Rots | https://github.com/Rots -vazha | https://github.com/vazha ydm | https://github.com/ydm +vazha | https://github.com/vazha +Rots | https://github.com/Rots ermalguni | https://github.com/ermalguni MadCozBadd | https://github.com/MadCozBadd samuael | https://github.com/samuael vadimzhukck | https://github.com/vadimzhukck -140am | https://github.com/140am -marcofranssen | https://github.com/marcofranssen geseq | https://github.com/geseq +marcofranssen | https://github.com/marcofranssen +140am | https://github.com/140am +junnplus | https://github.com/junnplus TaltaM | https://github.com/TaltaM cranktakular | https://github.com/cranktakular dackroyd | https://github.com/dackroyd -junnplus | https://github.com/junnplus khcchiu | https://github.com/khcchiu yangrq1018 | https://github.com/yangrq1018 woshidama323 | https://github.com/woshidama323 diff --git a/README.md b/README.md index d2a397d2..ada6a237 100644 --- a/README.md +++ b/README.md @@ -156,29 +156,29 @@ Binaries will be published once the codebase reaches a stable condition. |User|Contribution Amount| |--|--| -| [thrasher-](https://github.com/thrasher-) | 708 | +| [thrasher-](https://github.com/thrasher-) | 711 | +| [dependabot[bot]](https://github.com/apps/dependabot) | 369 | | [shazbert](https://github.com/shazbert) | 362 | -| [dependabot[bot]](https://github.com/apps/dependabot) | 361 | | [gloriousCode](https://github.com/gloriousCode) | 237 | -| [gbjk](https://github.com/gbjk) | 121 | +| [gbjk](https://github.com/gbjk) | 123 | | [dependabot-preview[bot]](https://github.com/apps/dependabot-preview) | 88 | | [xtda](https://github.com/xtda) | 47 | | [lrascao](https://github.com/lrascao) | 27 | -| [Beadko](https://github.com/Beadko) | 22 | -| [Rots](https://github.com/Rots) | 15 | -| [vazha](https://github.com/vazha) | 15 | +| [Beadko](https://github.com/Beadko) | 24 | | [ydm](https://github.com/ydm) | 15 | +| [vazha](https://github.com/vazha) | 15 | +| [Rots](https://github.com/Rots) | 15 | | [ermalguni](https://github.com/ermalguni) | 14 | | [MadCozBadd](https://github.com/MadCozBadd) | 13 | | [samuael](https://github.com/samuael) | 11 | | [vadimzhukck](https://github.com/vadimzhukck) | 10 | -| [140am](https://github.com/140am) | 8 | -| [marcofranssen](https://github.com/marcofranssen) | 8 | | [geseq](https://github.com/geseq) | 8 | +| [marcofranssen](https://github.com/marcofranssen) | 8 | +| [140am](https://github.com/140am) | 8 | +| [junnplus](https://github.com/junnplus) | 8 | | [TaltaM](https://github.com/TaltaM) | 6 | | [cranktakular](https://github.com/cranktakular) | 6 | | [dackroyd](https://github.com/dackroyd) | 5 | -| [junnplus](https://github.com/junnplus) | 5 | | [khcchiu](https://github.com/khcchiu) | 5 | | [yangrq1018](https://github.com/yangrq1018) | 4 | | [woshidama323](https://github.com/woshidama323) | 3 | diff --git a/exchanges/stream/README.md b/cmd/documentation/internal_templates/internal_exchange_websocket_readme.tmpl similarity index 85% rename from exchanges/stream/README.md rename to cmd/documentation/internal_templates/internal_exchange_websocket_readme.tmpl index 3a02c2ef..5ea81254 100644 --- a/exchanges/stream/README.md +++ b/cmd/documentation/internal_templates/internal_exchange_websocket_readme.tmpl @@ -1,10 +1,8 @@ -# GoCryptoTrader Exchange Stream Package - -This package is part of the GoCryptoTrader project and is responsible for handling exchange streaming data. - +{{define "internal exchange websocket" -}} +{{template "header" .}} ## Overview -The `stream` package uses Gorilla Websocket and provides functionalities to connect to various cryptocurrency exchanges and handle real-time data streams. +The `websocket` package provides methods to manage connections and subscriptions for exchange websockets. ## Features @@ -21,13 +19,14 @@ The `stream` package uses Gorilla Websocket and provides functionalities to conn ## Usage ### Default single websocket connection -Here is a basic example of how to setup the `stream` package for websocket: + +Example setup for the `websocket` package connection: ```go package main import ( - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/request" ) @@ -38,16 +37,16 @@ type Exchange struct { // In the exchange wrapper this will set up the initial pointer field provided by exchange.Base func (e *Exchange) SetDefault() { - e.Websocket = stream.NewWebsocket() + e.Websocket = websocket.NewManager() e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit } -// In the exchange wrapper this is the original setup pattern for the websocket services +// In the exchange wrapper this is the original setup pattern for the websocket services func (e *Exchange) Setup(exch *config.Exchange) error { // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. - if err := e.Websocket.Setup(&stream.WebsocketSetup{ + if err := e.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: connectionURLString, RunningURL: connectionURLString, @@ -63,7 +62,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { } // This is a public websocket connection - if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + if err := ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: connectionURLString, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, @@ -72,8 +71,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } - // This is a private websocket connection - return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + // This is a private websocket connection + return ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: privateConnectionURLString, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, @@ -89,7 +88,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { ```go func (e *Exchange) Setup(exch *config.Exchange) error { // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. - if err := e.Websocket.Setup(&stream.WebsocketSetup{ + if err := e.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, Features: &e.Features.Supports.WebsocketCapabilities, FillsFeed: e.Features.Enabled.FillsFeed, @@ -100,7 +99,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } // Spot connection - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: connectionURLStringForSpot, RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -117,7 +116,7 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } // Futures connection - USDT margined - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: connectionURLStringForSpotForFutures, RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -134,4 +133,8 @@ func (e *Exchange) Setup(exch *config.Exchange) error { return err } } -``` \ No newline at end of file +``` + +{{template "contributions"}} +{{template "donations" .}} +{{end}} diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index f42dcc34..fca261ea 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -19,9 +19,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -108,7 +108,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.NewWebsocket() + {{.Variable}}.Websocket = websocket.NewManager() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -138,7 +138,7 @@ func ({{.Variable}} *{{.CapitalName}}) Setup(exch *config.Exchange) error { // If websocket is supported, please fill out the following err = {{.Variable}}.Websocket.Setup( - &stream.WebsocketSetup{ + &websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: {{.Name}}WSAPIURL, RunningURL: wsRunningEndpoint, @@ -151,7 +151,7 @@ func ({{.Variable}} *{{.CapitalName}}) Setup(exch *config.Exchange) error { return err } - {{.Variable}}.WebsocketConn = &stream.WebsocketConnection{ + {{.Variable}}.WebsocketConn = &websocket.WebsocketConnection{ ExchangeName: {{.Variable}}.Name, URL: {{.Variable}}.Websocket.GetWebsocketURL(), ProxyURL: {{.Variable}}.Websocket.GetProxyAddress(), diff --git a/cmd/websocket_client/main.go b/cmd/websocket_client/main.go index 274f4a55..e15819a6 100644 --- a/cmd/websocket_client/main.go +++ b/cmd/websocket_client/main.go @@ -8,7 +8,7 @@ import ( "net/http" "strconv" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/config" @@ -18,7 +18,7 @@ import ( // Vars for the websocket client var ( - WSConn *websocket.Conn + WSConn *gws.Conn ) // WebsocketEvent is the struct used for websocket events @@ -89,7 +89,7 @@ func main() { strconv.Itoa(common.ExtractPort(listenAddr)))) log.Printf("Connecting to websocket host: %s", wsHost) - var dialer websocket.Dialer + var dialer gws.Dialer var resp *http.Response WSConn, resp, err = dialer.Dial(wsHost, http.Header{}) if err != nil { diff --git a/common/common.go b/common/common.go index 551a02d8..2958caed 100644 --- a/common/common.go +++ b/common/common.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" "unicode" + "unsafe" "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/log" @@ -80,6 +81,22 @@ var ( errHTTPClientInvalid = errors.New("custom http client cannot be nil") ) +// NilGuard returns an ErrNilPointer with the type of the first nil argument +func NilGuard(ptrs ...any) (errs error) { + for _, p := range ptrs { + /* Internally interfaces contain a type and a value address + Obviously can't compare to nil, since the types won't match, so we look into the interface + eface is the internal representation of any; e(mpty-inter)face + See: https://cs.opensource.google/go/go/+/refs/tags/go1.24.1:src/runtime/runtime2.go;l=184-187 + We optimize here by converting to [2]uintptr and just checking the address, instead of casting to a local eface type + */ + if (*[2]uintptr)(unsafe.Pointer(&p))[1] == 0 { + errs = AppendError(errs, fmt.Errorf("%w: %T", ErrNilPointer, p)) + } + } + return errs +} + // MatchesEmailPattern ensures that the string is an email address by regexp check func MatchesEmailPattern(value string) bool { if len(value) < 3 || len(value) > 254 { diff --git a/common/common_test.go b/common/common_test.go index 230324cf..8450bf72 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -722,3 +722,26 @@ func BenchmarkCounter(b *testing.B) { c.IncrementAndGet() } } + +func TestNilGuard(t *testing.T) { + t.Parallel() + err := NilGuard((*int)(nil)) + assert.ErrorIs(t, err, ErrNilPointer) + assert.ErrorContains(t, err, "*int") + + s := "normal input" + err = NilGuard(&s, 2, &[]int{4, 5, 6}, []int{1, 2, 3}, new(A)) + assert.NoError(t, err) + + err = NilGuard(&s, nil, (*int)(nil)) + assert.ErrorIs(t, err, ErrNilPointer) + assert.ErrorContains(t, err, "*int") + var mErr *multiError + require.ErrorAs(t, err, &mErr, "err must be a multiError") + assert.Len(t, mErr.Unwrap(), 2, "Should get 2 errors back") + + assert.ErrorIs(t, NilGuard(nil), ErrNilPointer, "Unusual input of an untyped nil should still error correctly") + + err = NilGuard() + require.NoError(t, err, "NilGuard with no arguments should not panic") +} diff --git a/communications/slack/slack.go b/communications/slack/slack.go index 46d9a0d2..f7d6c89c 100644 --- a/communications/slack/slack.go +++ b/communications/slack/slack.go @@ -12,7 +12,7 @@ import ( "sync" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/communications/base" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -44,7 +44,7 @@ type Slack struct { TargetChannelID string Details Response ReconnectURL string - WebsocketConn *websocket.Conn + WebsocketConn *gws.Conn Connected bool Shutdown bool mu sync.Mutex @@ -195,7 +195,7 @@ func (s *Slack) NewConnection() error { // WebsocketConnect creates a websocket dialer amd initiates a websocket // connection func (s *Slack) WebsocketConnect() error { - var dialer websocket.Dialer + var dialer gws.Dialer var err error websocketURL := s.Details.URL @@ -376,7 +376,7 @@ func (s *Slack) WebsocketSend(eventType, text string) error { if s.WebsocketConn == nil { return errors.New("websocket not connected") } - return s.WebsocketConn.WriteMessage(websocket.TextMessage, data) + return s.WebsocketConn.WriteMessage(gws.TextMessage, data) } // HandleMessage handles incoming messages and/or commands from slack diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index d34f4061..e9334dff 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -1070,7 +1070,7 @@ func (f *FTX) Setup(exch *config.Exchange) error { } // Websocket details setup below - err = f.Websocket.Setup(&stream.WebsocketSetup{ + err = f.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, // DefaultURL defines the default endpoint in the event a rollback is // needed via gctcli. @@ -1100,7 +1100,7 @@ func (f *FTX) Setup(exch *config.Exchange) error { return err } // Sets up a new connection for the websocket, there are two separate connections denoted by the ConnectionSetup struct auth bool. - return f.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return f.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, // RateLimit int64 rudimentary rate limit that sleeps connection in milliseconds before sending designated payload diff --git a/engine/apiserver.go b/engine/apiserver.go index c5b8c8ca..2f46f32b 100644 --- a/engine/apiserver.go +++ b/engine/apiserver.go @@ -13,7 +13,7 @@ import ( "time" "github.com/gorilla/mux" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/config" @@ -489,13 +489,13 @@ func (c *websocketClient) read() { for { msgType, message, err := c.Conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + if gws.IsUnexpectedCloseError(err, gws.CloseGoingAway, gws.CloseAbnormalClosure) { log.Errorf(log.APIServerMgr, "websocket: client disconnected, err: %s\n", err) } break } - if msgType == websocket.TextMessage { + if msgType == gws.TextMessage { var evt WebsocketEvent err := json.Unmarshal(message, &evt) if err != nil { @@ -551,7 +551,7 @@ func (c *websocketClient) write() { for { message, ok := <-c.Send if !ok { - err := c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + err := c.Conn.WriteMessage(gws.CloseMessage, []byte{}) if err != nil { log.Errorln(log.APIServerMgr, err) } @@ -559,7 +559,7 @@ func (c *websocketClient) write() { return } - w, err := c.Conn.NextWriter(websocket.TextMessage) + w, err := c.Conn.NextWriter(gws.TextMessage) if err != nil { log.Errorf(log.APIServerMgr, "websocket: failed to create new io.writeCloser: %s\n", err) return @@ -628,7 +628,7 @@ func (m *apiServerManager) WebsocketClientHandler(w http.ResponseWriter, r *http return } - upgrader := websocket.Upgrader{ + upgrader := gws.Upgrader{ WriteBufferSize: 1024, ReadBufferSize: 1024, } diff --git a/engine/apiserver_types.go b/engine/apiserver_types.go index 95276646..c5337b93 100644 --- a/engine/apiserver_types.go +++ b/engine/apiserver_types.go @@ -6,7 +6,7 @@ import ( "sync" "github.com/gorilla/mux" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" @@ -62,7 +62,7 @@ type apiServerManager struct { // websocketClient stores information related to the websocket client type websocketClient struct { Hub *websocketHub - Conn *websocket.Conn + Conn *gws.Conn Authenticated bool authFailures int Send chan []byte diff --git a/engine/communication_manager.md b/engine/communication_manager.md index 5292f601..55257f32 100644 --- a/engine/communication_manager.md +++ b/engine/communication_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Communication manager +# GoCryptoTrader package Communication Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Communication manager +## Current Features for Communication Manager + The communication manager subsystem is used to push events raised in GoCryptoTrader to any enabled communication system such as a Slack server + In order to modify the behaviour of the communication manager subsystem, you can edit the following inside your config file under `communications`: diff --git a/engine/connection_manager.md b/engine/connection_manager.md index e63697f5..e0c94a5c 100644 --- a/engine/connection_manager.md +++ b/engine/connection_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Connection manager +# GoCryptoTrader package Connection Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Connection manager +## Current Features for Connection Manager + The connection manager subsystem is used to periodically check whether the application is connected to the internet and will provide alerts of any changes + In order to modify the behaviour of the connection manager subsystem, you can edit the following inside your config file under `connectionMonitor`: diff --git a/engine/currency_state_manager.md b/engine/currency_state_manager.md index cb6a44d9..908f2ed8 100644 --- a/engine/currency_state_manager.md +++ b/engine/currency_state_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Currency state manager +# GoCryptoTrader package Currency State Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Currency state manager +## Current Features for Currency State Manager + The state manager keeps currency states up to date, which include: * Withdrawal - Determines if the currency is allowed to be withdrawn from the exchange. * Deposit - Determines if the currency is allowed to be deposited to an exchange. diff --git a/engine/database_connection.md b/engine/database_connection.md index ae941df5..a2fd5d9a 100644 --- a/engine/database_connection.md +++ b/engine/database_connection.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Database connection +# GoCryptoTrader package Database Connection @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Database connection +## Current Features for Database Connection + The database connection manager subsystem is used to periodically check whether the application is connected to the database and will provide alerts of any changes + In order to modify the behaviour of the database connection manager subsystem, you can edit the following inside your config file under `database`: diff --git a/engine/datahistory_manager.md b/engine/datahistory_manager.md index bc60ebc3..aca38ff3 100644 --- a/engine/datahistory_manager.md +++ b/engine/datahistory_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Datahistory manager +# GoCryptoTrader package Datahistory Manager diff --git a/engine/event_manager.md b/engine/event_manager.md index c1463284..5f9035f2 100644 --- a/engine/event_manager.md +++ b/engine/event_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Event manager +# GoCryptoTrader package Event Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Event manager +## Current Features for Event Manager + The event manager subsystem is used to push events to communication systems such as Slack + The only configurable aspects of the event manager are the delays between receiving an event and pushing it and enabling verbose: diff --git a/engine/exchange_manager.md b/engine/exchange_manager.md index f44ae370..90c1196c 100644 --- a/engine/exchange_manager.md +++ b/engine/exchange_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Exchange manager +# GoCryptoTrader package Exchange Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Exchange manager +## Current Features for Exchange Manager + The exchange manager subsystem is used load and store exchanges so that the engine Bot can use them to track orderbooks, submit orders etc etc + The exchange manager itself is not customisable, it is always enabled. + The exchange manager by default will load all exchanges that are enabled in your config, however, it will also load exchanges by request via GRPC commands diff --git a/engine/ntp_manager.md b/engine/ntp_manager.md index 81f90b19..2275c2a9 100644 --- a/engine/ntp_manager.md +++ b/engine/ntp_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Ntp manager +# GoCryptoTrader package Ntp Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Ntp manager +## Current Features for Ntp Manager + The NTP manager subsystem is used highlight discrepancies between your system time and specified NTP server times + It is useful for debugging and understanding why a request to an exchange may be rejected + The NTP manager cannot update your system clock, so when it does alert you of issues, you must take it upon yourself to change your system time in the event your requests are being rejected for being too far out of sync diff --git a/engine/order_manager.md b/engine/order_manager.md index 46dd9ca6..9621c300 100644 --- a/engine/order_manager.md +++ b/engine/order_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Order manager +# GoCryptoTrader package Order Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Order manager +## Current Features for Order Manager + The order manager subsystem stores and monitors all orders from enabled exchanges with API keys and `authenticatedSupport` enabled + It can be enabled or disabled via runtime command `-ordermanager=false` and defaults to true + All orders placed via GoCryptoTrader will be added to the order manager store diff --git a/engine/portfolio_manager.md b/engine/portfolio_manager.md index 37953ee8..68d05bbc 100644 --- a/engine/portfolio_manager.md +++ b/engine/portfolio_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Portfolio manager +# GoCryptoTrader package Portfolio Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Portfolio manager +## Current Features for Portfolio Manager + The portfolio manager subsystem is used to synchronise and monitor wallet addresses + It can read addresses specified in your config file + If you have set API keys for an enabled exchange and enabled `authenticatedSupport`, it will store your exchange addresses diff --git a/engine/subsystem_types.md b/engine/subsystem_types.md index 507bde0c..978648e7 100644 --- a/engine/subsystem_types.md +++ b/engine/subsystem_types.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Subsystem types +# GoCryptoTrader package Subsystem Types @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Subsystem types +## Current Features for Subsystem Types + Subsystem contains subsystems that are used at run time by an `engine.Engine`, however they can be setup and run individually. + Subsystems are designed to be self contained + All subsystems have a public `Setup(...) (..., error)` function to return a valid subsystem ready for use diff --git a/engine/sync_manager.md b/engine/sync_manager.md index 9ccb98ec..d60eb53d 100644 --- a/engine/sync_manager.md +++ b/engine/sync_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Sync manager +# GoCryptoTrader package Sync Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Sync manager +## Current Features for Sync Manager + The currency pair syncer subsystem is used to keep all trades, tickers and orderbooks up to date for all enabled exchange asset currency pairs + It can sync data via a websocket connection or REST and will switch between them if there has been no updates + In order to modify the behaviour of the currency pair syncer subsystem, you can change runtime parameters as detailed below: diff --git a/engine/websocketroutine_manager.go b/engine/websocketroutine_manager.go index e595f076..3434339e 100644 --- a/engine/websocketroutine_manager.go +++ b/engine/websocketroutine_manager.go @@ -11,9 +11,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/fill" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -153,7 +153,7 @@ func (m *WebsocketRoutineManager) websocketRoutine() { // WebsocketDataReceiver handles websocket data coming from a websocket feed // associated with an exchange -func (m *WebsocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error { +func (m *WebsocketRoutineManager) websocketDataReceiver(ws *websocket.Manager) error { if m == nil { return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem) } @@ -200,7 +200,7 @@ func (m *WebsocketRoutineManager) websocketDataHandler(exchName string, data any log.Infoln(log.WebsocketMgr, d) case error: return fmt.Errorf("exchange %s websocket error - %s", exchName, data) - case stream.FundingData: + case websocket.FundingData: if m.verbose { log.Infof(log.WebsocketMgr, "%s websocket %s %s funding updated %+v", exchName, @@ -244,7 +244,7 @@ func (m *WebsocketRoutineManager) websocketDataHandler(exchName string, data any } case order.Detail, ticker.Price, orderbook.Depth: return errUseAPointer - case stream.KlineData: + case websocket.KlineData: if m.verbose { log.Infof(log.WebsocketMgr, "%s websocket %s %s kline updated %+v", exchName, @@ -252,7 +252,7 @@ func (m *WebsocketRoutineManager) websocketDataHandler(exchName string, data any d.AssetType, d) } - case []stream.KlineData: + case []websocket.KlineData: for x := range d { if m.verbose { log.Infof(log.WebsocketMgr, "%s websocket %s %s kline updated %+v", @@ -333,7 +333,7 @@ func (m *WebsocketRoutineManager) websocketDataHandler(exchName string, data any } case order.ClassificationError: return fmt.Errorf("%w %s", d.Err, d.Error()) - case stream.UnhandledMessageWarning: + case websocket.UnhandledMessageWarning: log.Warnln(log.WebsocketMgr, d.Message) case account.Change: if m.verbose { diff --git a/engine/websocketroutine_manager.md b/engine/websocketroutine_manager.md index 1beb7ebf..14e2a6a9 100644 --- a/engine/websocketroutine_manager.md +++ b/engine/websocketroutine_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Websocketroutine manager +# GoCryptoTrader package Websocketroutine Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Websocketroutine manager +## Current Features for Websocketroutine Manager + The websocket routine manager subsystem is used process websocket data in a unified manner across enabled exchanges with websocket support + It can help process orders to the order manager subsystem when it receives new data + Logs output of ticker and orderbook updates diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index de7d4f48..e8fcd299 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -12,8 +12,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) func TestWebsocketRoutineManagerSetup(t *testing.T) { @@ -159,7 +159,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { if err == nil { t.Error("Error not handled correctly") } - err = m.websocketDataHandler(exchName, stream.FundingData{}) + err = m.websocketDataHandler(exchName, websocket.FundingData{}) if err != nil { t.Error(err) } @@ -171,7 +171,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { if !errors.Is(err, nil) { t.Errorf("error '%v', expected '%v'", err, nil) } - err = m.websocketDataHandler(exchName, stream.KlineData{}) + err = m.websocketDataHandler(exchName, websocket.KlineData{}) if err != nil { t.Error(err) } @@ -224,10 +224,9 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { t.Error(err) } - err = m.websocketDataHandler(exchName, stream.UnhandledMessageWarning{ + err = m.websocketDataHandler(exchName, websocket.UnhandledMessageWarning{ Message: "there's an issue here's a tissue", - }, - ) + }) if err != nil { t.Error(err) } @@ -294,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.NewWebsocket() + mock := websocket.NewManager() mock.ToRoutine = make(chan any) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/engine/withdraw_manager.md b/engine/withdraw_manager.md index 5d34f20e..009db25b 100644 --- a/engine/withdraw_manager.md +++ b/engine/withdraw_manager.md @@ -1,4 +1,4 @@ -# GoCryptoTrader package Withdraw manager +# GoCryptoTrader package Withdraw Manager @@ -18,7 +18,7 @@ You can track ideas, planned features and what's in progress on our [GoCryptoTra Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) -## Current Features for Withdraw manager +## Current Features for Withdraw Manager + The withdraw manager subsystem is responsible for the processing of withdrawal requests and submitting them to exchanges + The withdraw manager can be interacted with via GRPC commands such as `WithdrawFiatRequest` and `WithdrawCryptoRequest` + Supports caching of responses to allow for quick viewing of withdrawal events via GRPC diff --git a/exchanges/alphapoint/alphapoint.go b/exchanges/alphapoint/alphapoint.go index 2e105887..73ef53f2 100644 --- a/exchanges/alphapoint/alphapoint.go +++ b/exchanges/alphapoint/alphapoint.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/encoding/json" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" @@ -44,7 +44,7 @@ const ( // Alphapoint is the overarching type across the alphapoint package type Alphapoint struct { exchange.Base - WebsocketConn *websocket.Conn + WebsocketConn *gws.Conn } // GetTicker returns current ticker information from Alphapoint for a selected diff --git a/exchanges/alphapoint/alphapoint_websocket.go b/exchanges/alphapoint/alphapoint_websocket.go index 81173762..4a6c7204 100644 --- a/exchanges/alphapoint/alphapoint_websocket.go +++ b/exchanges/alphapoint/alphapoint_websocket.go @@ -3,7 +3,7 @@ package alphapoint import ( "net/http" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/encoding/json" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/log" @@ -16,7 +16,7 @@ const ( // WebsocketClient starts a new webstocket connection func (a *Alphapoint) WebsocketClient() { for a.Enabled { - var dialer websocket.Dialer + var dialer gws.Dialer var err error var httpResp *http.Response endpoint, err := a.API.Endpoints.GetURL(exchange.WebsocketSpot) @@ -35,7 +35,7 @@ func (a *Alphapoint) WebsocketClient() { log.Debugf(log.ExchangeSys, "%s Connected to Websocket.\n", a.Name) } - err = a.WebsocketConn.WriteMessage(websocket.TextMessage, []byte(`{"messageType": "logon"}`)) + err = a.WebsocketConn.WriteMessage(gws.TextMessage, []byte(`{"messageType": "logon"}`)) if err != nil { log.Errorln(log.ExchangeSys, err) return @@ -48,7 +48,7 @@ func (a *Alphapoint) WebsocketClient() { break } - if msgType == websocket.TextMessage { + if msgType == gws.TextMessage { type MsgType struct { MessageType string `json:"messageType"` } diff --git a/exchanges/binance/binance_test.go b/exchanges/binance/binance_test.go index bb8261eb..19ef5afd 100644 --- a/exchanges/binance/binance_test.go +++ b/exchanges/binance/binance_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -1984,12 +1984,12 @@ func TestSubscribe(t *testing.T) { require.NoError(t, err, "generateSubscriptions must not error") if mockTests { exp := []string{"btcusdt@depth@100ms", "btcusdt@kline_1m", "btcusdt@ticker", "btcusdt@trade", "dogeusdt@depth@100ms", "dogeusdt@kline_1m", "dogeusdt@ticker", "dogeusdt@trade"} - mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error { + mock := func(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() var req WsPayload require.NoError(tb, json.Unmarshal(msg, &req), "Unmarshal should not error") require.ElementsMatch(tb, req.Params, exp, "Params should have correct channels") - return w.WriteMessage(websocket.TextMessage, fmt.Appendf(nil, `{"result":null,"id":%d}`, req.ID)) + return w.WriteMessage(gws.TextMessage, fmt.Appendf(nil, `{"result":null,"id":%d}`, req.ID)) } b = testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) } else { @@ -2006,12 +2006,12 @@ func TestSubscribeBadResp(t *testing.T) { channels := subscription.List{ {Channel: "moons@ticker"}, } - mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error { + mock := func(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() var req WsPayload err := json.Unmarshal(msg, &req) require.NoError(tb, err, "Unmarshal should not error") - return w.WriteMessage(websocket.TextMessage, fmt.Appendf(nil, `{"result":{"error":"carrots"},"id":%d}`, req.ID)) + return w.WriteMessage(gws.TextMessage, fmt.Appendf(nil, `{"result":{"error":"carrots"},"id":%d}`, req.ID)) } b := testexch.MockWsInstance[Binance](t, mockws.CurryWsMockUpgrader(t, mock)) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes err := b.Subscribe(channels) diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index cc3c2ad1..3bb1429a 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -11,7 +11,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -19,10 +19,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -52,10 +52,10 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer dialer.HandshakeTimeout = b.Config.HTTPTimeout dialer.Proxy = http.ProxyFromEnvironment var err error @@ -89,9 +89,9 @@ func (b *Binance) WsConnect() error { go b.KeepAuthKeyAlive() } - b.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + b.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ UseGorillaHandler: true, - MessageType: websocket.PongMessage, + MessageType: gws.PongMessage, Delay: pingDelay, }) @@ -178,7 +178,7 @@ func (b *Binance) wsHandleData(respRaw []byte) error { } jsonData, _, _, err := jsonparser.Get(respRaw, "data") if err != nil { - return fmt.Errorf("%s %s %s", b.Name, stream.UnhandledMessage, string(respRaw)) + return fmt.Errorf("%s %s %s", b.Name, websocket.UnhandledMessage, string(respRaw)) } var event string event, err = jsonparser.GetUnsafeString(jsonData, "e") @@ -297,13 +297,13 @@ func (b *Binance) wsHandleData(respRaw []byte) error { streamStr, err := jsonparser.GetUnsafeString(respRaw, "stream") if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { - return fmt.Errorf("%s %s %s", b.Name, stream.UnhandledMessage, string(respRaw)) + return fmt.Errorf("%s %s %s", b.Name, websocket.UnhandledMessage, string(respRaw)) } return err } streamType := strings.Split(streamStr, "@") if len(streamType) <= 1 { - return fmt.Errorf("%s %s %s", b.Name, stream.UnhandledMessage, string(respRaw)) + return fmt.Errorf("%s %s %s", b.Name, websocket.UnhandledMessage, string(respRaw)) } var ( pair currency.Pair @@ -386,7 +386,7 @@ func (b *Binance) wsHandleData(respRaw []byte) error { b.Name, err) } - b.Websocket.DataHandler <- stream.KlineData{ + b.Websocket.DataHandler <- websocket.KlineData{ Timestamp: kline.EventTime.Time(), Pair: pair, AssetType: asset.Spot, @@ -421,7 +421,7 @@ func (b *Binance) wsHandleData(respRaw []byte) error { } return nil default: - return fmt.Errorf("%s %s %s", b.Name, stream.UnhandledMessage, string(respRaw)) + return fmt.Errorf("%s %s %s", b.Name, websocket.UnhandledMessage, string(respRaw)) } } diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index beff82b9..241bf706 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -27,11 +27,11 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -196,7 +196,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } @@ -217,7 +217,7 @@ func (b *Binance) Setup(exch *config.Exchange) error { if err != nil { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: binanceDefaultWebsocketURL, RunningURL: ePoint, @@ -236,7 +236,7 @@ func (b *Binance) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(250 * time.Millisecond), diff --git a/exchanges/binanceus/binanceus.go b/exchanges/binanceus/binanceus.go index cdc46dc9..2faf63f8 100644 --- a/exchanges/binanceus/binanceus.go +++ b/exchanges/binanceus/binanceus.go @@ -1852,7 +1852,7 @@ func (bi *Binanceus) SendAuthHTTPRequest(ctx context.Context, ePath exchange.URL // GetWsAuthStreamKey this method 'Creates User Data Stream' will retrieve a key to use for authorised WS streaming // Same as that of Binance -// Start a new user data stream. The stream will close after 60 minutes unless a keepalive is sent. +// Start a new user data websocket. The stream will close after 60 minutes unless a keepalive is sent. // If the account has an active listenKey, // that listenKey will be returned and its validity will be extended for 60 minutes. func (bi *Binanceus) GetWsAuthStreamKey(ctx context.Context) (string, error) { @@ -1928,7 +1928,7 @@ func (bi *Binanceus) MaintainWsAuthStreamKey(ctx context.Context) error { }, request.AuthenticatedRequest) } -// CloseUserDataStream Close out a user data stream. +// CloseUserDataStream Close out a user data websocket. func (bi *Binanceus) CloseUserDataStream(ctx context.Context) error { endpointPath, err := bi.API.Endpoints.GetURL(exchange.RestSpotSupplementary) if err != nil { diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8b2515d4..df2bb03c 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -9,17 +9,17 @@ import ( "strings" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -46,9 +46,9 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout dialer.Proxy = http.ProxyFromEnvironment var err error @@ -82,9 +82,9 @@ func (bi *Binanceus) WsConnect() error { go bi.KeepAuthKeyAlive() } - bi.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + bi.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ UseGorillaHandler: true, - MessageType: websocket.PongMessage, + MessageType: gws.PongMessage, Delay: pingDelay, }) @@ -400,7 +400,7 @@ func (bi *Binanceus) wsHandleData(respRaw []byte) error { return err } - bi.Websocket.DataHandler <- stream.KlineData{ + bi.Websocket.DataHandler <- websocket.KlineData{ Timestamp: kline.EventTime, Pair: pair, AssetType: asset.Spot, @@ -465,8 +465,8 @@ func (bi *Binanceus) wsHandleData(respRaw []byte) error { bi.Websocket.DataHandler <- agg return nil default: - bi.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: bi.Name + stream.UnhandledMessage + string(respRaw), + bi.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: bi.Name + websocket.UnhandledMessage + string(respRaw), } } } diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 8a686fe1..7ba008f3 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -22,10 +22,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -141,7 +141,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.NewWebsocket() + bi.Websocket = websocket.NewManager() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -167,7 +167,7 @@ func (bi *Binanceus) Setup(exch *config.Exchange) error { return err } - err = bi.Websocket.Setup(&stream.WebsocketSetup{ + err = bi.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: binanceusDefaultWebsocketURL, RunningURL: ePoint, @@ -186,7 +186,7 @@ func (bi *Binanceus) Setup(exch *config.Exchange) error { return err } - return bi.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return bi.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(300 * time.Millisecond), diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index 7ea1fb25..1c9dffbf 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -23,10 +23,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -1105,7 +1105,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWSAuth dials websocket, sends login request. func TestWSAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { @@ -1333,8 +1333,8 @@ func TestWSSubscribedResponse(t *testing.T) { assert.NoError(t, err, "Setting a matcher should not error") err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`)) if assert.Error(t, err, "Should error if sub is not registered yet") { - assert.ErrorIs(t, err, stream.ErrSubscriptionFailure, "Should error SubFailure if sub isn't registered yet") - assert.ErrorIs(t, err, stream.ErrSubscriptionFailure, "Should error SubNotFound if sub isn't registered yet") + assert.ErrorIs(t, err, websocket.ErrSubscriptionFailure, "Should error SubFailure if sub isn't registered yet") + assert.ErrorIs(t, err, subscription.ErrNotFound, "Should error SubNotFound if sub isn't registered yet") assert.ErrorContains(t, err, "waiter1", "Should error containing subID if") } diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index ed27bad9..575d4a7b 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -16,7 +16,7 @@ import ( "github.com/Masterminds/sprig/v3" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/common/crypto" @@ -27,10 +27,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -100,7 +100,7 @@ var defaultSubscriptions = subscription.List{ {Enabled: true, Channel: subscription.OrderbookChannel, Asset: asset.All, Levels: 100, Params: map[string]any{"prec": "R0"}}, } -var comms = make(chan stream.Response) +var comms = make(chan websocket.Response) type checksum struct { Token uint32 @@ -123,9 +123,9 @@ var subscriptionNames = map[string]string{ // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return fmt.Errorf("%v unable to connect to Websocket. Error: %s", @@ -162,7 +162,7 @@ func (b *Bitfinex) WsConnect() error { } // wsReadData receives and passes on websocket messages for processing -func (b *Bitfinex) wsReadData(ws stream.Connection) { +func (b *Bitfinex) wsReadData(ws websocket.Connection) { defer b.Websocket.Wg.Done() for { resp := ws.ReadMessage() @@ -193,7 +193,7 @@ func (b *Bitfinex) WsDataHandler() { } return case resp := <-comms: - if resp.Type != websocket.TextMessage { + if resp.Type != gws.TextMessage { continue } err := b.wsHandleData(resp.Raw) @@ -494,8 +494,8 @@ func (b *Bitfinex) wsHandleData(respRaw []byte) error { b.Websocket.DataHandler <- wsFundingTrade } default: - b.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: b.Name + stream.UnhandledMessage + string(respRaw), + b.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: b.Name + websocket.UnhandledMessage + string(respRaw), } return nil } @@ -581,12 +581,12 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { c := b.Websocket.GetSubscription(subID) if c == nil { - return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, subscription.ErrNotFound, subID) + return fmt.Errorf("%w: %w subID: %s", websocket.ErrSubscriptionFailure, subscription.ErrNotFound, subID) } chanID, err := jsonparser.GetInt(respRaw, "chanId") if err != nil { - return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, common.ErrParsingWSField, err, c.Channel, c.Pairs) + return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", websocket.ErrSubscriptionFailure, common.ErrParsingWSField, err, c.Channel, c.Pairs) } // Note: chanID's int type avoids conflicts with the string type subID key because of the type difference @@ -596,7 +596,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { // subscribeToChan removes the old subID keyed Subscription err = b.Websocket.AddSuccessfulSubscriptions(b.Websocket.Conn, c) if err != nil { - return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID) + return fmt.Errorf("%w: %w subID: %s", websocket.ErrSubscriptionFailure, err, subID) } if b.Verbose { @@ -799,7 +799,7 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []any) e return errors.New("invalid candleBundle length") } var err error - var klineData stream.KlineData + var klineData websocket.KlineData if klineData.Timestamp, err = convert.TimeFromUnixTimestampFloat(element[0]); err != nil { return fmt.Errorf("unable to convert candle timestamp: %w", err) } @@ -828,7 +828,7 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []any) e return errors.New("invalid candleBundle length") } var err error - var klineData stream.KlineData + var klineData websocket.KlineData if klineData.Timestamp, err = convert.TimeFromUnixTimestampFloat(candleData); err != nil { return fmt.Errorf("unable to convert candle timestamp: %w", err) } diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 0ad05652..d69590bc 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -25,10 +25,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -167,7 +167,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -193,7 +193,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: publicBitfinexWebsocketEndpoint, RunningURL: wsEndpoint, @@ -210,7 +210,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: publicBitfinexWebsocketEndpoint, @@ -219,7 +219,7 @@ func (b *Bitfinex) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: authenticatedBitfinexWebsocketEndpoint, diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index c6e1bccf..7fc953d7 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -9,17 +9,17 @@ import ( "time" "github.com/Masterminds/sprig/v3" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -37,10 +37,10 @@ var defaultSubscriptions = subscription.List{ // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer dialer.HandshakeTimeout = b.Config.HTTPTimeout dialer.Proxy = http.ProxyFromEnvironment @@ -92,7 +92,7 @@ func (b *Bithumb) wsHandleData(respRaw []byte) error { } return fmt.Errorf("%s: %w", resp.ResponseMessage, - stream.ErrSubscriptionFailure) + websocket.ErrSubscriptionFailure) } switch resp.Type { diff --git a/exchanges/bithumb/bithumb_websocket_test.go b/exchanges/bithumb/bithumb_websocket_test.go index 39fbb93a..6498de85 100644 --- a/exchanges/bithumb/bithumb_websocket_test.go +++ b/exchanges/bithumb/bithumb_websocket_test.go @@ -1,7 +1,6 @@ package bithumb import ( - "errors" "testing" "time" @@ -11,9 +10,9 @@ import ( exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" ) @@ -27,12 +26,7 @@ var ( func TestWsHandleData(t *testing.T) { t.Parallel() - pairs := currency.Pairs{ - currency.Pair{ - Base: currency.BTC, - Quote: currency.USDT, - }, - } + pairs := currency.Pairs{currency.NewBTCUSDT()} dummy := Bithumb{ location: time.Local, @@ -53,9 +47,7 @@ func TestWsHandleData(t *testing.T) { }, }, }, - Websocket: &stream.Websocket{ - DataHandler: make(chan any, 1), - }, + Websocket: websocket.NewManager(), }, } @@ -64,36 +56,20 @@ func TestWsHandleData(t *testing.T) { welcomeMsg := []byte(`{"status":"0000","resmsg":"Connected Successfully"}`) err := dummy.wsHandleData(welcomeMsg) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = dummy.wsHandleData([]byte(`{"status":"1336","resmsg":"Failed"}`)) - if !errors.Is(err, stream.ErrSubscriptionFailure) { - t.Fatalf("received: %v but expected: %v", - err, - stream.ErrSubscriptionFailure) - } + require.ErrorIs(t, err, websocket.ErrSubscriptionFailure) + + err = dummy.wsHandleData(wsTransResp) + require.NoError(t, err) + + err = dummy.wsHandleData(wsOrderbookResp) + require.NoError(t, err) err = dummy.wsHandleData(wsTickerResp) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } - - handled := <-dummy.Websocket.DataHandler - if _, ok := handled.(*ticker.Price); !ok { - t.Fatal("unexpected value") - } - - err = dummy.wsHandleData(wsTransResp) // This doesn't pipe to datahandler - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } - - err = dummy.wsHandleData(wsOrderbookResp) // This doesn't pipe to datahandler - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + require.NoError(t, err) + assert.IsType(t, new(ticker.Price), <-dummy.Websocket.DataHandler, "ticker should send a price to the DataHandler") } func TestSubToReq(t *testing.T) { diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 3b17ef50..60a9c8f0 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -25,9 +25,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -131,7 +131,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } @@ -155,7 +155,7 @@ func (b *Bithumb) Setup(exch *config.Exchange) error { if err != nil { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: wsEndpoint, RunningURL: ePoint, @@ -168,7 +168,7 @@ func (b *Bithumb) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(time.Second), diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index f0f102e6..895ef758 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -23,8 +23,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -662,9 +662,9 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } - var dialer websocket.Dialer + var dialer gws.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) require.NoError(t, err) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 0d667158..4bc15ddc 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -11,7 +11,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -20,9 +20,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -84,9 +84,9 @@ var defaultSubscriptions = subscription.List{ // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer if err := b.Websocket.Conn.Dial(&dialer, http.Header{}); err != nil { return err } @@ -119,7 +119,7 @@ const ( wsClosePacket = 2 ) -func (b *Bitmex) wsOpenStream(ctx context.Context, c stream.Connection, name string) error { +func (b *Bitmex) wsOpenStream(ctx context.Context, c websocket.Connection, name string) error { resp, err := c.SendMessageReturnResponse(ctx, request.Unset, "open:"+name, []any{wsOpenPacket, name, name}) if err != nil { return err @@ -401,7 +401,7 @@ func (b *Bitmex) wsHandleData(respRaw []byte) error { } b.Websocket.DataHandler <- response default: - b.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: b.Name + stream.UnhandledMessage + string(msg)} + b.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: b.Name + websocket.UnhandledMessage + string(msg)} } return nil diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index adb21ab7..b929a8a2 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -26,10 +26,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -139,7 +139,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -165,7 +165,7 @@ func (b *Bitmex) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: bitmexWSURL, RunningURL: wsEndpoint, @@ -181,7 +181,7 @@ func (b *Bitmex) Setup(exch *config.Exchange) error { if err != nil { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: bitmexWSURL, diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index e79cd142..6e2de693 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -11,7 +11,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -21,9 +21,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -57,9 +57,9 @@ var subscriptionNames = map[string]string{ // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -67,8 +67,8 @@ func (b *Bitstamp) WsConnect() error { if b.Verbose { log.Debugf(log.ExchangeSys, "%s Connected to Websocket.\n", b.Name) } - b.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + b.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: hbMsg, Delay: hbInterval, }) @@ -123,7 +123,7 @@ func (b *Bitstamp) wsHandleData(respRaw []byte) error { } }() default: - b.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: b.Name + stream.UnhandledMessage + string(respRaw)} + b.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: b.Name + websocket.UnhandledMessage + string(respRaw)} } return nil } diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index ec54ad90..92e5da7a 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -23,9 +23,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -124,7 +124,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -150,7 +150,7 @@ func (b *Bitstamp) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: bitstampWSURL, RunningURL: wsURL, @@ -164,7 +164,7 @@ func (b *Bitstamp) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: b.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 37c70df3..21ac9296 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -11,7 +11,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -20,10 +20,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -57,9 +57,9 @@ var subscriptionNames = map[string]string{ // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -352,7 +352,7 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error { } return fmt.Errorf("%v websocket error. Code: %v Message: %v", b.Name, wsErr.Code, wsErr.Message) default: - b.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: b.Name + stream.UnhandledMessage + string(respRaw)} + b.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: b.Name + websocket.UnhandledMessage + string(respRaw)} return nil } return nil diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 57ddfa7c..3386418d 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -24,10 +24,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -128,7 +128,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -154,7 +154,7 @@ func (b *BTCMarkets) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: btcMarketsWSURL, RunningURL: wsURL, @@ -173,7 +173,7 @@ func (b *BTCMarkets) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/btse/btse_test.go b/exchanges/btse/btse_test.go index d643caf4..e7f6344a 100644 --- a/exchanges/btse/btse_test.go +++ b/exchanges/btse/btse_test.go @@ -21,10 +21,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" ) @@ -662,7 +662,7 @@ func TestWsUnexpectedData(t *testing.T) { t.Parallel() data := []byte(`{}`) err := b.wsHandleData(data) - assert.ErrorContains(t, err, stream.UnhandledMessage, "wsHandleData should error on empty message") + assert.ErrorContains(t, err, websocket.UnhandledMessage, "wsHandleData should error on empty message") } func TestGetFuturesContractDetails(t *testing.T) { diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 66624cc2..967f7260 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -9,7 +9,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -17,9 +17,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -41,15 +41,15 @@ var defaultSubscriptions = subscription.List{ // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err } - b.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.PingMessage, + b.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.PingMessage, Delay: btseWebsocketTimer, }) @@ -146,7 +146,7 @@ func (b *BTSE) wsHandleData(respRaw []byte) error { if result["event"] != nil { event, ok := result["event"].(string) if !ok { - return errors.New(b.Name + stream.UnhandledMessage + string(respRaw)) + return errors.New(b.Name + websocket.UnhandledMessage + string(respRaw)) } switch event { case "subscribe": @@ -169,14 +169,14 @@ func (b *BTSE) wsHandleData(respRaw []byte) error { log.Infof(log.WebsocketMgr, "%v websocket authenticated: %v", b.Name, login.Success) } default: - return errors.New(b.Name + stream.UnhandledMessage + string(respRaw)) + return errors.New(b.Name + websocket.UnhandledMessage + string(respRaw)) } return nil } topic, ok := result["topic"].(string) if !ok { - return errors.New(b.Name + stream.UnhandledMessage + string(respRaw)) + return errors.New(b.Name + websocket.UnhandledMessage + string(respRaw)) } switch { case topic == "notificationApi": @@ -352,7 +352,7 @@ func (b *BTSE) wsHandleData(respRaw []byte) error { return err } default: - return errors.New(b.Name + stream.UnhandledMessage + string(respRaw)) + return errors.New(b.Name + websocket.UnhandledMessage + string(respRaw)) } return nil diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index a49f3866..db8b495d 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -25,9 +25,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -144,7 +144,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -170,7 +170,7 @@ func (b *BTSE) Setup(exch *config.Exchange) error { return err } - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: btseWebsocket, RunningURL: wsRunningURL, @@ -184,7 +184,7 @@ func (b *BTSE) Setup(exch *config.Exchange) error { return err } - return b.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return b.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 4c43490d..890dea97 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -4,27 +4,27 @@ import ( "context" "net/http" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) - var dialer websocket.Dialer + var dialer gws.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + by.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: []byte(`{"op": "ping"}`), Delay: bybitWebsocketTimer, }) diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index d680ff61..8f8d5b52 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -4,27 +4,27 @@ import ( "context" "net/http" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) - var dialer websocket.Dialer + var dialer gws.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + by.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: []byte(`{"op": "ping"}`), Delay: bybitWebsocketTimer, }) diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 9dcd2175..dfffe3dd 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -5,22 +5,22 @@ import ( "net/http" "strconv" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) - var dialer websocket.Dialer + var dialer gws.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -30,8 +30,8 @@ func (by *Bybit) WsOptionsConnect() error { if err != nil { return err } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + by.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: pingData, Delay: bybitWebsocketTimer, }) diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 428d39fe..7c7c6703 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -25,9 +25,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" testws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" @@ -3177,7 +3177,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { + if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3188,7 +3188,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { + if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3199,7 +3199,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { + if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3813,7 +3813,7 @@ func TestAuthSubscribe(t *testing.T) { require.NoError(t, err, "ExpandTemplates must not error") b.Features.Subscriptions = subscription.List{} success := true - mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error { + mock := func(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() var req SubscriptionArgument require.NoError(tb, json.Unmarshal(msg, &req), "Unmarshal must not error") @@ -3825,7 +3825,7 @@ func TestAuthSubscribe(t *testing.T) { Operation: req.Operation, }) require.NoError(tb, err, "Marshal must not error") - return w.WriteMessage(websocket.TextMessage, msg) + return w.WriteMessage(gws.TextMessage, msg) } b = testexch.MockWsInstance[Bybit](t, testws.CurryWsMockUpgrader(t, mock)) b.Websocket.AuthConn = b.Websocket.Conn diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 24411de5..b0d0f0ec 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -9,7 +9,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -21,10 +21,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -80,15 +80,15 @@ var subscriptionNames = map[string]string{ // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + by.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: []byte(`{"op": "ping"}`), Delay: bybitWebsocketTimer, }) @@ -107,14 +107,14 @@ func (by *Bybit) WsConnect() error { // WsAuth sends an authentication message to receive auth data func (by *Bybit) WsAuth(ctx context.Context) error { - var dialer websocket.Dialer + var dialer gws.Dialer err := by.Websocket.AuthConn.Dial(&dialer, http.Header{}) if err != nil { return err } - by.Websocket.AuthConn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + by.Websocket.AuthConn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: []byte(`{"op":"ping"}`), Delay: bybitWebsocketTimer, }) @@ -214,7 +214,7 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su return fmt.Errorf("%s with request ID %s msg: %s", resp.Operation, resp.RequestID, resp.RetMsg) } - var conn stream.Connection + var conn websocket.Connection if payloads[a].auth { conn = by.Websocket.AuthConn } else { @@ -250,7 +250,7 @@ func (by *Bybit) GetSubscriptionTemplate(_ *subscription.Subscription) (*templat } // wsReadData receives and passes on websocket messages for processing -func (by *Bybit) wsReadData(assetType asset.Item, ws stream.Connection) { +func (by *Bybit) wsReadData(assetType asset.Item, ws websocket.Connection) { defer by.Websocket.Wg.Done() for { select { @@ -285,7 +285,7 @@ func (by *Bybit) wsHandleData(assetType asset.Item, respRaw []byte) error { } case "ping", "pong": default: - by.Websocket.DataHandler <- stream.UnhandledMessageWarning{ + by.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ Message: string(respRaw), } return nil @@ -487,13 +487,13 @@ func (by *Bybit) wsProcessLeverageTokenKline(assetType asset.Item, resp *Websock if err != nil { return err } - ltKline := make([]stream.KlineData, len(result)) + ltKline := make([]websocket.KlineData, len(result)) for x := range result { interval, err := stringToInterval(result[x].Interval) if err != nil { return err } - ltKline[x] = stream.KlineData{ + ltKline[x] = websocket.KlineData{ Timestamp: result[x].Timestamp.Time(), Pair: cp, AssetType: assetType, @@ -531,13 +531,13 @@ func (by *Bybit) wsProcessKline(assetType asset.Item, resp *WebsocketResponse, t if err != nil { return err } - spotCandlesticks := make([]stream.KlineData, len(result)) + spotCandlesticks := make([]websocket.KlineData, len(result)) for x := range result { interval, err := stringToInterval(result[x].Interval) if err != nil { return err } - spotCandlesticks[x] = stream.KlineData{ + spotCandlesticks[x] = websocket.KlineData{ Timestamp: result[x].Timestamp.Time(), Pair: cp, AssetType: assetType, diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 0755ad37..ff5ce76e 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -25,10 +25,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -206,7 +206,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.NewWebsocket() + by.Websocket = websocket.NewManager() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -234,7 +234,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { } err = by.Websocket.Setup( - &stream.WebsocketSetup{ + &websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: spotPublic, RunningURL: wsRunningEndpoint, @@ -253,7 +253,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { if err != nil { return err } - err = by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = by.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: by.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: bybitWebsocketTimer, @@ -262,7 +262,7 @@ func (by *Bybit) Setup(exch *config.Exchange) error { return err } - return by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return by.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: websocketPrivate, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index f2cb6ef2..663d719b 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -20,8 +20,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" "github.com/thrasher-corp/gocryptotrader/portfolio/banking" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -662,9 +662,9 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } - var dialer websocket.Dialer + var dialer gws.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) require.NoError(t, err, "Dial must not error") go c.wsReadData() diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index c7fb88c4..4f04e318 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -8,7 +8,7 @@ import ( "strconv" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -17,10 +17,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -30,9 +30,9 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -253,7 +253,7 @@ func (c *CoinbasePro) wsHandleData(respRaw []byte) error { }) } default: - c.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: c.Name + stream.UnhandledMessage + string(respRaw)} + c.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: c.Name + websocket.UnhandledMessage + string(respRaw)} return nil } return nil diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 34dcdde6..b76abb97 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -21,11 +21,11 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -130,7 +130,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.NewWebsocket() + c.Websocket = websocket.NewManager() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -156,7 +156,7 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error { return err } - err = c.Websocket.Setup(&stream.WebsocketSetup{ + err = c.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: coinbaseproWebsocketURL, RunningURL: wsRunningURL, @@ -174,7 +174,7 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index a1e58c0e..33b57617 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -20,7 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -71,13 +71,13 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) } - var dialer websocket.Dialer + var dialer gws.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { t.Fatal(err) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 22de4c79..243b6e5e 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -10,7 +10,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -19,10 +19,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -41,9 +41,9 @@ var channels map[string]chan []byte // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -361,7 +361,7 @@ func (c *COINUT) wsHandleData(_ context.Context, respRaw []byte) error { } c.Websocket.DataHandler <- o default: - c.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: c.Name + stream.UnhandledMessage + string(respRaw)} + c.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: c.Name + websocket.UnhandledMessage + string(respRaw)} return nil } return nil diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index b20dbbd5..cee815d2 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -23,10 +23,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -104,7 +104,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.NewWebsocket() + c.Websocket = websocket.NewManager() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -130,7 +130,7 @@ func (c *COINUT) Setup(exch *config.Exchange) error { return err } - err = c.Websocket.Setup(&stream.WebsocketSetup{ + err = c.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: coinutWebsocketURL, RunningURL: wsRunningURL, @@ -148,7 +148,7 @@ func (c *COINUT) Setup(exch *config.Exchange) error { return err } - return c.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return c.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewWeightedRateLimitByDuration(33 * time.Millisecond), diff --git a/exchanges/deribit/deribit_websocket.go b/exchanges/deribit/deribit_websocket.go index 1f3f2dc7..92cda72a 100644 --- a/exchanges/deribit/deribit_websocket.go +++ b/exchanges/deribit/deribit_websocket.go @@ -10,7 +10,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -21,10 +21,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -124,9 +124,9 @@ var ( // WsConnect starts a new connection with the websocket API func (d *Deribit) WsConnect() error { if !d.Websocket.IsEnabled() || !d.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := d.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -299,8 +299,8 @@ func (d *Deribit) wsHandleData(respRaw []byte) error { case "trades": return d.processTrades(respRaw, channels) default: - d.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: d.Name + stream.UnhandledMessage + string(respRaw), + d.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: d.Name + websocket.UnhandledMessage + string(respRaw), } return nil } @@ -312,8 +312,8 @@ func (d *Deribit) wsHandleData(respRaw []byte) error { return nil } default: - d.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: d.Name + stream.UnhandledMessage + string(respRaw), + d.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: d.Name + websocket.UnhandledMessage + string(respRaw), } return nil } @@ -626,7 +626,7 @@ func (d *Deribit) processCandleChart(respRaw []byte, channels []string) error { if err != nil { return err } - d.Websocket.DataHandler <- stream.KlineData{ + d.Websocket.DataHandler <- websocket.KlineData{ Timestamp: time.UnixMilli(candleData.Tick), Pair: cp, AssetType: a, @@ -839,7 +839,7 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro subAck[c] = true } if len(subAck) != len(subs) { - err = stream.ErrSubscriptionFailure + err = websocket.ErrSubscriptionFailure } for _, s := range subs { if _, ok := subAck[s.QualifiedChannel]; ok { diff --git a/exchanges/deribit/deribit_wrapper.go b/exchanges/deribit/deribit_wrapper.go index 988ca210..501c6c60 100644 --- a/exchanges/deribit/deribit_wrapper.go +++ b/exchanges/deribit/deribit_wrapper.go @@ -25,10 +25,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -150,7 +150,7 @@ func (d *Deribit) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - d.Websocket = stream.NewWebsocket() + d.Websocket = websocket.NewManager() d.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit d.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout d.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -170,7 +170,7 @@ func (d *Deribit) Setup(exch *config.Exchange) error { if err != nil { return err } - err = d.Websocket.Setup(&stream.WebsocketSetup{ + err = d.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: deribitWebsocketAddress, RunningURL: deribitWebsocketAddress, @@ -188,7 +188,7 @@ func (d *Deribit) Setup(exch *config.Exchange) error { return err } - return d.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return d.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: d.Websocket.GetWebsocketURL(), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/deribit/deribit_ws_endpoints.go b/exchanges/deribit/deribit_ws_endpoints.go index 09c56c79..a8df3c75 100644 --- a/exchanges/deribit/deribit_ws_endpoints.go +++ b/exchanges/deribit/deribit_ws_endpoints.go @@ -2319,7 +2319,7 @@ func (d *Deribit) WSMovePositions(ccy currency.Code, sourceSubAccountUID, target return resp, d.SendWSRequest(nonMatchingEPL, movePositions, input, &resp, true) } -// WsSimulateBlockTrade checks if a block trade can be executed through the websocket stream. +// WsSimulateBlockTrade checks if a block trade can be executed through the websocket func (d *Deribit) WsSimulateBlockTrade(role string, trades []BlockTradeParam) (bool, error) { if role != roleMaker && role != roleTaker { return false, errInvalidTradeRole diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 7d46edaa..96a68fa8 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -31,10 +31,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/banking" ) @@ -1056,7 +1056,7 @@ func (b *Base) SetGlobalPairsManager(request, config *currency.PairFormat, asset } // GetWebsocket returns a pointer to the exchange websocket -func (b *Base) GetWebsocket() (*stream.Websocket, error) { +func (b *Base) GetWebsocket() (*websocket.Manager, error) { if b.Websocket == nil { return nil, common.ErrFunctionNotSupported } @@ -1124,7 +1124,7 @@ func (b *Base) AuthenticateWebsocket(_ context.Context) error { } // CanUseAuthenticatedWebsocketEndpoints calls b.Websocket.CanUseAuthenticatedEndpoints -// Used to avoid import cycles on stream.websocket +// Used to avoid import cycles on websocket.Manager func (b *Base) CanUseAuthenticatedWebsocketEndpoints() bool { return b.Websocket != nil && b.Websocket.CanUseAuthenticatedEndpoints() } @@ -1601,7 +1601,7 @@ func (b *Base) GetKlineExtendedRequest(pair currency.Pair, a asset.Item, interva func (b *Base) Shutdown() error { if b.Websocket != nil { err := b.Websocket.Shutdown() - if err != nil && !errors.Is(err, stream.ErrNotConnected) { + if err != nil && !errors.Is(err, websocket.ErrNotConnected) { return err } } diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index 79c775d5..fd009f22 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -25,10 +25,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/banking" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -206,7 +206,7 @@ func TestSetClientProxyAddress(t *testing.T) { Requester: requester, } - newBase.Websocket = stream.NewWebsocket() + newBase.Websocket = websocket.NewManager() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -866,9 +866,9 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.NewWebsocket() + b.Websocket = websocket.NewManager() b.Features.Supports.Websocket = true - err = b.Websocket.Setup(&stream.WebsocketSetup{ + err = b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: &config.Exchange{ WebsocketTrafficTimeout: time.Second * 30, Name: "test", @@ -1193,8 +1193,8 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.NewWebsocket() - err := b.Websocket.Setup(&stream.WebsocketSetup{ + b.Websocket = websocket.NewManager() + err := b.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, WebsocketTrafficTimeout: time.Second * 30, @@ -1603,7 +1603,7 @@ func TestGetWebsocket(t *testing.T) { if err == nil { t.Fatal("error cannot be nil") } - b.Websocket = &stream.Websocket{} + b.Websocket = websocket.NewManager() _, err = b.GetWebsocket() if err != nil { t.Fatal(err) @@ -1617,7 +1617,7 @@ func TestFlushWebsocketChannels(t *testing.T) { t.Fatal(err) } - b.Websocket = &stream.Websocket{} + b.Websocket = websocket.NewManager() err = b.FlushWebsocketChannels() if err == nil { t.Fatal(err) @@ -1631,7 +1631,7 @@ func TestSubscribeToWebsocketChannels(t *testing.T) { t.Fatal(err) } - b.Websocket = &stream.Websocket{} + b.Websocket = websocket.NewManager() err = b.SubscribeToWebsocketChannels(nil) if err == nil { t.Fatal(err) @@ -1643,7 +1643,7 @@ func TestUnsubscribeToWebsocketChannels(t *testing.T) { err := b.UnsubscribeToWebsocketChannels(nil) assert.ErrorIs(t, err, common.ErrFunctionNotSupported, "UnsubscribeToWebsocketChannels should error correctly with a nil Websocket") - b.Websocket = &stream.Websocket{} + b.Websocket = websocket.NewManager() err = b.UnsubscribeToWebsocketChannels(nil) assert.NoError(t, err, "UnsubscribeToWebsocketChannels from an empty/nil list should not error") } @@ -1655,7 +1655,7 @@ func TestGetSubscriptions(t *testing.T) { t.Fatal(err) } - b.Websocket = &stream.Websocket{} + b.Websocket = websocket.NewManager() _, err = b.GetSubscriptions() if err != nil { t.Fatal(err) @@ -2897,7 +2897,7 @@ func TestCanUseAuthenticatedWebsocketEndpoints(t *testing.T) { t.Parallel() e := &FakeBase{} assert.False(t, e.CanUseAuthenticatedWebsocketEndpoints(), "CanUseAuthenticatedWebsocketEndpoints should return false with nil websocket") - e.Websocket = stream.NewWebsocket() + e.Websocket = websocket.NewManager() assert.False(t, e.CanUseAuthenticatedWebsocketEndpoints()) e.Websocket.SetCanUseAuthenticatedEndpoints(true) assert.True(t, e.CanUseAuthenticatedWebsocketEndpoints()) diff --git a/exchanges/exchange_types.go b/exchanges/exchange_types.go index a967c69c..d1d557af 100644 --- a/exchanges/exchange_types.go +++ b/exchanges/exchange_types.go @@ -13,8 +13,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) // Endpoint authentication types @@ -246,7 +246,7 @@ type Base struct { WebsocketResponseCheckTimeout time.Duration WebsocketResponseMaxLimit time.Duration WebsocketOrderbookBufferLimit int64 - Websocket *stream.Websocket + Websocket *websocket.Manager *request.Requester Config *config.Exchange settingsMutex sync.RWMutex diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index d385ebe0..5c27d2e7 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -26,8 +26,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -3520,7 +3520,7 @@ func TestGenerateWebsocketMessageID(t *testing.T) { require.NotEmpty(t, g.GenerateWebsocketMessageID(false)) } -type DummyConnection struct{ stream.Connection } +type DummyConnection struct{ websocket.Connection } func (d *DummyConnection) GenerateMessageID(bool) int64 { return 1337 } func (d *DummyConnection) SendMessageReturnResponse(context.Context, request.EndpointLimit, any, any) ([]byte, error) { @@ -3532,12 +3532,12 @@ func TestHandleSubscriptions(t *testing.T) { subs := subscription.List{{Channel: subscription.OrderbookChannel}} - err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { + err := g.handleSubscription(context.Background(), &DummyConnection{}, subscribeEvent, subs, func(context.Context, websocket.Connection, string, subscription.List) ([]WsInput, error) { return []WsInput{{}}, nil }) require.NoError(t, err) - err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(context.Context, stream.Connection, string, subscription.List) ([]WsInput, error) { + err = g.handleSubscription(context.Background(), &DummyConnection{}, unsubscribeEvent, subs, func(context.Context, websocket.Connection, string, subscription.List) ([]WsInput, error) { return []WsInput{{}}, nil }) require.NoError(t, err) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c9d4f8e4..11ff1a5d 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -15,7 +15,7 @@ import ( "github.com/Masterminds/sprig/v3" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -25,10 +25,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -76,12 +76,12 @@ var subscriptionNames = map[string]string{ var standardMarginAssetTypes = []asset.Item{asset.Spot, asset.Margin, asset.CrossMargin} // WsConnectSpot initiates a websocket connection -func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) error { +func (g *Gateio) WsConnectSpot(ctx context.Context, conn websocket.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { return err } - err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &gws.Dialer{}, http.Header{}) if err != nil { return err } @@ -89,22 +89,22 @@ func (g *Gateio) WsConnectSpot(ctx context.Context, conn stream.Connection) erro if err != nil { return err } - conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, websocket.PingHandler{ Websocket: true, Delay: time.Second * 15, Message: pingMessage, - MessageType: websocket.TextMessage, + MessageType: gws.TextMessage, }) return nil } // authenticateSpot sends an authentication message to the websocket connection -func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) error { +func (g *Gateio) authenticateSpot(ctx context.Context, conn websocket.Connection) error { return g.websocketLogin(ctx, conn, "spot.login") } // websocketLogin authenticates the websocket connection -func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) error { +func (g *Gateio) websocketLogin(ctx context.Context, conn websocket.Connection, channel string) error { if conn == nil { return fmt.Errorf("%w: %T", common.ErrNilPointer, conn) } @@ -209,10 +209,10 @@ func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { return g.processCrossMarginLoans(respRaw) case spotPongChannel: default: - g.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: g.Name + stream.UnhandledMessage + string(respRaw), + g.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: g.Name + websocket.UnhandledMessage + string(respRaw), } - return errors.New(stream.UnhandledMessage) + return errors.New(websocket.UnhandledMessage) } return nil } @@ -333,10 +333,10 @@ func (g *Gateio) processCandlestick(incoming []byte) error { return err } - out := make([]stream.KlineData, 0, len(standardMarginAssetTypes)) + out := make([]websocket.KlineData, 0, len(standardMarginAssetTypes)) for _, a := range standardMarginAssetTypes { if enabled, _ := g.CurrencyPairs.IsPairEnabled(currencyPair, a); enabled { - out = append(out, stream.KlineData{ + out = append(out, websocket.KlineData{ Pair: currencyPair, AssetType: a, Exchange: g.Name, @@ -670,7 +670,7 @@ func (g *Gateio) GetSubscriptionTemplate(_ *subscription.Subscription) (*templat } // manageSubs sends a websocket message to subscribe or unsubscribe from a list of channel -func (g *Gateio) manageSubs(ctx context.Context, event string, conn stream.Connection, subs subscription.List) error { +func (g *Gateio) manageSubs(ctx context.Context, event string, conn websocket.Connection, subs subscription.List) error { var errs error subs, errs = subs.ExpandTemplates(g) if errs != nil { @@ -706,7 +706,7 @@ func (g *Gateio) manageSubs(ctx context.Context, event string, conn stream.Conne } // manageSubReq constructs the subscription management message for a subscription -func (g *Gateio) manageSubReq(ctx context.Context, event string, conn stream.Connection, s *subscription.Subscription) (*WsInput, error) { +func (g *Gateio) manageSubReq(ctx context.Context, event string, conn websocket.Connection, s *subscription.Subscription) (*WsInput, error) { req := &WsInput{ ID: conn.GenerateMessageID(false), Event: event, @@ -733,12 +733,12 @@ func (g *Gateio) manageSubReq(ctx context.Context, event string, conn stream.Con } // Subscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Subscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { +func (g *Gateio) Subscribe(ctx context.Context, conn websocket.Connection, subs subscription.List) error { return g.manageSubs(ctx, subscribeEvent, conn, subs) } // Unsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) Unsubscribe(ctx context.Context, conn stream.Connection, subs subscription.List) error { +func (g *Gateio) Unsubscribe(ctx context.Context, conn websocket.Connection, subs subscription.List) error { return g.manageSubs(ctx, unsubscribeEvent, conn, subs) } @@ -784,10 +784,10 @@ const subTplText = ` ` // GeneratePayload returns the payload for a websocket message -type GeneratePayload func(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) +type GeneratePayload func(ctx context.Context, conn websocket.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) // handleSubscription sends a websocket message to receive data from the channel -func (g *Gateio) handleSubscription(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List, generatePayload GeneratePayload) error { +func (g *Gateio) handleSubscription(ctx context.Context, conn websocket.Connection, event string, channelsToSubscribe subscription.List, generatePayload GeneratePayload) error { payloads, err := generatePayload(ctx, conn, event, channelsToSubscribe) if err != nil { return err diff --git a/exchanges/gateio/gateio_websocket_delivery_futures.go b/exchanges/gateio/gateio_websocket_delivery_futures.go index fd3ce062..5e766a6e 100644 --- a/exchanges/gateio/gateio_websocket_delivery_futures.go +++ b/exchanges/gateio/gateio_websocket_delivery_futures.go @@ -7,14 +7,14 @@ import ( "strconv" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -37,12 +37,12 @@ var defaultDeliveryFuturesSubscriptions = []string{ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account -func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Connection) error { +func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn websocket.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { return err } - err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &gws.Dialer{}, http.Header{}) if err != nil { return err } @@ -54,10 +54,10 @@ func (g *Gateio) WsDeliveryFuturesConnect(ctx context.Context, conn stream.Conne if err != nil { return err } - conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, websocket.PingHandler{ Websocket: true, Delay: time.Second * 5, - MessageType: websocket.PingMessage, + MessageType: gws.PingMessage, Message: pingMessage, }) return nil @@ -109,16 +109,16 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.Lis } // DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) DeliveryFuturesSubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } // DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) DeliveryFuturesUnsubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateDeliveryFuturesPayload) } -func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateDeliveryFuturesPayload(ctx context.Context, conn websocket.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } diff --git a/exchanges/gateio/gateio_websocket_futures.go b/exchanges/gateio/gateio_websocket_futures.go index 244afc56..c60f49ce 100644 --- a/exchanges/gateio/gateio_websocket_futures.go +++ b/exchanges/gateio/gateio_websocket_futures.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/account" @@ -19,10 +19,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) const ( @@ -58,12 +58,12 @@ var defaultFuturesSubscriptions = []string{ } // WsFuturesConnect initiates a websocket connection for futures account -func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) error { +func (g *Gateio) WsFuturesConnect(ctx context.Context, conn websocket.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { return err } - err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &gws.Dialer{}, http.Header{}) if err != nil { return err } @@ -75,9 +75,9 @@ func (g *Gateio) WsFuturesConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, websocket.PingHandler{ Websocket: true, - MessageType: websocket.PingMessage, + MessageType: gws.PingMessage, Delay: time.Second * 15, Message: pingMessage, }) @@ -138,12 +138,12 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions(settlement currency.Code) ( } // FuturesSubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) FuturesSubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } // FuturesUnsubscribe sends a websocket message to stop receiving data from the channel -func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateFuturesPayload) } @@ -199,14 +199,14 @@ func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset. case futuresAutoOrdersChannel: return g.processFuturesAutoOrderPushData(respRaw) default: - g.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: g.Name + stream.UnhandledMessage + string(respRaw), + g.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: g.Name + websocket.UnhandledMessage + string(respRaw), } - return errors.New(stream.UnhandledMessage) + return errors.New(websocket.UnhandledMessage) } } -func (g *Gateio) generateFuturesPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateFuturesPayload(ctx context.Context, conn websocket.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -375,7 +375,7 @@ func (g *Gateio) processFuturesCandlesticks(data []byte, assetType asset.Item) e if err != nil { return err } - klineDatas := make([]stream.KlineData, len(resp.Result)) + klineDatas := make([]websocket.KlineData, len(resp.Result)) for x := range resp.Result { icp := strings.Split(resp.Result[x].Name, currency.UnderscoreDelimiter) if len(icp) < 3 { @@ -385,7 +385,7 @@ func (g *Gateio) processFuturesCandlesticks(data []byte, assetType asset.Item) e if err != nil { return err } - klineDatas[x] = stream.KlineData{ + klineDatas[x] = websocket.KlineData{ Pair: currencyPair, AssetType: assetType, Exchange: g.Name, diff --git a/exchanges/gateio/gateio_websocket_option.go b/exchanges/gateio/gateio_websocket_option.go index 938ec208..849149f8 100644 --- a/exchanges/gateio/gateio_websocket_option.go +++ b/exchanges/gateio/gateio_websocket_option.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/account" @@ -18,10 +18,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -67,12 +67,12 @@ var defaultOptionsSubscriptions = []string{ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. -func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) error { +func (g *Gateio) WsOptionsConnect(ctx context.Context, conn websocket.Connection) error { err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { return err } - err = conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}) + err = conn.DialContext(ctx, &gws.Dialer{}, http.Header{}) if err != nil { return err } @@ -84,10 +84,10 @@ func (g *Gateio) WsOptionsConnect(ctx context.Context, conn stream.Connection) e if err != nil { return err } - conn.SetupPingHandler(websocketRateLimitNotNeededEPL, stream.PingHandler{ + conn.SetupPingHandler(websocketRateLimitNotNeededEPL, websocket.PingHandler{ Websocket: true, Delay: time.Second * 5, - MessageType: websocket.PingMessage, + MessageType: gws.PingMessage, Message: pingMessage, }) return nil @@ -169,7 +169,7 @@ getEnabledPairs: return subscriptions, nil } -func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { +func (g *Gateio) generateOptionsPayload(ctx context.Context, conn websocket.Connection, event string, channelsToSubscribe subscription.List) ([]WsInput, error) { if len(channelsToSubscribe) == 0 { return nil, errors.New("cannot generate payload, no channels supplied") } @@ -283,12 +283,12 @@ func (g *Gateio) generateOptionsPayload(ctx context.Context, conn stream.Connect } // OptionsSubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsSubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) OptionsSubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, subscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } // OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options -func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { +func (g *Gateio) OptionsUnsubscribe(ctx context.Context, conn websocket.Connection, channelsToUnsubscribe subscription.List) error { return g.handleSubscription(ctx, conn, unsubscribeEvent, channelsToUnsubscribe, g.generateOptionsPayload) } @@ -346,10 +346,10 @@ func (g *Gateio) WsHandleOptionsData(_ context.Context, respRaw []byte) error { case optionsPositionsChannel: return g.processOptionsPositionPushData(respRaw) default: - g.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: g.Name + stream.UnhandledMessage + string(respRaw), + g.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: g.Name + websocket.UnhandledMessage + string(respRaw), } - return errors.New(stream.UnhandledMessage) + return errors.New(websocket.UnhandledMessage) } } @@ -464,7 +464,7 @@ func (g *Gateio) processOptionsCandlestickPushData(data []byte) error { if err != nil { return err } - klineDatas := make([]stream.KlineData, len(resp.Result)) + klineDatas := make([]websocket.KlineData, len(resp.Result)) for x := range resp.Result { icp := strings.Split(resp.Result[x].NameOfSubscription, currency.UnderscoreDelimiter) if len(icp) < 3 { @@ -474,7 +474,7 @@ func (g *Gateio) processOptionsCandlestickPushData(data []byte) error { if err != nil { return err } - klineDatas[x] = stream.KlineData{ + klineDatas[x] = websocket.KlineData{ Pair: currencyPair, AssetType: asset.Options, Exchange: g.Name, diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go index 7933c117..cdf9e9f0 100644 --- a/exchanges/gateio/gateio_websocket_request_spot_test.go +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -12,7 +12,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" ) @@ -21,18 +20,18 @@ func TestWebsocketLogin(t *testing.T) { err := g.websocketLogin(context.Background(), nil, "") require.ErrorIs(t, err, common.ErrNilPointer) - err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") - require.ErrorIs(t, err, errChannelEmpty) - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) testexch.UpdatePairsOnce(t, g) g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes - demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) + c, err := g.Websocket.GetConnection(asset.Spot) require.NoError(t, err) - err = g.websocketLogin(context.Background(), demonstrationConn, "spot.login") + err = g.websocketLogin(context.Background(), c, "") + require.ErrorIs(t, err, errChannelEmpty) + + err = g.websocketLogin(context.Background(), c, "spot.login") require.NoError(t, err) } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 30ff8dd1..c25541fd 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -26,10 +26,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" "github.com/thrasher-corp/gocryptotrader/types" @@ -176,7 +176,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.NewWebsocket() + g.Websocket = websocket.NewManager() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -197,7 +197,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } - err = g.Websocket.Setup(&stream.WebsocketSetup{ + err = g.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, Features: &g.Features.Supports.WebsocketCapabilities, FillsFeed: g.Features.Enabled.FillsFeed, @@ -209,7 +209,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } // Spot connection - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: gateioWebsocketEndpoint, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -226,7 +226,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } // Futures connection - USDT margined - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: futuresWebsocketUsdtURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -245,7 +245,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Futures connection - BTC margined - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: futuresWebsocketBtcURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -265,7 +265,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // TODO: Add BTC margined delivery futures. // Futures connection - Delivery - USDT margined - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: deliveryRealUSDTTradingURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -284,7 +284,7 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Futures connection - Options - return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: optionsWebsocketURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index b46502b6..b2146d4d 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -19,8 +19,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -562,9 +562,9 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } - var dialer websocket.Dialer + var dialer gws.Dialer go g.wsReadData() err = g.WsAuth(context.Background(), &dialer) if err != nil { diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index b78645e2..557ede68 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -12,7 +12,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -22,9 +22,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -53,15 +53,15 @@ var subscriptionNames = map[string]string{ } // Instantiates a communications channel between websocket connections -var comms = make(chan stream.Response) +var comms = make(chan websocket.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := g.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -128,7 +128,7 @@ func (g *Gemini) manageSubs(subs subscription.List, op wsSubOp) error { } // WsAuth will connect to Gemini's secure endpoint -func (g *Gemini) WsAuth(ctx context.Context, dialer *websocket.Dialer) error { +func (g *Gemini) WsAuth(ctx context.Context, dialer *gws.Dialer) error { if !g.IsWebsocketAuthenticationSupported() { return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", g.Name) } @@ -175,14 +175,14 @@ func (g *Gemini) WsAuth(ctx context.Context, dialer *websocket.Dialer) error { } // wsFunnelConnectionData receives data from multiple connections and passes it to wsReadData -func (g *Gemini) wsFunnelConnectionData(ws stream.Connection) { +func (g *Gemini) wsFunnelConnectionData(ws websocket.Connection) { defer g.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil { return } - comms <- stream.Response{Raw: resp.Raw} + comms <- websocket.Response{Raw: resp.Raw} } } @@ -398,7 +398,7 @@ func (g *Gemini) wsHandleData(respRaw []byte) error { if !ok { return errors.New("unable to type assert interval") } - g.Websocket.DataHandler <- stream.KlineData{ + g.Websocket.DataHandler <- websocket.KlineData{ Timestamp: time.UnixMilli(int64(candle.Changes[i][0])), Pair: pair, AssetType: asset.Spot, @@ -412,7 +412,7 @@ func (g *Gemini) wsHandleData(respRaw []byte) error { } } default: - g.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: g.Name + stream.UnhandledMessage + string(respRaw)} + g.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: g.Name + websocket.UnhandledMessage + string(respRaw)} return nil } } else if r, ok := result["result"].(string); ok { @@ -426,7 +426,7 @@ func (g *Gemini) wsHandleData(respRaw []byte) error { } return fmt.Errorf("%v Unhandled websocket error %s", g.Name, respRaw) default: - g.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: g.Name + stream.UnhandledMessage + string(respRaw)} + g.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: g.Name + websocket.UnhandledMessage + string(respRaw)} return nil } } diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index 5dcbbad1..e4f06d13 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -24,9 +24,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -106,7 +106,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.NewWebsocket() + g.Websocket = websocket.NewManager() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -139,7 +139,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - err = g.Websocket.Setup(&stream.WebsocketSetup{ + err = g.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: geminiWebsocketEndpoint, RunningURL: wsRunningURL, @@ -153,7 +153,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v2/" + geminiWsMarketData, @@ -162,7 +162,7 @@ func (g *Gemini) Setup(exch *config.Exchange) error { return err } - return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, URL: geminiWebsocketEndpoint + "/v1/" + geminiWsOrderEvents, diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 2b73802b..d7540317 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -20,8 +20,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" @@ -461,10 +461,10 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } - var dialer websocket.Dialer + var dialer gws.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { t.Fatal(err) diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index a4af461b..40096ec3 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -11,7 +11,7 @@ import ( "time" "github.com/Masterminds/sprig/v3" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -21,10 +21,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -53,9 +53,9 @@ var defaultSubscriptions = subscription.List{ // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -144,7 +144,7 @@ func (h *HitBTC) wsGetTableName(respRaw []byte) (string, error) { return "trading", nil } } - h.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: h.Name + stream.UnhandledMessage + string(respRaw)} + h.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: h.Name + websocket.UnhandledMessage + string(respRaw)} return "", nil } @@ -303,7 +303,7 @@ func (h *HitBTC) wsHandleData(respRaw []byte) error { return err } default: - h.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: h.Name + stream.UnhandledMessage + string(respRaw)} + h.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: h.Name + websocket.UnhandledMessage + string(respRaw)} return nil } return nil diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index bcf48ac4..cb272af1 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -23,10 +23,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -125,7 +125,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.NewWebsocket() + h.Websocket = websocket.NewManager() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -151,7 +151,7 @@ func (h *HitBTC) Setup(exch *config.Exchange) error { return err } - err = h.Websocket.Setup(&stream.WebsocketSetup{ + err = h.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: hitbtcWebsocketAddress, RunningURL: wsRunningURL, @@ -169,7 +169,7 @@ func (h *HitBTC) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 0374d560..e0fe42c1 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -12,7 +12,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -28,10 +28,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" mockws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" @@ -1312,9 +1312,9 @@ func TestWSCandles(t *testing.T) { close(h.Websocket.DataHandler) require.Len(t, h.Websocket.DataHandler, 1, "Must see correct number of records") cAny := <-h.Websocket.DataHandler - c, ok := cAny.(stream.KlineData) + c, ok := cAny.(websocket.KlineData) require.True(t, ok, "Must get the correct type from DataHandler") - exp := stream.KlineData{ + exp := websocket.KlineData{ Timestamp: time.UnixMilli(1489474082831), Pair: btcusdtPair, AssetType: asset.Spot, @@ -1983,20 +1983,20 @@ func TestGenerateSubscriptions(t *testing.T) { testsubs.EqualLists(t, exp, subs) } -func wsFixture(tb testing.TB, msg []byte, w *websocket.Conn) error { +func wsFixture(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() action, _ := jsonparser.GetString(msg, "action") ch, _ := jsonparser.GetString(msg, "ch") if action == "req" && ch == "auth" { - return w.WriteMessage(websocket.TextMessage, []byte(`{"action":"req","code":200,"ch":"auth","data":{}}`)) + return w.WriteMessage(gws.TextMessage, []byte(`{"action":"req","code":200,"ch":"auth","data":{}}`)) } if action == "sub" { - return w.WriteMessage(websocket.TextMessage, []byte(`{"action":"sub","code":200,"ch":"`+ch+`"}`)) + return w.WriteMessage(gws.TextMessage, []byte(`{"action":"sub","code":200,"ch":"`+ch+`"}`)) } id, _ := jsonparser.GetString(msg, "id") sub, _ := jsonparser.GetString(msg, "sub") if id != "" && sub != "" { - return w.WriteMessage(websocket.TextMessage, []byte(`{"id":"`+id+`","status":"ok","subbed":"`+sub+`"}`)) + return w.WriteMessage(gws.TextMessage, []byte(`{"id":"`+id+`","status":"ok","subbed":"`+sub+`"}`)) } return fmt.Errorf("%w: %s", errors.New("Unhandled mock websocket message"), msg) } diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index db4347f2..19f8c0a8 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -12,7 +12,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -23,10 +23,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -76,9 +76,9 @@ var subscriptionNames = map[string]string{ // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - if err := h.Websocket.Conn.Dial(&websocket.Dialer{}, http.Header{}); err != nil { + if err := h.Websocket.Conn.Dial(&gws.Dialer{}, http.Header{}); err != nil { return err } @@ -100,7 +100,7 @@ func (h *HUOBI) WsConnect() error { } // wsReadMsgs reads and processes messages from a websocket connection -func (h *HUOBI) wsReadMsgs(s stream.Connection) { +func (h *HUOBI) wsReadMsgs(s websocket.Connection) { defer h.Websocket.Wg.Done() for { msg := s.ReadMessage() @@ -146,8 +146,8 @@ func (h *HUOBI) wsHandleData(respRaw []byte) error { return h.wsHandleChannelMsgs(s, respRaw) } - h.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: h.Name + stream.UnhandledMessage + string(respRaw), + h.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: h.Name + websocket.UnhandledMessage + string(respRaw), } return nil @@ -208,7 +208,7 @@ func (h *HUOBI) wsHandleCandleMsg(s *subscription.Subscription, respRaw []byte) if err := json.Unmarshal(respRaw, &c); err != nil { return err } - h.Websocket.DataHandler <- stream.KlineData{ + h.Websocket.DataHandler <- websocket.KlineData{ Timestamp: c.Timestamp.Time(), Exchange: h.Name, AssetType: s.Asset, @@ -511,7 +511,7 @@ func (h *HUOBI) manageSubs(op string, subs subscription.List) error { return subscription.ErrBatchingNotSupported } s := subs[0] - var c stream.Connection + var c websocket.Connection var req any if s.Authenticated { c = h.Websocket.AuthConn @@ -528,7 +528,7 @@ func (h *HUOBI) manageSubs(op string, subs subscription.List) error { if op == wsSubOp { s.SetKey(s.QualifiedChannel) if err := h.Websocket.AddSubscriptions(c, s); err != nil { - return fmt.Errorf("%w: %s; error: %w", stream.ErrSubscriptionFailure, s, err) + return fmt.Errorf("%w: %s; error: %w", websocket.ErrSubscriptionFailure, s, err) } } ctx := context.Background() @@ -564,7 +564,7 @@ func (h *HUOBI) wsGenerateSignature(creds *account.Credentials, timestamp string } func (h *HUOBI) wsAuthConnect(ctx context.Context) error { - if err := h.Websocket.AuthConn.Dial(&websocket.Dialer{}, http.Header{}); err != nil { + if err := h.Websocket.AuthConn.Dial(&gws.Dialer{}, http.Header{}); err != nil { return fmt.Errorf("authenticated dial failed: %w", err) } if err := h.wsLogin(ctx); err != nil { @@ -603,7 +603,7 @@ func (h *HUOBI) wsLogin(ctx context.Context) error { } resp := c.ReadMessage() if resp.Raw == nil { - return &websocket.CloseError{Code: websocket.CloseAbnormalClosure} + return &gws.CloseError{Code: gws.CloseAbnormalClosure} } return getErrResp(resp.Raw) diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 6556014f..b6d0a088 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -26,9 +26,9 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -166,7 +166,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.NewWebsocket() + h.Websocket = websocket.NewManager() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -203,7 +203,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - err = h.Websocket.Setup(&stream.WebsocketSetup{ + err = h.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: wsSpotURL + wsPublicPath, RunningURL: wsRunningURL, @@ -217,7 +217,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - err = h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = h.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -226,7 +226,7 @@ func (h *HUOBI) Setup(exch *config.Exchange) error { return err } - return h.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return h.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(20 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index 069a6c6c..946835d2 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -20,10 +20,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -76,7 +76,7 @@ type IBotExchange interface { DisableRateLimiter() error EnableRateLimiter() error GetServerTime(ctx context.Context, ai asset.Item) (time.Time, error) - GetWebsocket() (*stream.Websocket, error) + GetWebsocket() (*websocket.Manager, error) SubscribeToWebsocketChannels(channels subscription.List) error UnsubscribeToWebsocketChannels(channels subscription.List) error GetSubscriptions() (subscription.List, error) diff --git a/exchanges/kraken/kraken_types.go b/exchanges/kraken/kraken_types.go index 772f193b..ddda5031 100644 --- a/exchanges/kraken/kraken_types.go +++ b/exchanges/kraken/kraken_types.go @@ -76,11 +76,6 @@ const ( statusOpen = "open" krakenFormat = "2006-01-02T15:04:05.000Z" - - // ChannelOrderbookDepthKey configures the orderbook depth in stream.ChannelSubscription.Params - ChannelOrderbookDepthKey = "_depth" - // ChannelCandlesTimeframeKey configures the candle bar timeframe in stream.ChannelSubscription.Params - ChannelCandlesTimeframeKey = "_timeframe" ) var ( diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index a21f1b42..3d6ce3d0 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -13,7 +13,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/currency" @@ -23,10 +23,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -58,6 +58,7 @@ const ( krakenWsAddOrderStatus = "addOrderStatus" krakenWsCancelOrderStatus = "cancelOrderStatus" krakenWsCancelAllOrderStatus = "cancelAllStatus" + krakenWsPong = "pong" krakenWsPingDelay = time.Second * 27 ) @@ -96,16 +97,16 @@ var defaultSubscriptions = subscription.List{ // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err } - comms := make(chan stream.Response) + comms := make(chan websocket.Response) k.Websocket.Wg.Add(2) go k.wsReadData(comms) go k.wsFunnelConnectionData(k.Websocket.Conn, comms) @@ -141,7 +142,7 @@ func (k *Kraken) WsConnect() error { } // wsFunnelConnectionData funnels both auth and public ws data into one manageable place -func (k *Kraken) wsFunnelConnectionData(ws stream.Connection, comms chan stream.Response) { +func (k *Kraken) wsFunnelConnectionData(ws websocket.Connection, comms chan websocket.Response) { defer k.Websocket.Wg.Done() for { resp := ws.ReadMessage() @@ -153,7 +154,7 @@ func (k *Kraken) wsFunnelConnectionData(ws stream.Connection, comms chan stream. } // wsReadData receives and passes on websocket messages for processing -func (k *Kraken) wsReadData(comms chan stream.Response) { +func (k *Kraken) wsReadData(comms chan websocket.Response) { defer k.Websocket.Wg.Done() for { @@ -226,16 +227,16 @@ func (k *Kraken) wsHandleData(respRaw []byte) error { } switch event { - case stream.Pong, krakenWsHeartbeat: + case krakenWsPong, krakenWsHeartbeat: return nil case krakenWsCancelOrderStatus, krakenWsCancelAllOrderStatus, krakenWsAddOrderStatus, krakenWsSubscriptionStatus: // All of these should have found a listener already - return fmt.Errorf("%w: %s %v", stream.ErrSignatureNotMatched, event, reqID) + return fmt.Errorf("%w: %s %v", websocket.ErrSignatureNotMatched, event, reqID) case krakenWsSystemStatus: return k.wsProcessSystemStatus(respRaw) default: - k.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: fmt.Sprintf("%s: %s", stream.UnhandledMessage, respRaw), + k.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: fmt.Sprintf("%s: %s", websocket.UnhandledMessage, respRaw), } } @@ -243,11 +244,11 @@ func (k *Kraken) wsHandleData(respRaw []byte) error { } // startWsPingHandler sets up a websocket ping handler to maintain a connection -func (k *Kraken) startWsPingHandler(conn stream.Connection) { - conn.SetupPingHandler(request.Unset, stream.PingHandler{ +func (k *Kraken) startWsPingHandler(conn websocket.Connection) { + conn.SetupPingHandler(request.Unset, websocket.PingHandler{ Message: []byte(`{"event":"ping"}`), Delay: krakenWsPingDelay, - MessageType: websocket.TextMessage, + MessageType: gws.TextMessage, }) } @@ -973,7 +974,7 @@ func (k *Kraken) wsProcessCandle(c string, resp []any, pair currency.Pair) error } interval := parts[1] - k.Websocket.DataHandler <- stream.KlineData{ + k.Websocket.DataHandler <- websocket.KlineData{ AssetType: asset.Spot, Pair: pair, Timestamp: time.Now(), @@ -1097,7 +1098,7 @@ func (k *Kraken) manageSubs(op string, subs subscription.List) error { resps, err := conn.SendMessageReturnResponses(context.TODO(), request.Unset, r.RequestID, r, len(s.Pairs)) // Ignore an overall timeout, because we'll track individual subscriptions in handleSubResps - err = common.ExcludeError(err, stream.ErrSignatureTimeout) + err = common.ExcludeError(err, websocket.ErrSignatureTimeout) if err != nil { return fmt.Errorf("%w; Channel: %s Pair: %s", err, s.Channel, s.Pairs) } diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index 7d60d760..eda5df4b 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -26,10 +26,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -169,7 +169,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.NewWebsocket() + k.Websocket = websocket.NewManager() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -194,7 +194,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { if err != nil { return err } - err = k.Websocket.Setup(&stream.WebsocketSetup{ + err = k.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: krakenWSURL, RunningURL: wsRunningURL, @@ -209,7 +209,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { return err } - err = k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + err = k.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, @@ -222,7 +222,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { if err != nil { return err } - return k.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return k.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ RateLimit: request.NewWeightedRateLimitByDuration(50 * time.Millisecond), ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, diff --git a/exchanges/kraken/mock_ws_test.go b/exchanges/kraken/mock_ws_test.go index 80e04cdb..ef1675ca 100644 --- a/exchanges/kraken/mock_ws_test.go +++ b/exchanges/kraken/mock_ws_test.go @@ -6,12 +6,12 @@ import ( "testing" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/thrasher-corp/gocryptotrader/encoding/json" ) -func mockWsServer(tb testing.TB, msg []byte, w *websocket.Conn) error { +func mockWsServer(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() event, err := jsonparser.GetUnsafeString(msg, "event") if err != nil { @@ -26,7 +26,7 @@ func mockWsServer(tb testing.TB, msg []byte, w *websocket.Conn) error { return nil } -func mockWsCancelOrders(tb testing.TB, msg []byte, w *websocket.Conn) error { +func mockWsCancelOrders(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() var req WsCancelOrderRequest if err := json.Unmarshal(msg, &req); err != nil { @@ -46,10 +46,10 @@ func mockWsCancelOrders(tb testing.TB, msg []byte, w *websocket.Conn) error { if err != nil { return err } - return w.WriteMessage(websocket.TextMessage, msg) + return w.WriteMessage(gws.TextMessage, msg) } -func mockWsAddOrder(tb testing.TB, msg []byte, w *websocket.Conn) error { +func mockWsAddOrder(tb testing.TB, msg []byte, w *gws.Conn) error { tb.Helper() var req WsAddOrderRequest if err := json.Unmarshal(msg, &req); err != nil { @@ -72,5 +72,5 @@ func mockWsAddOrder(tb testing.TB, msg []byte, w *websocket.Conn) error { if err != nil { return err } - return w.WriteMessage(websocket.TextMessage, msg) + return w.WriteMessage(gws.TextMessage, msg) } diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index bc73c5c9..df714da1 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -13,7 +13,7 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -24,10 +24,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -119,12 +119,12 @@ var defaultSubscriptions = subscription.List{ // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } fetchedFuturesOrderbookMutex.Lock() fetchedFuturesOrderbook = map[string]bool{} fetchedFuturesOrderbookMutex.Unlock() - var dialer websocket.Dialer + var dialer gws.Dialer dialer.HandshakeTimeout = ku.Config.HTTPTimeout dialer.Proxy = http.ProxyFromEnvironment var instances *WSInstanceServers @@ -155,10 +155,10 @@ func (ku *Kucoin) WsConnect() error { } ku.Websocket.Wg.Add(1) go ku.wsReadData() - ku.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + ku.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ Delay: time.Millisecond * time.Duration(instances.InstanceServers[0].PingTimeout), Message: []byte(`{"type":"ping"}`), - MessageType: websocket.TextMessage, + MessageType: gws.TextMessage, }) ku.setupOrderbookManager() @@ -335,8 +335,8 @@ func (ku *Kucoin) wsHandleData(respData []byte) error { } return ku.processFuturesKline(resp.Data, instrumentInfos[1]) default: - ku.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: ku.Name + stream.UnhandledMessage + string(respData), + ku.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: ku.Name + websocket.UnhandledMessage + string(respData), } return errors.New("push data not handled") } @@ -612,7 +612,7 @@ func (ku *Kucoin) processFuturesKline(respData []byte, intervalStr string) error if err != nil { return err } - ku.Websocket.DataHandler <- &stream.KlineData{ + ku.Websocket.DataHandler <- &websocket.KlineData{ Timestamp: resp.Time.Time(), AssetType: asset.Futures, Exchange: ku.Name, @@ -839,7 +839,7 @@ func (ku *Kucoin) processCandlesticks(respData []byte, instrument, intervalStrin if !ku.AssetWebsocketSupport.IsAssetWebsocketSupported(assets[x]) { continue } - ku.Websocket.DataHandler <- &stream.KlineData{ + ku.Websocket.DataHandler <- &websocket.KlineData{ Timestamp: response.Time.Time(), Pair: pair, AssetType: assets[x], diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index d80fba0f..93984734 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -26,10 +26,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -164,7 +164,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.NewWebsocket() + ku.Websocket = websocket.NewManager() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -192,7 +192,7 @@ func (ku *Kucoin) Setup(exch *config.Exchange) error { return err } err = ku.Websocket.Setup( - &stream.WebsocketSetup{ + &websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: kucoinWebsocketURL, RunningURL: wsRunningEndpoint, @@ -211,7 +211,7 @@ func (ku *Kucoin) Setup(exch *config.Exchange) error { if err != nil { return err } - return ku.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return ku.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), diff --git a/exchanges/lbank/lbank.go b/exchanges/lbank/lbank.go index 7ff03da3..a2b7a7f4 100644 --- a/exchanges/lbank/lbank.go +++ b/exchanges/lbank/lbank.go @@ -21,14 +21,12 @@ import ( exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Lbank is the overarching type across this package type Lbank struct { exchange.Base - privateKey *rsa.PrivateKey - WebsocketConn *stream.WebsocketConnection + privateKey *rsa.PrivateKey } const ( diff --git a/exchanges/okx/okx_business_websocket.go b/exchanges/okx/okx_business_websocket.go index 5056c40d..dc55d492 100644 --- a/exchanges/okx/okx_business_websocket.go +++ b/exchanges/okx/okx_business_websocket.go @@ -7,15 +7,15 @@ import ( "strconv" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -46,9 +46,9 @@ var ( // WsConnectBusiness connects to a business websocket channel. func (ok *Okx) WsConnectBusiness() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer dialer.ReadBufferSize = 8192 dialer.WriteBufferSize = 8192 @@ -63,8 +63,8 @@ func (ok *Okx) WsConnectBusiness() error { log.Debugf(log.ExchangeSys, "Successful connection to %v\n", ok.Websocket.GetWebsocketURL()) } - ok.Websocket.Conn.SetupPingHandler(request.UnAuth, stream.PingHandler{ - MessageType: websocket.TextMessage, + ok.Websocket.Conn.SetupPingHandler(request.UnAuth, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: pingMsg, Delay: time.Second * 20, }) @@ -150,7 +150,7 @@ func (ok *Okx) WsSpreadAuth(ctx context.Context) error { } } -// GenerateDefaultBusinessSubscriptions returns a list of default subscriptions to business stream. +// GenerateDefaultBusinessSubscriptions returns a list of default subscriptions to business websocket. func (ok *Okx) GenerateDefaultBusinessSubscriptions() ([]subscription.Subscription, error) { var subs []string var subscriptions []subscription.Subscription diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index dc6d82d5..1c09844e 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -12,7 +12,7 @@ import ( "text/template" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -22,10 +22,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/types" ) @@ -251,9 +251,9 @@ var subscriptionNames = map[string]string{ // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer dialer.ReadBufferSize = 8192 dialer.WriteBufferSize = 8192 @@ -267,8 +267,8 @@ func (ok *Okx) WsConnect() error { log.Debugf(log.ExchangeSys, "Successful connection to %v\n", ok.Websocket.GetWebsocketURL()) } - ok.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + ok.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: pingMsg, Delay: time.Second * 20, }) @@ -291,15 +291,15 @@ func (ok *Okx) WsAuth(ctx context.Context) error { if err != nil { return err } - var dialer websocket.Dialer + var dialer gws.Dialer err = ok.Websocket.AuthConn.Dial(&dialer, http.Header{}) if err != nil { return err } ok.Websocket.Wg.Add(1) go ok.wsReadData(ok.Websocket.AuthConn) - ok.Websocket.AuthConn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, + ok.Websocket.AuthConn.SetupPingHandler(request.Unset, websocket.PingHandler{ + MessageType: gws.TextMessage, Message: pingMsg, Delay: time.Second * 20, }) @@ -368,7 +368,7 @@ func (ok *Okx) WsAuth(ctx context.Context) error { } // wsReadData sends msgs from public and auth websockets to data handler -func (ok *Okx) wsReadData(ws stream.Connection) { +func (ok *Okx) wsReadData(ws websocket.Connection) { defer ok.Websocket.Wg.Done() for { resp := ws.ReadMessage() @@ -654,7 +654,7 @@ func (ok *Okx) WsHandleData(respRaw []byte) error { var resp CopyTradingNotification return ok.wsProcessPushData(respRaw, &resp) default: - ok.Websocket.DataHandler <- stream.UnhandledMessageWarning{Message: ok.Name + stream.UnhandledMessage + string(respRaw)} + ok.Websocket.DataHandler <- websocket.UnhandledMessageWarning{Message: ok.Name + websocket.UnhandledMessage + string(respRaw)} return nil } } @@ -793,7 +793,7 @@ func (ok *Okx) wsProcessIndexCandles(respRaw []byte) error { candleInterval := strings.TrimPrefix(response.Argument.Channel, candle) for i := range response.Data { candlesData := response.Data[i] - myCandle := stream.KlineData{ + myCandle := websocket.KlineData{ Pair: pair, Exchange: ok.Name, Timestamp: time.UnixMilli(candlesData[0].Int64()), @@ -1397,7 +1397,7 @@ func (ok *Okx) wsProcessCandles(respRaw []byte) error { candleInterval := strings.TrimPrefix(response.Argument.Channel, candle) for i := range response.Data { for j := range assets { - ok.Websocket.DataHandler <- stream.KlineData{ + ok.Websocket.DataHandler <- websocket.KlineData{ Timestamp: time.UnixMilli(response.Data[i][0].Int64()), Pair: pair, AssetType: assets[j], @@ -1585,7 +1585,7 @@ func (ok *Okx) WsPlaceOrder(ctx context.Context, arg *PlaceOrderRequestParam) (* } } -// WsPlaceMultipleOrders creates an order through the websocket stream. +// WsPlaceMultipleOrders creates an order through the websocket func (ok *Okx) WsPlaceMultipleOrders(ctx context.Context, args []PlaceOrderRequestParam) ([]OrderData, error) { if len(args) == 0 { return nil, order.ErrSubmissionIsNil @@ -2051,7 +2051,7 @@ func (m *wsRequestDataChannelsMultiplexer) Shutdown() { close(m.shutdown) } -// wsChannelSubscription sends a subscription or unsubscription request for different channels through the websocket stream. +// wsChannelSubscription sends a subscription or unsubscription request for different channels through the websocket func (ok *Okx) wsChannelSubscription(ctx context.Context, operation, channel string, assetType asset.Item, pair currency.Pair, tInstrumentType, tInstrumentID, tUnderlying bool) error { if operation != operationSubscribe && operation != operationUnsubscribe { return errInvalidWebsocketEvent @@ -2107,7 +2107,7 @@ func (ok *Okx) wsChannelSubscription(ctx context.Context, operation, channel str // Private Channel Websocket methods -// wsAuthChannelSubscription send a subscription or unsubscription request for different channels through the websocket stream. +// wsAuthChannelSubscription send a subscription or unsubscription request for different channels through the websocket func (ok *Okx) wsAuthChannelSubscription(ctx context.Context, operation, channel string, assetType asset.Item, pair currency.Pair, uid, algoID string, params wsSubscriptionParameters) error { if operation != operationSubscribe && operation != operationUnsubscribe { return errInvalidWebsocketEvent diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index e1f2a611..d112846b 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -28,10 +28,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -177,7 +177,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.NewWebsocket() + ok.Websocket = websocket.NewManager() ok.WebsocketResponseMaxLimit = websocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = websocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -208,7 +208,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { if err != nil { return err } - if err := ok.Websocket.Setup(&stream.WebsocketSetup{ + if err := ok.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: apiWebsocketPublicURL, RunningURL: wsRunningEndpoint, @@ -228,7 +228,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { go ok.WsResponseMultiplexer.Run() - if err := ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + if err := ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: apiWebsocketPublicURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: websocketResponseMaxLimit, @@ -237,7 +237,7 @@ func (ok *Okx) Setup(exch *config.Exchange) error { return err } - return ok.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ URL: apiWebsocketPrivateURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: websocketResponseMaxLimit, diff --git a/exchanges/orderbook/orderbook_types.go b/exchanges/orderbook/orderbook_types.go index 4c9aec59..c5721bb0 100644 --- a/exchanges/orderbook/orderbook_types.go +++ b/exchanges/orderbook/orderbook_types.go @@ -203,7 +203,7 @@ type Movement struct { // FullBookSideConsumed defines if the orderbook liquidty has been consumed // by the requested amount. This might not represent the actual book on the // exchange as they might restrict the amount of information being passed - // back from either a REST request or websocket stream. + // back from either a REST request or websocket update FullBookSideConsumed bool } diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index f1ebd0f6..740724e8 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -19,7 +19,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -547,9 +547,9 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.ErrWebsocketNotEnabled.Error()) + t.Skip(websocket.ErrWebsocketNotEnabled.Error()) } - var dialer websocket.Dialer + var dialer gws.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { t.Fatal(err) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index f9276957..c972e644 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" @@ -19,10 +19,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -56,9 +56,9 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return stream.ErrWebsocketNotEnabled + return websocket.ErrWebsocketNotEnabled } - var dialer websocket.Dialer + var dialer gws.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) if err != nil { return err @@ -262,8 +262,8 @@ func (p *Poloniex) wsHandleData(respRaw []byte) error { return fmt.Errorf("websocket process trades update: %w", err) } default: - p.Websocket.DataHandler <- stream.UnhandledMessageWarning{ - Message: p.Name + stream.UnhandledMessage + string(respRaw), + p.Websocket.DataHandler <- websocket.UnhandledMessageWarning{ + Message: p.Name + websocket.UnhandledMessage + string(respRaw), } } } diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index a77b9931..2993ab5a 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -24,10 +24,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" "github.com/thrasher-corp/gocryptotrader/log" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -137,7 +137,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.NewWebsocket() + p.Websocket = websocket.NewManager() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -163,7 +163,7 @@ func (p *Poloniex) Setup(exch *config.Exchange) error { return err } - err = p.Websocket.Setup(&stream.WebsocketSetup{ + err = p.Websocket.Setup(&websocket.ManagerSetup{ ExchangeConfig: exch, DefaultURL: poloniexWebsocketAddress, RunningURL: wsRunningURL, @@ -181,7 +181,7 @@ func (p *Poloniex) Setup(exch *config.Exchange) error { return err } - return p.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + return p.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, }) diff --git a/exchanges/sharedtestvalues/customex.go b/exchanges/sharedtestvalues/customex.go index 34b4b8e1..38ddc20a 100644 --- a/exchanges/sharedtestvalues/customex.go +++ b/exchanges/sharedtestvalues/customex.go @@ -16,10 +16,10 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -282,7 +282,7 @@ func (c *CustomEx) EnableRateLimiter() error { } // GetWebsocket is a mock method for CustomEx -func (c *CustomEx) GetWebsocket() (*stream.Websocket, error) { +func (c *CustomEx) GetWebsocket() (*websocket.Manager, error) { return nil, nil } diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index d5a93d12..b4780850 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -14,7 +14,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" ) // This package is only to be referenced in test files @@ -51,8 +51,8 @@ func GetWebsocketStructChannelOverride() chan struct{} { } // NewTestWebsocket returns a test websocket object -func NewTestWebsocket() *stream.Websocket { - w := stream.NewWebsocket() +func NewTestWebsocket() *websocket.Manager { + w := websocket.NewManager() w.DataHandler = make(chan any, WebsocketChannelOverrideCapacity) w.ToRoutine = make(chan any, 1000) return w diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go deleted file mode 100644 index bb70a5ea..00000000 --- a/exchanges/stream/stream_types.go +++ /dev/null @@ -1,162 +0,0 @@ -package stream - -import ( - "context" - "net/http" - "time" - - "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/currency" - "github.com/thrasher-corp/gocryptotrader/exchanges/asset" - "github.com/thrasher-corp/gocryptotrader/exchanges/order" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" -) - -// Connection defines a streaming services connection -type Connection interface { - Dial(*websocket.Dialer, http.Header) error - DialContext(context.Context, *websocket.Dialer, http.Header) error - ReadMessage() Response - SetupPingHandler(request.EndpointLimit, PingHandler) - // GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set - // (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function - // defined in websocket_connection.go. - GenerateMessageID(highPrecision bool) int64 - // SendMessageReturnResponse will send a WS message to the connection and wait for response - SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) - // SendMessageReturnResponses will send a WS message to the connection and wait for N responses - SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int) ([][]byte, error) - // SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection - SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int, messageInspector Inspector) ([][]byte, error) - // SendRawMessage sends a message over the connection without JSON encoding it - SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error - // SendJSONMessage sends a JSON encoded message over the connection - SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error - SetURL(string) - SetProxy(string) - GetURL() string - Shutdown() error -} - -// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection -// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages -type Inspector interface { - IsFinal([]byte) bool -} - -// Response defines generalised data from the stream connection -type Response struct { - Type int - Raw []byte -} - -// ConnectionSetup defines variables for an individual stream connection -type ConnectionSetup struct { - ResponseCheckTimeout time.Duration - ResponseMaxLimit time.Duration - RateLimit *request.RateLimiterWithWeight - Authenticated bool - ConnectionLevelReporter Reporter - - // URL defines the websocket server URL to connect to - URL string - // Connector is the function that will be called to connect to the - // exchange's websocket server. This will be called once when the stream - // service is started. Any bespoke connection logic should be handled here. - Connector func(ctx context.Context, conn Connection) error - // GenerateSubscriptions is a function that will be called to generate a - // list of subscriptions to be made to the exchange's websocket server. - GenerateSubscriptions func() (subscription.List, error) - // Subscriber is a function that will be called to send subscription - // messages based on the exchange's websocket server requirements to - // subscribe to specific channels. - Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error - // Unsubscriber is a function that will be called to send unsubscription - // messages based on the exchange's websocket server requirements to - // unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED. - Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error - // Handler defines the function that will be called when a message is - // received from the exchange's websocket server. This function should - // handle the incoming message and pass it to the appropriate data handler. - Handler func(ctx context.Context, incoming []byte) error - // BespokeGenerateMessageID is a function that returns a unique message ID. - // This is useful for when an exchange connection requires a unique or - // structured message ID for each message sent. - BespokeGenerateMessageID func(highPrecision bool) int64 - // Authenticate will be called to authenticate the connection - Authenticate func(ctx context.Context, conn Connection) error - // MessageFilter defines the criteria used to match messages to a specific connection. - // The filter enables precise routing and handling of messages for distinct connection contexts. - MessageFilter any -} - -// ConnectionWrapper contains the connection setup details to be used when -// attempting a new connection. It also contains the subscriptions that are -// associated with the specific connection. -type ConnectionWrapper struct { - // Setup contains the connection setup details - Setup *ConnectionSetup - // Subscriptions contains the subscriptions that are associated with the - // specific connection(s) - Subscriptions *subscription.Store - // Connection contains the active connection based off the connection - // details above. - Connection Connection // TODO: Upgrade to slice of connections. -} - -// PingHandler container for ping handler settings -type PingHandler struct { - Websocket bool - UseGorillaHandler bool - MessageType int - Message []byte - Delay time.Duration -} - -// FundingData defines funding data -type FundingData struct { - Timestamp time.Time - CurrencyPair currency.Pair - AssetType asset.Item - Exchange string - Amount float64 - Rate float64 - Period int64 - Side order.Side -} - -// KlineData defines kline feed -type KlineData struct { - Timestamp time.Time - Pair currency.Pair - AssetType asset.Item - Exchange string - StartTime time.Time - CloseTime time.Time - Interval string - OpenPrice float64 - ClosePrice float64 - HighPrice float64 - LowPrice float64 - Volume float64 -} - -// WebsocketPositionUpdated reflects a change in orders/contracts on an exchange -type WebsocketPositionUpdated struct { - Timestamp time.Time - Pair currency.Pair - AssetType asset.Item - Exchange string -} - -// UnhandledMessageWarning defines a container for unhandled message warnings -type UnhandledMessageWarning struct { - Message string -} - -// Reporter interface groups observability functionality over -// Websocket request latency. -type Reporter interface { - Latency(name string, message []byte, t time.Duration) -} diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go deleted file mode 100644 index 2a436604..00000000 --- a/exchanges/stream/websocket.go +++ /dev/null @@ -1,1310 +0,0 @@ -package stream - -import ( - "context" - "errors" - "fmt" - "net/url" - "reflect" - "slices" - "sync" - "time" - - "github.com/thrasher-corp/gocryptotrader/common" - "github.com/thrasher-corp/gocryptotrader/config" - "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" - "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" - "github.com/thrasher-corp/gocryptotrader/log" -) - -const jobBuffer = 5000 - -// Public websocket errors -var ( - ErrWebsocketNotEnabled = errors.New("websocket not enabled") - ErrSubscriptionFailure = errors.New("subscription failure") - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - ErrSubscriptionsNotAdded = errors.New("subscriptions not added") - ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed") - ErrAlreadyDisabled = errors.New("websocket already disabled") - ErrNotConnected = errors.New("websocket is not connected") - ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") - ErrRequestRouteNotFound = errors.New("request route not found") - ErrSignatureNotSet = errors.New("signature not set") - ErrRequestPayloadNotSet = errors.New("request payload not set") -) - -// Private websocket errors -var ( - errExchangeConfigIsNil = errors.New("exchange config is nil") - errWebsocketIsNil = errors.New("websocket is nil") - errWebsocketSetupIsNil = errors.New("websocket setup is nil") - errWebsocketAlreadyInitialised = errors.New("websocket already initialised") - errWebsocketAlreadyEnabled = errors.New("websocket already enabled") - errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") - errConfigFeaturesIsNil = errors.New("exchange config features is nil") - errDefaultURLIsEmpty = errors.New("default url is empty") - errRunningURLIsEmpty = errors.New("running url cannot be empty") - errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameEmpty = errors.New("exchange config name empty") - errInvalidTrafficTimeout = errors.New("invalid traffic timeout") - errTrafficAlertNil = errors.New("traffic alert is nil") - errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") - errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") - errWebsocketConnectorUnset = errors.New("websocket connector function not set") - errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") - errReadMessageErrorsNil = errors.New("read message errors is nil") - errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") - errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") - errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") - errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("websocket connect func not set") - errAlreadyConnected = errors.New("websocket already connected") - errCannotShutdown = errors.New("websocket cannot shutdown") - errAlreadyReconnecting = errors.New("websocket in the process of reconnection") - errConnSetup = errors.New("error in connection setup") - errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionWrapperDuplication = errors.New("connection wrapper duplication") - errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") - errExchangeConfigEmpty = errors.New("exchange config is empty") - errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") - errMessageFilterNotSet = errors.New("message filter not set") - errMessageFilterNotComparable = errors.New("message filter is not comparable") -) - -var globalReporter Reporter - -// SetupGlobalReporter sets a reporter interface to be used -// for all exchange requests -func SetupGlobalReporter(r Reporter) { - globalReporter = r -} - -// NewWebsocket initialises the websocket struct -func NewWebsocket() *Websocket { - return &Websocket{ - DataHandler: make(chan any, jobBuffer), - ToRoutine: make(chan any, jobBuffer), - ShutdownC: make(chan struct{}), - TrafficAlert: make(chan struct{}, 1), - // ReadMessageErrors is buffered for an edge case when `Connect` fails - // after subscriptions are made but before the connectionMonitor has - // started. This allows the error to be read and handled in the - // connectionMonitor and start a connection cycle again. - ReadMessageErrors: make(chan error, 1), - Match: NewMatch(), - subscriptions: subscription.NewStore(), - features: &protocol.Features{}, - Orderbook: buffer.Orderbook{}, - connections: make(map[Connection]*ConnectionWrapper), - } -} - -// Setup sets main variables for websocket connection -func (w *Websocket) Setup(s *WebsocketSetup) error { - if w == nil { - return errWebsocketIsNil - } - - if s == nil { - return errWebsocketSetupIsNil - } - - w.m.Lock() - defer w.m.Unlock() - - if w.IsInitialised() { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) - } - - if s.ExchangeConfig == nil { - return errExchangeConfigIsNil - } - - if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameEmpty - } - w.exchangeName = s.ExchangeConfig.Name - w.verbose = s.ExchangeConfig.Verbose - - if s.Features == nil { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketFeaturesIsUnset) - } - w.features = s.Features - - if s.ExchangeConfig.Features == nil { - return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) - } - w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) - - w.useMultiConnectionManagement = s.UseMultiConnectionManagement - - if !w.useMultiConnectionManagement { - // TODO: Remove this block when all exchanges are updated and backwards - // compatibility is no longer required. - if s.Connector == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) - } - if s.Subscriber == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) - } - if s.Unsubscriber == nil && w.features.Unsubscribe { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) - } - if s.GenerateSubscriptions == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) - } - if s.DefaultURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty) - } - w.defaultURL = s.DefaultURL - if s.RunningURL == "" { - return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty) - } - - w.connector = s.Connector - w.Subscriber = s.Subscriber - w.Unsubscriber = s.Unsubscriber - w.GenerateSubs = s.GenerateSubscriptions - - err := w.SetWebsocketURL(s.RunningURL, false, false) - if err != nil { - return fmt.Errorf("%s %w", w.exchangeName, err) - } - - if s.RunningURLAuth != "" { - err = w.SetWebsocketURL(s.RunningURLAuth, true, false) - if err != nil { - return fmt.Errorf("%s %w", w.exchangeName, err) - } - } - } - - w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay - if w.connectionMonitorDelay <= 0 { - w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay - } - - if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second { - return fmt.Errorf("%s %w cannot be less than %s", - w.exchangeName, - errInvalidTrafficTimeout, - time.Second) - } - w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout - - w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) - - if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil { - return err - } - - w.Trade.Setup(s.TradeFeed, w.DataHandler) - w.Fills.Setup(s.FillsFeed, w.DataHandler) - - if s.MaxWebsocketSubscriptionsPerConnection < 0 { - return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) - } - w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection - w.setState(disconnectedState) - - w.rateLimitDefinitions = s.RateLimitDefinitions - return nil -} - -// SetupNewConnection sets up an auth or unauth streaming connection -func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { - if w == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) - } - - if c == nil || c.ResponseCheckTimeout == 0 && - c.ResponseMaxLimit == 0 && - c.RateLimit == nil && - c.URL == "" && - c.ConnectionLevelReporter == nil && - c.BespokeGenerateMessageID == nil { - return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) - } - - if w.exchangeName == "" { - return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) - } - if w.TrafficAlert == nil { - return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) - } - if w.ReadMessageErrors == nil { - return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) - } - if c.ConnectionLevelReporter == nil { - c.ConnectionLevelReporter = w.ExchangeLevelReporter - } - if c.ConnectionLevelReporter == nil { - c.ConnectionLevelReporter = globalReporter - } - - if w.useMultiConnectionManagement { - // The connection and supporting functions are defined per connection - // and the connection wrapper is stored in the connection manager. - if c.URL == "" { - return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) - } - if c.Connector == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) - } - if c.GenerateSubscriptions == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) - } - if c.Subscriber == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) - } - if c.Unsubscriber == nil && w.features.Unsubscribe { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) - } - if c.Handler == nil { - return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) - } - - if c.MessageFilter != nil && !reflect.TypeOf(c.MessageFilter).Comparable() { - return errMessageFilterNotComparable - } - - for x := range w.connectionManager { - // Below allows for multiple connections to the same URL with different outbound request signatures. This - // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on - // a spot connection. - if w.connectionManager[x].Setup.URL == c.URL && c.MessageFilter == w.connectionManager[x].Setup.MessageFilter { - return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) - } - } - w.connectionManager = append(w.connectionManager, &ConnectionWrapper{ - Setup: c, - Subscriptions: subscription.NewStore(), - }) - return nil - } - - if c.Authenticated { - w.AuthConn = w.getConnectionFromSetup(c) - } else { - w.Conn = w.getConnectionFromSetup(c) - } - - return nil -} - -// getConnectionFromSetup returns a websocket connection from a setup -// configuration. This is used for setting up new connections on the fly. -func (w *Websocket) getConnectionFromSetup(c *ConnectionSetup) *WebsocketConnection { - connectionURL := w.GetWebsocketURL() - if c.URL != "" { - connectionURL = c.URL - } - return &WebsocketConnection{ - ExchangeName: w.exchangeName, - URL: connectionURL, - ProxyURL: w.GetProxyAddress(), - Verbose: w.verbose, - ResponseMaxLimit: c.ResponseMaxLimit, - Traffic: w.TrafficAlert, - readMessageErrors: w.ReadMessageErrors, - shutdown: w.ShutdownC, - Wg: &w.Wg, - Match: w.Match, - RateLimit: c.RateLimit, - Reporter: c.ConnectionLevelReporter, - bespokeGenerateMessageID: c.BespokeGenerateMessageID, - RateLimitDefinitions: w.rateLimitDefinitions, - } -} - -// Connect initiates a websocket connection by using a package defined connection -// function -func (w *Websocket) Connect() error { - w.m.Lock() - defer w.m.Unlock() - return w.connect() -} - -func (w *Websocket) connect() error { - if !w.IsEnabled() { - return ErrWebsocketNotEnabled - } - if w.IsConnecting() { - return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) - } - if w.IsConnected() { - return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) - } - - if w.subscriptions == nil { - return fmt.Errorf("%w: subscriptions", common.ErrNilPointer) - } - w.subscriptions.Clear() - - w.setState(connectingState) - - w.Wg.Add(2) - go w.monitorFrame(&w.Wg, w.monitorData) - go w.monitorFrame(&w.Wg, w.monitorTraffic) - - if !w.useMultiConnectionManagement { - if w.connector == nil { - return fmt.Errorf("%v %w", w.exchangeName, errNoConnectFunc) - } - err := w.connector() - if err != nil { - w.setState(disconnectedState) - return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) - } - w.setState(connectedState) - - if w.connectionMonitorRunning.CompareAndSwap(false, true) { - // This oversees all connections and does not need to be part of wait group management. - go w.monitorFrame(nil, w.monitorConnection) - } - - subs, err := w.GenerateSubs() // regenerate state on new connection - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) - } - if len(subs) != 0 { - if err := w.SubscribeToChannels(nil, subs); err != nil { - return err - } - - if missing := w.subscriptions.Missing(subs); len(missing) > 0 { - return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing) - } - } - return nil - } - - if len(w.connectionManager) == 0 { - w.setState(disconnectedState) - return fmt.Errorf("cannot connect: %w", errNoPendingConnections) - } - - // multiConnectFatalError is a fatal error that will cause all connections to - // be shutdown and the websocket to be disconnected. - var multiConnectFatalError error - - // subscriptionError is a non-fatal error that does not shutdown connections - var subscriptionError error - - // TODO: Implement concurrency below. - for i := range w.connectionManager { - if w.connectionManager[i].Setup.GenerateSubscriptions == nil { - multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriptionsGeneratorUnset) - break - } - - subs, err := w.connectionManager[i].Setup.GenerateSubscriptions() // regenerate state on new connection - if err != nil { - multiConnectFatalError = fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) - break - } - - if len(subs) == 0 { - // If no subscriptions are generated, we skip the connection - if w.verbose { - log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", w.exchangeName) - } - continue - } - - if w.connectionManager[i].Setup.Connector == nil { - multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errNoConnectFunc) - break - } - if w.connectionManager[i].Setup.Handler == nil { - multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketDataHandlerUnset) - break - } - if w.connectionManager[i].Setup.Subscriber == nil { - multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, w.connectionManager[i].Setup.URL, errWebsocketSubscriberUnset) - break - } - - // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. - - conn := w.getConnectionFromSetup(w.connectionManager[i].Setup) - - err = w.connectionManager[i].Setup.Connector(context.TODO(), conn) - if err != nil { - multiConnectFatalError = fmt.Errorf("%v Error connecting %w", w.exchangeName, err) - break - } - - if !conn.IsConnected() { - multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", w.exchangeName, i+1, conn.URL) - break - } - - w.connections[conn] = w.connectionManager[i] - w.connectionManager[i].Connection = conn - - w.Wg.Add(1) - go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) - - if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() { - err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn) - if err != nil { - multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err) - break - } - } - - err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) - if err != nil { - subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)) - continue - } - - if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 { - subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)) - continue - } - - if w.verbose { - log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]", - w.exchangeName, - i+1, - conn.URL, - len(subs)) - } - } - - if multiConnectFatalError != nil { - // Roll back any successful connections and flush subscriptions - for x := range w.connectionManager { - if w.connectionManager[x].Connection != nil { - if err := w.connectionManager[x].Connection.Shutdown(); err != nil { - log.Errorln(log.WebsocketMgr, err) - } - w.connectionManager[x].Connection = nil - } - w.connectionManager[x].Subscriptions.Clear() - } - clear(w.connections) - w.setState(disconnectedState) // Flip from connecting to disconnected. - - // Drain residual error in the single buffered channel, this mitigates - // the cycle when `Connect` is called again and the connectionMonitor - // starts but there is an old error in the channel. - drain(w.ReadMessageErrors) - - return multiConnectFatalError - } - - // Assume connected state here. All connections have been established. - // All subscriptions have been sent and stored. All data received is being - // handled by the appropriate data handler. - w.setState(connectedState) - - if w.connectionMonitorRunning.CompareAndSwap(false, true) { - // This oversees all connections and does not need to be part of wait group management. - go w.monitorFrame(nil, w.monitorConnection) - } - - return subscriptionError -} - -// Disable disables the exchange websocket protocol -// Note that connectionMonitor will be responsible for shutting down the websocket after disabling -func (w *Websocket) Disable() error { - if !w.IsEnabled() { - return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) - } - - w.setEnabled(false) - return nil -} - -// Enable enables the exchange websocket protocol -func (w *Websocket) Enable() error { - if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) - } - - w.setEnabled(true) - return w.Connect() -} - -// Shutdown attempts to shut down a websocket connection and associated routines -// by using a package defined shutdown function -func (w *Websocket) Shutdown() error { - w.m.Lock() - defer w.m.Unlock() - return w.shutdown() -} - -func (w *Websocket) shutdown() error { - if !w.IsConnected() { - return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) - } - - // TODO: Interrupt connection and or close connection when it is re-established. - if w.IsConnecting() { - return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) - } - - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) - } - - defer w.Orderbook.FlushBuffer() - - // During the shutdown process, all errors are treated as non-fatal to avoid issues when the connection has already - // been closed. In such cases, attempting to close the connection may result in a - // "failed to send closeNotify alert (but connection was closed anyway)" error. Treating these errors as non-fatal - // prevents the shutdown process from being interrupted, which could otherwise trigger a continuous traffic monitor - // cycle and potentially block the initiation of a new connection. - var nonFatalCloseConnectionErrors error - - // Shutdown managed connections - for x := range w.connectionManager { - if w.connectionManager[x].Connection != nil { - if err := w.connectionManager[x].Connection.Shutdown(); err != nil { - nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) - } - w.connectionManager[x].Connection = nil - // Flush any subscriptions from last connection across any managed connections - w.connectionManager[x].Subscriptions.Clear() - } - } - // Clean map of old connections - clear(w.connections) - - if w.Conn != nil { - if err := w.Conn.Shutdown(); err != nil { - nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) - } - } - if w.AuthConn != nil { - if err := w.AuthConn.Shutdown(); err != nil { - nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) - } - } - // flush any subscriptions from last connection if needed - w.subscriptions.Clear() - - w.setState(disconnectedState) - - close(w.ShutdownC) - w.Wg.Wait() - w.ShutdownC = make(chan struct{}) - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) - } - - // Drain residual error in the single buffered channel, this mitigates - // the cycle when `Connect` is called again and the connectionMonitor - // starts but there is an old error in the channel. - drain(w.ReadMessageErrors) - - if nonFatalCloseConnectionErrors != nil { - log.Warnf(log.WebsocketMgr, "%v websocket: shutdown error: %v", w.exchangeName, nonFatalCloseConnectionErrors) - } - - return nil -} - -// FlushChannels flushes channel subscriptions when there is a pair/asset change -func (w *Websocket) FlushChannels() error { - if !w.IsEnabled() { - return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) - } - - if !w.IsConnected() { - return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) - } - - // If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to - // maintain consistency. - if !w.features.Subscribe || !w.features.Unsubscribe { - w.m.Lock() - defer w.m.Unlock() - if err := w.shutdown(); err != nil { - return err - } - return w.connect() - } - - if !w.useMultiConnectionManagement { - newSubs, err := w.GenerateSubs() - if err != nil { - return err - } - return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs) - } - - for x := range w.connectionManager { - newSubs, err := w.connectionManager[x].Setup.GenerateSubscriptions() - if err != nil { - return err - } - - // Case if there is nothing to unsubscribe from and the connection is nil - if len(newSubs) == 0 && w.connectionManager[x].Connection == nil { - continue - } - - // If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection. - if w.connectionManager[x].Connection == nil { - conn := w.getConnectionFromSetup(w.connectionManager[x].Setup) - if err := w.connectionManager[x].Setup.Connector(context.TODO(), conn); err != nil { - return err - } - w.Wg.Add(1) - go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) - w.connections[conn] = w.connectionManager[x] - w.connectionManager[x].Connection = conn - } - - err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs) - if err != nil { - return err - } - - // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. - if w.connectionManager[x].Subscriptions.Len() == 0 { - delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map - if err := w.connectionManager[x].Connection.Shutdown(); err != nil { - log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) - } - w.connectionManager[x].Connection = nil - } - } - return nil -} - -// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels -// have been subscribed to or unsubscribed from. -func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error { - subs, unsubs := store.Diff(incoming) - if len(unsubs) != 0 { - if err := w.UnsubscribeChannels(c, unsubs); err != nil { - return err - } - - if contained := store.Contained(unsubs); len(contained) > 0 { - return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained) - } - } - if len(subs) != 0 { - if err := w.SubscribeToChannels(c, subs); err != nil { - return err - } - - if missing := store.Missing(subs); len(missing) > 0 { - return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing) - } - } - return nil -} - -func (w *Websocket) setState(s uint32) { - w.state.Store(s) -} - -// IsInitialised returns whether the websocket has been Setup() already -func (w *Websocket) IsInitialised() bool { - return w.state.Load() != uninitialisedState -} - -// IsConnected returns whether the websocket is connected -func (w *Websocket) IsConnected() bool { - return w.state.Load() == connectedState -} - -// IsConnecting returns whether the websocket is connecting -func (w *Websocket) IsConnecting() bool { - return w.state.Load() == connectingState -} - -func (w *Websocket) setEnabled(b bool) { - w.enabled.Store(b) -} - -// IsEnabled returns whether the websocket is enabled -func (w *Websocket) IsEnabled() bool { - return w.enabled.Load() -} - -// CanUseAuthenticatedWebsocketForWrapper Handles a common check to -// verify whether a wrapper can use an authenticated websocket endpoint -func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() { - if w.CanUseAuthenticatedEndpoints() { - return true - } - log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) - } - return false -} - -// SetWebsocketURL sets websocket URL and can refresh underlying connections -func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error { - if w.useMultiConnectionManagement { - // TODO: Add functionality for multi-connection management to change URL - return fmt.Errorf("%s: %w", w.exchangeName, errCannotChangeConnectionURL) - } - defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage - if auth { - if defaultVals { - url = w.defaultURLAuth - } - - err := checkWebsocketURL(url) - if err != nil { - return err - } - w.runningURLAuth = url - - if w.verbose { - log.Debugf(log.WebsocketMgr, "%s websocket: setting authenticated websocket URL: %s\n", w.exchangeName, url) - } - - if w.AuthConn != nil { - w.AuthConn.SetURL(url) - } - } else { - if defaultVals { - url = w.defaultURL - } - err := checkWebsocketURL(url) - if err != nil { - return err - } - w.runningURL = url - - if w.verbose { - log.Debugf(log.WebsocketMgr, "%s websocket: setting unauthenticated websocket URL: %s\n", w.exchangeName, url) - } - - if w.Conn != nil { - w.Conn.SetURL(url) - } - } - - if w.IsConnected() && reconnect { - log.Debugf(log.WebsocketMgr, "%s websocket: flushing websocket connection to %s\n", w.exchangeName, url) - return w.Shutdown() - } - return nil -} - -// GetWebsocketURL returns the running websocket URL -func (w *Websocket) GetWebsocketURL() string { - return w.runningURL -} - -// SetProxyAddress sets websocket proxy address -func (w *Websocket) SetProxyAddress(proxyAddr string) error { - w.m.Lock() - defer w.m.Unlock() - if proxyAddr != "" { - if _, err := url.ParseRequestURI(proxyAddr); err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) - } - - if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) - } - - log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) - } else { - log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) - } - - for _, wrapper := range w.connectionManager { - if wrapper.Connection != nil { - wrapper.Connection.SetProxy(proxyAddr) - } - } - if w.Conn != nil { - w.Conn.SetProxy(proxyAddr) - } - if w.AuthConn != nil { - w.AuthConn.SetProxy(proxyAddr) - } - - w.proxyAddr = proxyAddr - - if !w.IsConnected() { - return nil - } - if err := w.shutdown(); err != nil { - return err - } - return w.connect() -} - -// GetProxyAddress returns the current websocket proxy -func (w *Websocket) GetProxyAddress() string { - return w.proxyAddr -} - -// GetName returns exchange name -func (w *Websocket) GetName() string { - return w.exchangeName -} - -// UnsubscribeChannels unsubscribes from a list of websocket channel -func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error { - if len(channels) == 0 { - return nil // No channels to unsubscribe from is not an error - } - if wrapper, ok := w.connections[conn]; ok && conn != nil { - return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { - return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) - }) - } - - if w.Unsubscriber == nil { - return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer) - } - - return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error { - return w.Unsubscriber(channels) - }) -} - -func (w *Websocket) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error { - if store == nil { - return nil // No channels to unsubscribe from is not an error - } - for _, s := range channels { - if store.Get(s) == nil { - return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) - } - } - return unsub(channels) -} - -// ResubscribeToChannel resubscribes to channel -// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription -// Errors if subscription is already subscribing -func (w *Websocket) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error { - l := subscription.List{s} - if err := s.SetState(subscription.ResubscribingState); err != nil { - return fmt.Errorf("%w: %s", err, s) - } - if err := w.UnsubscribeChannels(conn, l); err != nil { - return err - } - return w.SubscribeToChannels(conn, l) -} - -// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method -// Errors are returned for duplicates or exceeding max Subscriptions -func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) error { - if slices.Contains(subs, nil) { - return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer) - } - if err := w.checkSubscriptions(conn, subs); err != nil { - return err - } - - if wrapper, ok := w.connections[conn]; ok && conn != nil { - return wrapper.Setup.Subscriber(context.TODO(), conn, subs) - } - - if w.Subscriber == nil { - return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer) - } - - if err := w.Subscriber(subs); err != nil { - return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err) - } - return nil -} - -// AddSubscriptions adds subscriptions to the subscription store -// Sets state to Subscribing unless the state is already set -func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error { - if w == nil { - return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) - } - var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { - subscriptionStore = &wrapper.Subscriptions - } else { - subscriptionStore = &w.subscriptions - } - - if *subscriptionStore == nil { - *subscriptionStore = subscription.NewStore() - } - var errs error - for _, s := range subs { - if s.State() == subscription.InactiveState { - if err := s.SetState(subscription.SubscribingState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) - } - } - if err := (*subscriptionStore).Add(s); err != nil { - errs = common.AppendError(errs, err) - } - } - return errs -} - -// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store -func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error { - if w == nil { - return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer) - } - - var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { - subscriptionStore = &wrapper.Subscriptions - } else { - subscriptionStore = &w.subscriptions - } - - if *subscriptionStore == nil { - *subscriptionStore = subscription.NewStore() - } - - var errs error - for _, s := range subs { - if err := s.SetState(subscription.SubscribedState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) - } - if err := (*subscriptionStore).Add(s); err != nil { - errs = common.AppendError(errs, err) - } - } - return errs -} - -// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed -func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error { - if w == nil { - return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer) - } - - var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { - subscriptionStore = wrapper.Subscriptions - } else { - subscriptionStore = w.subscriptions - } - - if subscriptionStore == nil { - return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer) - } - - var errs error - for _, s := range subs { - if err := s.SetState(subscription.UnsubscribedState); err != nil { - errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) - } - if err := subscriptionStore.Remove(s); err != nil { - errs = common.AppendError(errs, err) - } - } - return errs -} - -// GetSubscription returns a subscription at the key provided -// returns nil if no subscription is at that key or the key is nil -// Keys can implement subscription.MatchableKey in order to provide custom matching logic -func (w *Websocket) GetSubscription(key any) *subscription.Subscription { - if w == nil || key == nil { - return nil - } - for _, c := range w.connectionManager { - if c.Subscriptions == nil { - continue - } - sub := c.Subscriptions.Get(key) - if sub != nil { - return sub - } - } - if w.subscriptions == nil { - return nil - } - return w.subscriptions.Get(key) -} - -// GetSubscriptions returns a new slice of the subscriptions -func (w *Websocket) GetSubscriptions() subscription.List { - if w == nil { - return nil - } - var subs subscription.List - for _, c := range w.connectionManager { - if c.Subscriptions != nil { - subs = append(subs, c.Subscriptions.List()...) - } - } - if w.subscriptions != nil { - subs = append(subs, w.subscriptions.List()...) - } - return subs -} - -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { - w.canUseAuthenticatedEndpoints.Store(b) -} - -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner -func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - return w.canUseAuthenticatedEndpoints.Load() -} - -// checkWebsocketURL checks for a valid websocket url -func checkWebsocketURL(s string) error { - u, err := url.Parse(s) - if err != nil { - return err - } - if u.Scheme != "ws" && u.Scheme != "wss" { - return fmt.Errorf("cannot set %w %s", errInvalidWebsocketURL, s) - } - return nil -} - -// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists -// The subscription state is not considered when counting existing subscriptions -func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { - var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { - subscriptionStore = wrapper.Subscriptions - } else { - subscriptionStore = w.subscriptions - } - if subscriptionStore == nil { - return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) - } - - existing := subscriptionStore.Len() - if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { - return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", - errSubscriptionsExceedsLimit, - existing, - len(subs), - w.MaxSubscriptionsPerConnection) - } - - for _, s := range subs { - if s.State() == subscription.ResubscribingState { - continue - } - if found := subscriptionStore.Get(s); found != nil { - return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) - } - } - - return nil -} - -// Reader reads and handles data from a specific connection -func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, message []byte) error) { - defer w.Wg.Done() - for { - resp := conn.ReadMessage() - if resp.Raw == nil { - return // Connection has been closed - } - if err := handler(ctx, resp.Raw); err != nil { - w.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) - } - } -} - -func drain(ch <-chan error) { - for { - select { - case <-ch: - default: - return - } - } -} - -// ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit -type ClosureFrame func() func() bool - -// monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true -// This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup -// is optional and is used to signal when the monitor has finished. -func (w *Websocket) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { - if wg != nil { - defer w.Wg.Done() - } - observe := fn() - for { - if observe() { - return - } - } -} - -// monitorData monitors data throughput and logs if there is a back log of data -func (w *Websocket) monitorData() func() bool { - dropped := 0 - return func() bool { return w.observeData(&dropped) } -} - -// observeData observes data throughput and logs if there is a back log of data -func (w *Websocket) observeData(dropped *int) (exit bool) { - select { - case <-w.ShutdownC: - return true - case d := <-w.DataHandler: - select { - case w.ToRoutine <- d: - if *dropped != 0 { - log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) - *dropped = 0 - } - default: - if *dropped == 0 { - // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible - log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) - } - *dropped++ - } - return false - } -} - -// monitorConnection monitors the connection and attempts to reconnect if the connection is lost -func (w *Websocket) monitorConnection() func() bool { - timer := time.NewTimer(w.connectionMonitorDelay) - return func() bool { return w.observeConnection(timer) } -} - -// observeConnection observes the connection and attempts to reconnect if the connection is lost -func (w *Websocket) observeConnection(t *time.Timer) (exit bool) { - select { - case err := <-w.ReadMessageErrors: - if errors.Is(err, errConnectionFault) { - log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - if w.IsConnected() { - if shutdownErr := w.Shutdown(); shutdownErr != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) - } - } - } - // Speedier reconnection, instead of waiting for the next cycle. - if w.IsEnabled() && (!w.IsConnected() && !w.IsConnecting()) { - if connectErr := w.Connect(); connectErr != nil { - log.Errorln(log.WebsocketMgr, connectErr) - } - } - w.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) - case <-t.C: - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) - } - if !w.IsEnabled() { - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) - } - if w.IsConnected() { - if err := w.Shutdown(); err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) - } - t.Stop() - w.connectionMonitorRunning.Store(false) - return true - } - if !w.IsConnecting() && !w.IsConnected() { - err := w.Connect() - if err != nil { - log.Errorln(log.WebsocketMgr, err) - } - } - t.Reset(w.connectionMonitorDelay) - } - return false -} - -// monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic -// the connection is shutdown and will be reconnected by the connectionMonitor routine. -func (w *Websocket) monitorTraffic() func() bool { - timer := time.NewTimer(w.trafficTimeout) - return func() bool { return w.observeTraffic(timer) } -} - -func (w *Websocket) observeTraffic(t *time.Timer) bool { - select { - case <-w.ShutdownC: - if w.verbose { - log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) - } - case <-t.C: - if w.IsConnecting() || signalReceived(w.TrafficAlert) { - t.Reset(w.trafficTimeout) - return false - } - if w.verbose { - log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) - } - if w.IsConnected() { - go func() { // Without this the w.Shutdown() call below will deadlock - if err := w.Shutdown(); err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) - } - }() - } - } - t.Stop() - return true -} - -// signalReceived checks if a signal has been received, this also clears the signal. -func signalReceived(ch chan struct{}) bool { - select { - case <-ch: - return true - default: - return false - } -} - -// GetConnection returns a connection by message filter (defined in exchange package _wrapper.go websocket connection) -// for request and response handling in a multi connection context. -func (w *Websocket) GetConnection(messageFilter any) (Connection, error) { - if w == nil { - return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) - } - - if messageFilter == nil { - return nil, errMessageFilterNotSet - } - - w.m.Lock() - defer w.m.Unlock() - - if !w.useMultiConnectionManagement { - return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) - } - - if !w.IsConnected() { - return nil, ErrNotConnected - } - - for _, wrapper := range w.connectionManager { - if wrapper.Setup.MessageFilter == messageFilter { - if wrapper.Connection == nil { - return nil, fmt.Errorf("%s: %s %w associated with message filter: '%v'", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, messageFilter) - } - return wrapper.Connection, nil - } - } - - return nil, fmt.Errorf("%s: %w associated with message filter: '%v'", w.exchangeName, ErrRequestRouteNotFound, messageFilter) -} diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go deleted file mode 100644 index 63118f34..00000000 --- a/exchanges/stream/websocket_connection.go +++ /dev/null @@ -1,381 +0,0 @@ -package stream - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "context" - "crypto/rand" - "errors" - "fmt" - "io" - "math/big" - "net" - "net/http" - "net/url" - "strings" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/encoding/json" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/log" -) - -var ( - // errConnectionFault is a connection fault error which alerts the system that a connection cycle needs to take place. - errConnectionFault = errors.New("connection fault") - errWebsocketIsDisconnected = errors.New("websocket connection is disconnected") - errRateLimitNotFound = errors.New("rate limit definition not found") -) - -// Dial sets proxy urls and then connects to the websocket -func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error { - return w.DialContext(context.Background(), dialer, headers) -} - -// DialContext sets proxy urls and then connects to the websocket -func (w *WebsocketConnection) DialContext(ctx context.Context, dialer *websocket.Dialer, headers http.Header) error { - if w.ProxyURL != "" { - proxy, err := url.Parse(w.ProxyURL) - if err != nil { - return err - } - dialer.Proxy = http.ProxyURL(proxy) - } - - var err error - var conStatus *http.Response - w.Connection, conStatus, err = dialer.DialContext(ctx, w.URL, headers) - if err != nil { - if conStatus != nil { - _ = conStatus.Body.Close() - return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) - } - return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) - } - _ = conStatus.Body.Close() - - if w.Verbose { - log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) - } - select { - case w.Traffic <- struct{}{}: - default: - } - w.setConnectedStatus(true) - return nil -} - -// SendJSONMessage sends a JSON encoded message over the connection -func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error { - return w.writeToConn(ctx, epl, func() error { - if request.IsVerbose(ctx, w.Verbose) { - if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway - log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(msg)) - } - } - return w.Connection.WriteJSON(data) - }) -} - -// SendRawMessage sends a message over the connection without JSON encoding it -func (w *WebsocketConnection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error { - return w.writeToConn(ctx, epl, func() error { - if request.IsVerbose(ctx, w.Verbose) { - log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(message)) - } - return w.Connection.WriteMessage(messageType, message) - }) -} - -func (w *WebsocketConnection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error { - if !w.IsConnected() { - return fmt.Errorf("%v websocket connection: cannot send message %w", w.ExchangeName, errWebsocketIsDisconnected) - } - - var rl *request.RateLimiterWithWeight - if w.RateLimitDefinitions != nil { - var ok bool - if rl, ok = w.RateLimitDefinitions[epl]; !ok && w.RateLimit == nil { - // Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is - // set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits - // if they are defined. - return fmt.Errorf("%s websocket connection: %w for %v", w.ExchangeName, errRateLimitNotFound, epl) - } - } - - if rl == nil { - // If a global rate limit definition is not found, use the connection rate limit as a fallback. - rl = w.RateLimit - } - - if rl != nil { - if err := request.RateLimit(ctx, rl); err != nil { - return fmt.Errorf("%s websocket connection: rate limit error: %w", w.ExchangeName, err) - } - } - // This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check. - w.writeControl.Lock() - defer w.writeControl.Unlock() - return writeConn() -} - -// SetupPingHandler will automatically send ping or pong messages based on -// WebsocketPingHandler configuration -func (w *WebsocketConnection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) { - if handler.UseGorillaHandler { - w.Connection.SetPingHandler(func(msg string) error { - err := w.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay)) - if err == websocket.ErrCloseSent { - return nil - } else if e, ok := err.(net.Error); ok && e.Timeout() { - return nil - } - return err - }) - return - } - w.Wg.Add(1) - go func() { - defer w.Wg.Done() - ticker := time.NewTicker(handler.Delay) - for { - select { - case <-w.shutdown: - ticker.Stop() - return - case <-ticker.C: - err := w.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message) - if err != nil { - log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", w.ExchangeName, handler.Message, err) - return - } - } - } - }() -} - -// setConnectedStatus sets connection status if changed it will return true. -// TODO: Swap out these atomic switches and opt for sync.RWMutex. -func (w *WebsocketConnection) setConnectedStatus(b bool) bool { - if b { - return atomic.SwapInt32(&w.connected, 1) == 0 - } - return atomic.SwapInt32(&w.connected, 0) == 1 -} - -// IsConnected exposes websocket connection status -func (w *WebsocketConnection) IsConnected() bool { - return atomic.LoadInt32(&w.connected) == 1 -} - -// ReadMessage reads messages, can handle text, gzip and binary -func (w *WebsocketConnection) ReadMessage() Response { - mType, resp, err := w.Connection.ReadMessage() - if err != nil { - // If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the - // reader routine to exit. This leaves the connection without an active reader, - // leading to potential buffer issue from the ongoing websocket writes. - // Such errors are passed to `w.readMessageErrors` when the connection is active. - // The `connectionMonitor` handles these errors by flushing the buffer, reconnecting, - // and resubscribing to the websocket to restore the connection. - if w.setConnectedStatus(false) { - // NOTE: When w.setConnectedStatus() returns true the underlying - // state was changed and this infers that the connection was - // externally closed and an error is reported else Shutdown() - // method on WebsocketConnection type has been called and can - // be skipped. - select { - case w.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault): - default: - // bypass if there is no receiver, as this stops it returning - // when shutdown is called. - log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", w.ExchangeName, err) - } - } - return Response{} - } - - select { - case w.Traffic <- struct{}{}: - default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding - } - - var standardMessage []byte - switch mType { - case websocket.TextMessage: - standardMessage = resp - case websocket.BinaryMessage: - standardMessage, err = w.parseBinaryResponse(resp) - if err != nil { - log.Errorf(log.WebsocketMgr, "%v %v: Parse binary response error: %v", w.ExchangeName, removeURLQueryString(w.URL), err) - return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case. - } - } - if w.Verbose { - log.Debugf(log.WebsocketMgr, "%v %v: Message received: %v", w.ExchangeName, removeURLQueryString(w.URL), string(standardMessage)) - } - return Response{Raw: standardMessage, Type: mType} -} - -// parseBinaryResponse parses a websocket binary response into a usable byte array -func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { - var reader io.ReadCloser - var err error - if len(resp) >= 2 && resp[0] == 31 && resp[1] == 139 { // Detect GZIP - reader, err = gzip.NewReader(bytes.NewReader(resp)) - if err != nil { - return nil, err - } - } else { - reader = flate.NewReader(bytes.NewReader(resp)) - } - standardMessage, err := io.ReadAll(reader) - if err != nil { - return nil, err - } - return standardMessage, reader.Close() -} - -// GenerateMessageID generates a message ID for the individual connection. -// If a bespoke function is set (by using SetupNewConnection) it will use that, -// otherwise it will use the defaultGenerateMessageID function. -func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { - if w.bespokeGenerateMessageID != nil { - return w.bespokeGenerateMessageID(highPrec) - } - return w.defaultGenerateMessageID(highPrec) -} - -// defaultGenerateMessageID generates the default message ID -func (w *WebsocketConnection) defaultGenerateMessageID(highPrec bool) int64 { - var minValue int64 = 1e8 - var maxValue int64 = 2e8 - if highPrec { - maxValue = 2e12 - minValue = 1e12 - } - // utilization of hard coded positive numbers and default crypto/rand - // io.reader will panic on error instead of returning - randomNumber, err := rand.Int(rand.Reader, big.NewInt(maxValue-minValue+1)) - if err != nil { - panic(err) - } - return randomNumber.Int64() + minValue -} - -// Shutdown shuts down and closes specific connection -func (w *WebsocketConnection) Shutdown() error { - if w == nil || w.Connection == nil { - return nil - } - w.setConnectedStatus(false) - w.writeControl.Lock() - defer w.writeControl.Unlock() - return w.Connection.NetConn().Close() -} - -// SetURL sets connection URL -func (w *WebsocketConnection) SetURL(url string) { - w.URL = url -} - -// SetProxy sets connection proxy -func (w *WebsocketConnection) SetProxy(proxy string) { - w.ProxyURL = proxy -} - -// GetURL returns the connection URL -func (w *WebsocketConnection) GetURL() string { - return w.URL -} - -// SendMessageReturnResponse will send a WS message to the connection and wait for response -func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) { - resps, err := w.SendMessageReturnResponses(ctx, epl, signature, request, 1) - if err != nil { - return nil, err - } - return resps[0], nil -} - -// SendMessageReturnResponses will send a WS message to the connection and wait for N responses -// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) { - return w.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil) -} - -// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses -// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) { - outbound, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) - } - - ch, err := w.Match.Set(signature, expected) - if err != nil { - return nil, err - } - - start := time.Now() - err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound) - if err != nil { - return nil, err - } - - resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector) - if err != nil { - return nil, err - } - - if w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) - } - - return resps, err -} - -// waitForResponses waits for N responses from a channel -func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) { - timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) - defer timeout.Stop() - - resps := make([][]byte, 0, expected) -inspection: - for range expected { - select { - case resp := <-ch: - resps = append(resps, resp) - // Checks recently received message to determine if this is in fact the final message in a sequence of messages. - if messageInspector != nil && messageInspector.IsFinal(resp) { - w.Match.RemoveSignature(signature) - break inspection - } - case <-timeout.C: - w.Match.RemoveSignature(signature) - return nil, fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) - case <-ctx.Done(): - w.Match.RemoveSignature(signature) - return nil, ctx.Err() - } - } - - // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. - if request.IsVerbose(ctx, false) { - for i := range resps { - log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", w.ExchangeName, removeURLQueryString(w.URL), i+1, len(resps), string(resps[i])) - } - } - - return resps, nil -} - -func removeURLQueryString(url string) string { - if index := strings.Index(url, "?"); index != -1 { - return url[:index] - } - return url -} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go deleted file mode 100644 index c42a6b39..00000000 --- a/exchanges/stream/websocket_types.go +++ /dev/null @@ -1,187 +0,0 @@ -package stream - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "github.com/thrasher-corp/gocryptotrader/config" - "github.com/thrasher-corp/gocryptotrader/exchanges/fill" - "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" - "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" - "github.com/thrasher-corp/gocryptotrader/exchanges/trade" -) - -// Websocket functionality list and state consts -const ( - WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" - Ping = "ping" - Pong = "pong" - UnhandledMessage = " - Unhandled websocket message: " -) - -const ( - uninitialisedState uint32 = iota - disconnectedState - connectingState - connectedState -) - -// Websocket defines a return type for websocket connections via the interface -// wrapper for routine processing -type Websocket struct { - canUseAuthenticatedEndpoints atomic.Bool - enabled atomic.Bool - state atomic.Uint32 - verbose bool - connectionMonitorRunning atomic.Bool - trafficTimeout time.Duration - connectionMonitorDelay time.Duration - proxyAddr string - defaultURL string - defaultURLAuth string - runningURL string - runningURLAuth string - exchangeName string - m sync.Mutex - connector func() error - - // connectionManager stores all *potential* connections for the exchange, organised within ConnectionWrapper structs. - // Each ConnectionWrapper one connection (will be expanded soon) tailored for specific exchange functionalities or asset types. // TODO: Expand this to support multiple connections per ConnectionWrapper - // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful - // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. - // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. - connectionManager []*ConnectionWrapper - // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder - connections map[Connection]*ConnectionWrapper - - subscriptions *subscription.Store - - // Subscriber function for exchange specific subscribe implementation - Subscriber func(subscription.List) error - // Subscriber function for exchange specific unsubscribe implementation - Unsubscriber func(subscription.List) error - // GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets - GenerateSubs func() (subscription.List, error) - - useMultiConnectionManagement bool - - DataHandler chan any - ToRoutine chan any - - Match *Match - - // shutdown synchronises shutdown event across routines - ShutdownC chan struct{} - Wg sync.WaitGroup - - // Orderbook is a local buffer of orderbooks - Orderbook buffer.Orderbook - - // Trade is a notifier of occurring trades - Trade trade.Trade - - // Fills is a notifier of occurring fills - Fills fill.Fills - - // trafficAlert monitors if there is a halt in traffic throughput - TrafficAlert chan struct{} - // ReadMessageErrors will received all errors from ws.ReadMessage() and - // verify if its a disconnection - ReadMessageErrors chan error - features *protocol.Features - - // Standard stream connection - Conn Connection - // Authenticated stream connection - AuthConn Connection - - // Latency reporter - ExchangeLevelReporter Reporter - - // MaxSubScriptionsPerConnection defines the maximum number of - // subscriptions per connection that is allowed by the exchange. - MaxSubscriptionsPerConnection int - - // rateLimitDefinitions contains the rate limiters shared between Websocket and REST connections for all potential - // endpoints. - rateLimitDefinitions request.RateLimitDefinitions -} - -// WebsocketSetup defines variables for setting up a websocket connection -type WebsocketSetup struct { - ExchangeConfig *config.Exchange - DefaultURL string - RunningURL string - RunningURLAuth string - Connector func() error - Subscriber func(subscription.List) error - Unsubscriber func(subscription.List) error - GenerateSubscriptions func() (subscription.List, error) - Features *protocol.Features - - // Local orderbook buffer config values - OrderbookBufferConfig buffer.Config - - // UseMultiConnectionManagement allows the connections to be managed by the - // connection manager. If false, this will default to the global fields - // provided in this struct. - UseMultiConnectionManagement bool - - TradeFeed bool - - // Fill data config values - FillsFeed bool - - // MaxWebsocketSubscriptionsPerConnection defines the maximum number of - // subscriptions per connection that is allowed by the exchange. - MaxWebsocketSubscriptionsPerConnection int - - // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all endpoints. - // These rate limits take precedence over any rate limits specified in individual connection configurations. - // If no connection-specific rate limit is provided and the endpoint does not match any of these definitions, - // an error will be returned. However, if a connection configuration includes its own rate limit, - // it will fall back to that configuration’s rate limit without raising an error. - RateLimitDefinitions request.RateLimitDefinitions -} - -// WebsocketConnection contains all the data needed to send a message to a WS -// connection -type WebsocketConnection struct { - Verbose bool - connected int32 - - // Gorilla websocket does not allow more than one goroutine to utilise - // writes methods - writeControl sync.Mutex - - // RateLimit is a rate limiter for the connection itself - RateLimit *request.RateLimiterWithWeight - // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all - // potential endpoints. - RateLimitDefinitions request.RateLimitDefinitions - - ExchangeName string - URL string - ProxyURL string - Wg *sync.WaitGroup - Connection *websocket.Conn - - // shutdown synchronises shutdown event across routines associated with this connection only e.g. ping handler - shutdown chan struct{} - - Match *Match - ResponseMaxLimit time.Duration - Traffic chan struct{} - readMessageErrors chan error - - // bespokeGenerateMessageID is a function that returns a unique message ID - // defined externally. This is used for exchanges that require a unique - // message ID for each message sent. - bespokeGenerateMessageID func(highPrecision bool) int64 - - Reporter Reporter -} diff --git a/internal/exchange/websocket/README.md b/internal/exchange/websocket/README.md new file mode 100644 index 00000000..bd038473 --- /dev/null +++ b/internal/exchange/websocket/README.md @@ -0,0 +1,174 @@ +# GoCryptoTrader package Websocket + + + + +[![Build Status](https://github.com/thrasher-corp/gocryptotrader/actions/workflows/tests.yml/badge.svg?branch=master)](https://github.com/thrasher-corp/gocryptotrader/actions/workflows/tests.yml) +[![Software License](https://img.shields.io/badge/License-MIT-orange.svg?style=flat-square)](https://github.com/thrasher-corp/gocryptotrader/blob/master/LICENSE) +[![GoDoc](https://godoc.org/github.com/thrasher-corp/gocryptotrader?status.svg)](https://godoc.org/github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket) +[![Coverage Status](https://codecov.io/gh/thrasher-corp/gocryptotrader/graph/badge.svg?token=41784B23TS)](https://codecov.io/gh/thrasher-corp/gocryptotrader) +[![Go Report Card](https://goreportcard.com/badge/github.com/thrasher-corp/gocryptotrader)](https://goreportcard.com/report/github.com/thrasher-corp/gocryptotrader) + + +This websocket package is part of the GoCryptoTrader codebase. + +## This is still in active development + +You can track ideas, planned features and what's in progress on our [GoCryptoTrader Kanban board](https://github.com/orgs/thrasher-corp/projects/3). + +Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk) + +## Overview + +The `websocket` package provides methods to manage connections and subscriptions for exchange websockets. + +## Features + +- Handle real-time market data streams +- Unified interface for managing data streams +- Multi-connection management - a system that can be used to manage multiple connections to the same exchange +- Connection monitoring - a system that can be used to monitor the health of the websocket connections. This can be used to check if the connection is still alive and if it is not, it will attempt to reconnect +- Traffic monitoring - will reconnect if no message is sent for a period of time defined in your config +- Subscription management - a system that can be used to manage subscriptions to various data streams +- Rate limiting - a system that can be used to rate limit the number of requests sent to the exchange +- Message ID generation - a system that can be used to generate message IDs for websocket requests +- Websocket message response matching - can be used to match websocket responses to the requests that were sent + +## Usage + +### Default single websocket connection + +Example setup for the `websocket` package connection: + +```go +package main + +import ( + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket" + exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" +) + +type Exchange struct { + exchange.Base +} + +// In the exchange wrapper this will set up the initial pointer field provided by exchange.Base +func (e *Exchange) SetDefault() { + e.Websocket = websocket.NewManager() + e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit + e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout + e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit +} + +// In the exchange wrapper this is the original setup pattern for the websocket services +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&websocket.ManagerSetup{ + ExchangeConfig: exch, + DefaultURL: connectionURLString, + RunningURL: connectionURLString, + Connector: e.WsConnect, + Subscriber: e.Subscribe, + Unsubscriber: e.Unsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptions, + Features: &e.Features.Supports.WebsocketCapabilities, + MaxWebsocketSubscriptionsPerConnection: 240, + OrderbookBufferConfig: buffer.Config{ Checksum: e.CalculateUpdateOrderbookChecksum }, + }); err != nil { + return err + } + + // This is a public websocket connection + if err := ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ + URL: connectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }); err != nil { + return err + } + + // This is a private websocket connection + return ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ + URL: privateConnectionURLString, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exchangeWebsocketResponseMaxLimit, + Authenticated: true, + RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1), + }) +} +``` + +### Multiple websocket connections + The example below provides the now optional multi connection management system which allows for more connections + to be maintained and established based off URL, connections types, asset types etc. +```go +func (e *Exchange) Setup(exch *config.Exchange) error { + // This sets up global connection, sub, unsub and generate subscriptions for each connection defined below. + if err := e.Websocket.Setup(&websocket.ManagerSetup{ + ExchangeConfig: exch, + Features: &e.Features.Supports.WebsocketCapabilities, + FillsFeed: e.Features.Enabled.FillsFeed, + TradeFeed: e.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, + }) + if err != nil { + return err + } + // Spot connection + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ + URL: connectionURLStringForSpot, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: e.WsHandleSpotData, + Subscriber: e.SpotSubscribe, + Unsubscriber: e.SpotUnsubscribe, + GenerateSubscriptions: e.GenerateDefaultSubscriptionsSpot, + Connector: e.WsConnectSpot, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } + // Futures connection - USDT margined + err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{ + URL: connectionURLStringForSpotForFutures, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + // Custom handlers for the specific connection: + Handler: func(ctx context.Context, incoming []byte) error { return e.WsHandleFuturesData(ctx, incoming, asset.Futures) }, + Subscriber: e.FuturesSubscribe, + Unsubscriber: e.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return e.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: e.WsFuturesConnect, + BespokeGenerateMessageID: e.GenerateWebsocketMessageID, + }) + if err != nil { + return err + } +} +``` + + +## Contribution + +Please feel free to submit any pull requests or suggest any desired features to be added. + +When submitting a PR, please abide by our coding guidelines: + ++ Code must adhere to the official Go [formatting](https://golang.org/doc/effective_go.html#formatting) guidelines (i.e. uses [gofmt](https://golang.org/cmd/gofmt/)). ++ Code must be documented adhering to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines. ++ Code must adhere to our [coding style](https://github.com/thrasher-corp/gocryptotrader/blob/master/doc/coding_style.md). ++ Pull requests need to be based on and opened against the `master` branch. + +## Donations + + + +If this framework helped you in any way, or you would like to support the developers working on it, please donate Bitcoin to: + +***bc1qk0jareu4jytc0cfrhr5wgshsq8282awpavfahc*** diff --git a/exchanges/stream/buffer/buffer.go b/internal/exchange/websocket/buffer/buffer.go similarity index 98% rename from exchanges/stream/buffer/buffer.go rename to internal/exchange/websocket/buffer/buffer.go index 5cc0915b..c5f6a7cd 100644 --- a/exchanges/stream/buffer/buffer.go +++ b/internal/exchange/websocket/buffer/buffer.go @@ -35,8 +35,7 @@ var ( // Setup sets private variables func (w *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandler chan<- any) error { - if exchangeConfig == nil { // exchange config fields are checked in stream package - // prior to calling this, so further checks are not needed. + if exchangeConfig == nil { // exchange config fields are checked in websocket package prior to calling this, so further checks are not needed return fmt.Errorf(packageError, errExchangeConfigNil) } if c == nil { @@ -50,8 +49,7 @@ func (w *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandle return fmt.Errorf(packageError, errIssueBufferEnabledButNoLimit) } - // NOTE: These variables are set by config.json under "orderbook" for each - // individual exchange. + // NOTE: These variables are set by config.json under "orderbook" for each individual exchange w.bufferEnabled = exchangeConfig.Orderbook.WebsocketBufferEnabled w.obBufferLimit = exchangeConfig.Orderbook.WebsocketBufferLimit diff --git a/exchanges/stream/buffer/buffer_test.go b/internal/exchange/websocket/buffer/buffer_test.go similarity index 100% rename from exchanges/stream/buffer/buffer_test.go rename to internal/exchange/websocket/buffer/buffer_test.go diff --git a/exchanges/stream/buffer/buffer_types.go b/internal/exchange/websocket/buffer/buffer_types.go similarity index 100% rename from exchanges/stream/buffer/buffer_types.go rename to internal/exchange/websocket/buffer/buffer_types.go diff --git a/internal/exchange/websocket/connection.go b/internal/exchange/websocket/connection.go new file mode 100644 index 00000000..762f2036 --- /dev/null +++ b/internal/exchange/websocket/connection.go @@ -0,0 +1,481 @@ +package websocket + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + gws "github.com/gorilla/websocket" + "github.com/thrasher-corp/gocryptotrader/encoding/json" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/log" +) + +var ( + // errConnectionFault is a connection fault error which alerts the system that a connection cycle needs to take place. + errConnectionFault = errors.New("connection fault") + errWebsocketIsDisconnected = errors.New("websocket connection is disconnected") + errRateLimitNotFound = errors.New("rate limit definition not found") +) + +// Connection defines the interface for websocket connections +type Connection interface { + Dial(*gws.Dialer, http.Header) error + DialContext(context.Context, *gws.Dialer, http.Header) error + ReadMessage() Response + SetupPingHandler(request.EndpointLimit, PingHandler) + // GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set + // (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function + // defined in websocket_connection.go. + GenerateMessageID(highPrecision bool) int64 + // SendMessageReturnResponse will send a WS message to the connection and wait for response + SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) + // SendMessageReturnResponses will send a WS message to the connection and wait for N responses + SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int) ([][]byte, error) + // SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection + SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int, messageInspector Inspector) ([][]byte, error) + // SendRawMessage sends a message over the connection without JSON encoding it + SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error + // SendJSONMessage sends a JSON encoded message over the connection + SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error + SetURL(string) + SetProxy(string) + GetURL() string + Shutdown() error +} + +// ConnectionSetup defines variables for an individual stream connection +type ConnectionSetup struct { + ResponseCheckTimeout time.Duration + ResponseMaxLimit time.Duration + RateLimit *request.RateLimiterWithWeight + Authenticated bool + ConnectionLevelReporter Reporter + + // URL defines the websocket server URL to connect to + URL string + // Connector is the function that will be called to connect to the + // exchange's websocket server. This will be called once when the stream + // service is started. Any bespoke connection logic should be handled here. + Connector func(ctx context.Context, conn Connection) error + // GenerateSubscriptions is a function that will be called to generate a + // list of subscriptions to be made to the exchange's websocket server. + GenerateSubscriptions func() (subscription.List, error) + // Subscriber is a function that will be called to send subscription + // messages based on the exchange's websocket server requirements to + // subscribe to specific channels. + Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error + // Unsubscriber is a function that will be called to send unsubscription + // messages based on the exchange's websocket server requirements to + // unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED. + Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error + // Handler defines the function that will be called when a message is + // received from the exchange's websocket server. This function should + // handle the incoming message and pass it to the appropriate data handler. + Handler func(ctx context.Context, incoming []byte) error + // BespokeGenerateMessageID is a function that returns a unique message ID. + // This is useful for when an exchange connection requires a unique or + // structured message ID for each message sent. + BespokeGenerateMessageID func(highPrecision bool) int64 + Authenticate func(ctx context.Context, conn Connection) error + // MessageFilter defines the criteria used to match messages to a specific connection. + // The filter enables precise routing and handling of messages for distinct connection contexts. + MessageFilter any +} + +// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection +// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages +type Inspector interface { + IsFinal([]byte) bool +} + +// Response defines generalised data from the websocket connection +type Response struct { + Type int + Raw []byte +} + +// connection contains all the data needed to send a message to a websocket connection +type connection struct { + Verbose bool + connected int32 + writeControl sync.Mutex // Gorilla websocket does not allow more than one goroutine to utilise write methods + RateLimit *request.RateLimiterWithWeight // RateLimit is a rate limiter for the connection itself + RateLimitDefinitions request.RateLimitDefinitions // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections + Reporter Reporter + ExchangeName string + URL string + ProxyURL string + Wg *sync.WaitGroup + Connection *gws.Conn + shutdown chan struct{} + Match *Match + ResponseMaxLimit time.Duration + Traffic chan struct{} + readMessageErrors chan error + bespokeGenerateMessageID func(highPrecision bool) int64 +} + +// Dial sets proxy urls and then connects to the websocket +func (c *connection) Dial(dialer *gws.Dialer, headers http.Header) error { + return c.DialContext(context.Background(), dialer, headers) +} + +// DialContext sets proxy urls and then connects to the websocket +func (c *connection) DialContext(ctx context.Context, dialer *gws.Dialer, headers http.Header) error { + if c.ProxyURL != "" { + proxy, err := url.Parse(c.ProxyURL) + if err != nil { + return err + } + dialer.Proxy = http.ProxyURL(proxy) + } + + var err error + var conStatus *http.Response + c.Connection, conStatus, err = dialer.DialContext(ctx, c.URL, headers) + if err != nil { + if conStatus != nil { + _ = conStatus.Body.Close() + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", c.ExchangeName, c.URL, conStatus, conStatus.StatusCode, err) + } + return fmt.Errorf("%s websocket connection: %v Error: %w", c.ExchangeName, c.URL, err) + } + _ = conStatus.Body.Close() + + if c.Verbose { + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", c.ExchangeName, c.URL) + } + select { + case c.Traffic <- struct{}{}: + default: + } + c.setConnectedStatus(true) + return nil +} + +// SendJSONMessage sends a JSON encoded message over the connection +func (c *connection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error { + return c.writeToConn(ctx, epl, func() error { + if request.IsVerbose(ctx, c.Verbose) { + if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway + log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(msg)) + } + } + return c.Connection.WriteJSON(data) + }) +} + +// SendRawMessage sends a message over the connection without JSON encoding it +func (c *connection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error { + return c.writeToConn(ctx, epl, func() error { + if request.IsVerbose(ctx, c.Verbose) { + log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(message)) + } + return c.Connection.WriteMessage(messageType, message) + }) +} + +func (c *connection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error { + if !c.IsConnected() { + return fmt.Errorf("%v websocket connection: cannot send message %w", c.ExchangeName, errWebsocketIsDisconnected) + } + + var rl *request.RateLimiterWithWeight + if c.RateLimitDefinitions != nil { + var ok bool + if rl, ok = c.RateLimitDefinitions[epl]; !ok && c.RateLimit == nil { + // Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is + // set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits + // if they are defined. + return fmt.Errorf("%s websocket connection: %w for %v", c.ExchangeName, errRateLimitNotFound, epl) + } + } + + if rl == nil { + // If a global rate limit definition is not found, use the connection rate limit as a fallback. + rl = c.RateLimit + } + + if rl != nil { + if err := request.RateLimit(ctx, rl); err != nil { + return fmt.Errorf("%s websocket connection: rate limit error: %w", c.ExchangeName, err) + } + } + // This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check. + c.writeControl.Lock() + defer c.writeControl.Unlock() + return writeConn() +} + +// SetupPingHandler will automatically send ping or pong messages based on +// WebsocketPingHandler configuration +func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) { + if handler.UseGorillaHandler { + c.Connection.SetPingHandler(func(msg string) error { + err := c.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay)) + if err == gws.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Timeout() { + return nil + } + return err + }) + return + } + c.Wg.Add(1) + go func() { + defer c.Wg.Done() + ticker := time.NewTicker(handler.Delay) + for { + select { + case <-c.shutdown: + ticker.Stop() + return + case <-ticker.C: + err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message) + if err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err) + return + } + } + } + }() +} + +// setConnectedStatus sets connection status if changed it will return true. +// TODO: Swap out these atomic switches and opt for sync.RWMutex. +func (c *connection) setConnectedStatus(b bool) bool { + if b { + return atomic.SwapInt32(&c.connected, 1) == 0 + } + return atomic.SwapInt32(&c.connected, 0) == 1 +} + +// IsConnected exposes websocket connection status +func (c *connection) IsConnected() bool { + return atomic.LoadInt32(&c.connected) == 1 +} + +// ReadMessage reads messages, can handle text, gzip and binary +func (c *connection) ReadMessage() Response { + mType, resp, err := c.Connection.ReadMessage() + if err != nil { + // If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the + // reader routine to exit. This leaves the connection without an active reader, + // leading to potential buffer issue from the ongoing websocket writes. + // Such errors are passed to `c.readMessageErrors` when the connection is active. + // The `connectionMonitor` handles these errors by flushing the buffer, reconnecting, + // and resubscribing to the websocket to restore the connection. + if c.setConnectedStatus(false) { + // NOTE: When c.setConnectedStatus() returns true the underlying + // state was changed and this infers that the connection was + // externally closed and an error is reported else Shutdown() + // method on WebsocketConnection type has been called and can + // be skipped. + select { + case c.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault): + default: + // bypass if there is no receiver, as this stops it returning + // when shutdown is called. + log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", c.ExchangeName, err) + } + } + return Response{} + } + + select { + case c.Traffic <- struct{}{}: + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding + } + + var standardMessage []byte + switch mType { + case gws.TextMessage: + standardMessage = resp + case gws.BinaryMessage: + standardMessage, err = c.parseBinaryResponse(resp) + if err != nil { + log.Errorf(log.WebsocketMgr, "%v %v: Parse binary response error: %v", c.ExchangeName, removeURLQueryString(c.URL), err) + return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case. + } + } + if c.Verbose { + log.Debugf(log.WebsocketMgr, "%v %v: Message received: %v", c.ExchangeName, removeURLQueryString(c.URL), string(standardMessage)) + } + return Response{Raw: standardMessage, Type: mType} +} + +// parseBinaryResponse parses a websocket binary response into a usable byte array +func (c *connection) parseBinaryResponse(resp []byte) ([]byte, error) { + var reader io.ReadCloser + var err error + if len(resp) >= 2 && resp[0] == 31 && resp[1] == 139 { // Detect GZIP + reader, err = gzip.NewReader(bytes.NewReader(resp)) + if err != nil { + return nil, err + } + } else { + reader = flate.NewReader(bytes.NewReader(resp)) + } + standardMessage, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + return standardMessage, reader.Close() +} + +// GenerateMessageID generates a message ID for the individual connection. +// If a bespoke function is set (by using SetupNewConnection) it will use that, +// otherwise it will use the defaultGenerateMessageID function. +func (c *connection) GenerateMessageID(highPrec bool) int64 { + if c.bespokeGenerateMessageID != nil { + return c.bespokeGenerateMessageID(highPrec) + } + return c.defaultGenerateMessageID(highPrec) +} + +// defaultGenerateMessageID generates the default message ID +func (c *connection) defaultGenerateMessageID(highPrec bool) int64 { + var minValue int64 = 1e8 + var maxValue int64 = 2e8 + if highPrec { + maxValue = 2e12 + minValue = 1e12 + } + // utilization of hard coded positive numbers and default crypto/rand + // io.reader will panic on error instead of returning + randomNumber, err := rand.Int(rand.Reader, big.NewInt(maxValue-minValue+1)) + if err != nil { + panic(err) + } + return randomNumber.Int64() + minValue +} + +// Shutdown shuts down and closes specific connection +func (c *connection) Shutdown() error { + if c == nil || c.Connection == nil { + return nil + } + c.setConnectedStatus(false) + c.writeControl.Lock() + defer c.writeControl.Unlock() + return c.Connection.NetConn().Close() +} + +// SetURL sets connection URL +func (c *connection) SetURL(url string) { + c.URL = url +} + +// SetProxy sets connection proxy +func (c *connection) SetProxy(proxy string) { + c.ProxyURL = proxy +} + +// GetURL returns the connection URL +func (c *connection) GetURL() string { + return c.URL +} + +// SendMessageReturnResponse will send a WS message to the connection and wait for response +func (c *connection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) { + resps, err := c.SendMessageReturnResponses(ctx, epl, signature, request, 1) + if err != nil { + return nil, err + } + return resps[0], nil +} + +// SendMessageReturnResponses will send a WS message to the connection and wait for N responses +// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked +func (c *connection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) { + return c.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil) +} + +// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses +// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked +func (c *connection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) { + outbound, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) + } + + ch, err := c.Match.Set(signature, expected) + if err != nil { + return nil, err + } + + start := time.Now() + err = c.SendRawMessage(ctx, epl, gws.TextMessage, outbound) + if err != nil { + return nil, err + } + + resps, err := c.waitForResponses(ctx, signature, ch, expected, messageInspector) + if err != nil { + return nil, err + } + + if c.Reporter != nil { + c.Reporter.Latency(c.ExchangeName, outbound, time.Since(start)) + } + + return resps, err +} + +// waitForResponses waits for N responses from a channel +func (c *connection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) { + timeout := time.NewTimer(c.ResponseMaxLimit * time.Duration(expected)) + defer timeout.Stop() + + resps := make([][]byte, 0, expected) +inspection: + for range expected { + select { + case resp := <-ch: + resps = append(resps, resp) + // Checks recently received message to determine if this is in fact the final message in a sequence of messages. + if messageInspector != nil && messageInspector.IsFinal(resp) { + c.Match.RemoveSignature(signature) + break inspection + } + case <-timeout.C: + c.Match.RemoveSignature(signature) + return nil, fmt.Errorf("%s %w %v", c.ExchangeName, ErrSignatureTimeout, signature) + case <-ctx.Done(): + c.Match.RemoveSignature(signature) + return nil, ctx.Err() + } + } + + // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. + if request.IsVerbose(ctx, false) { + for i := range resps { + log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", c.ExchangeName, removeURLQueryString(c.URL), i+1, len(resps), string(resps[i])) + } + } + + return resps, nil +} + +func removeURLQueryString(url string) string { + if index := strings.Index(url, "?"); index != -1 { + return url[:index] + } + return url +} diff --git a/internal/exchange/websocket/manager.go b/internal/exchange/websocket/manager.go new file mode 100644 index 00000000..d283761f --- /dev/null +++ b/internal/exchange/websocket/manager.go @@ -0,0 +1,1070 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "net/url" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchanges/fill" + "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/exchanges/trade" + "github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket/buffer" + "github.com/thrasher-corp/gocryptotrader/log" +) + +// Public websocket errors +var ( + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") + ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") + ErrRequestRouteNotFound = errors.New("request route not found") + ErrSignatureNotSet = errors.New("signature not set") +) + +// Private websocket errors +var ( + errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") + errDefaultURLIsEmpty = errors.New("default url is empty") + errRunningURLIsEmpty = errors.New("running url cannot be empty") + errInvalidWebsocketURL = errors.New("invalid websocket url") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") + errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") + errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") + errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") + errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") + errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") + errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errDuplicateConnectionSetup = errors.New("duplicate connection setup") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errExchangeConfigEmpty = errors.New("exchange config is empty") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errMessageFilterNotComparable = errors.New("message filter is not comparable") +) + +// Websocket functionality list and state consts +const ( + UnhandledMessage = " - Unhandled websocket message: " + jobBuffer = 5000 +) + +const ( + uninitialisedState uint32 = iota + disconnectedState + connectingState + connectedState +) + +// Manager provides connection and subscription management and routing +type Manager struct { + enabled atomic.Bool + state atomic.Uint32 + verbose bool + canUseAuthenticatedEndpoints atomic.Bool + connectionMonitorRunning atomic.Bool + trafficTimeout time.Duration + connectionMonitorDelay time.Duration + proxyAddr string + defaultURL string + defaultURLAuth string + runningURL string + runningURLAuth string + exchangeName string + features *protocol.Features + m sync.Mutex + connections map[Connection]*connectionWrapper + subscriptions *subscription.Store + connector func() error + rateLimitDefinitions request.RateLimitDefinitions // rate limiters shared between Websocket and REST connections + Subscriber func(subscription.List) error + Unsubscriber func(subscription.List) error + GenerateSubs func() (subscription.List, error) + useMultiConnectionManagement bool + DataHandler chan any + ToRoutine chan any + Match *Match + ShutdownC chan struct{} + Wg sync.WaitGroup + Orderbook buffer.Orderbook + Trade trade.Trade // Trade is a notifier for trades + Fills fill.Fills // Fills is a notifier for fills + TrafficAlert chan struct{} + ReadMessageErrors chan error + Conn Connection // Public connection + AuthConn Connection // Authenticated Private connection + ExchangeLevelReporter Reporter // Latency reporter + MaxSubscriptionsPerConnection int + + // connectionManager stores all *potential* connections for the exchange, organised within connectionWrapper structs. + // Each connectionWrapper one connection (will be expanded soon) tailored for specific exchange functionalities or asset types. // TODO: Expand this to support multiple connections per connectionWrapper + // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful + // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. + // If an exchange does not require such differentiation, all connections may be managed under a single connectionWrapper. + + connectionManager []*connectionWrapper +} + +// ManagerSetup defines variables for setting up a websocket manager +type ManagerSetup struct { + ExchangeConfig *config.Exchange + DefaultURL string + RunningURL string + RunningURLAuth string + Connector func() error + Subscriber func(subscription.List) error + Unsubscriber func(subscription.List) error + GenerateSubscriptions func() (subscription.List, error) + Features *protocol.Features + OrderbookBufferConfig buffer.Config + + // UseMultiConnectionManagement allows the connections to be managed by the + // connection manager. If false, this will default to the global fields + // provided in this struct. + UseMultiConnectionManagement bool + + TradeFeed bool + FillsFeed bool + + MaxWebsocketSubscriptionsPerConnection int + + // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all endpoints. + // These rate limits take precedence over any rate limits specified in individual connection configurations. + // If no connection-specific rate limit is provided and the endpoint does not match any of these definitions, + // an error will be returned. However, if a connection configuration includes its own rate limit, + // it will fall back to that configuration’s rate limit without raising an error. + RateLimitDefinitions request.RateLimitDefinitions +} + +// connectionWrapper contains the connection setup details to be used when +// attempting a new connection. It also contains the subscriptions that are +// associated with the specific connection. +type connectionWrapper struct { + setup *ConnectionSetup + subscriptions *subscription.Store + connection Connection +} + +var globalReporter Reporter + +// SetupGlobalReporter sets a reporter interface to be used +// for all exchange requests +func SetupGlobalReporter(r Reporter) { + globalReporter = r +} + +// NewManager initialises the websocket struct +func NewManager() *Manager { + return &Manager{ + DataHandler: make(chan any, jobBuffer), + ToRoutine: make(chan any, jobBuffer), + ShutdownC: make(chan struct{}), + TrafficAlert: make(chan struct{}, 1), + // ReadMessageErrors is buffered for an edge case when `Connect` fails + // after subscriptions are made but before the connectionMonitor has + // started. This allows the error to be read and handled in the + // connectionMonitor and start a connection cycle again. + ReadMessageErrors: make(chan error, 1), + Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, + Orderbook: buffer.Orderbook{}, + connections: make(map[Connection]*connectionWrapper), + } +} + +// Setup sets main variables for websocket connection +func (m *Manager) Setup(s *ManagerSetup) error { + if err := common.NilGuard(m, s); err != nil { + return err + } + if s.ExchangeConfig == nil { + return fmt.Errorf("%w: ManagerSetup.ExchangeConfig", common.ErrNilPointer) + } + if s.ExchangeConfig.Features == nil { + return fmt.Errorf("%w: ManagerSetup.ExchangeConfig.Features", common.ErrNilPointer) + } + if s.Features == nil { + return fmt.Errorf("%w: ManagerSetup.Features", common.ErrNilPointer) + } + + m.m.Lock() + defer m.m.Unlock() + + if m.IsInitialised() { + return fmt.Errorf("%s %w", m.exchangeName, errWebsocketAlreadyInitialised) + } + + if s.ExchangeConfig.Name == "" { + return errExchangeConfigNameEmpty + } + m.exchangeName = s.ExchangeConfig.Name + m.verbose = s.ExchangeConfig.Verbose + + m.features = s.Features + + m.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) + + m.useMultiConnectionManagement = s.UseMultiConnectionManagement + + if !m.useMultiConnectionManagement { + // TODO: Remove this block when all exchanges are updated and backwards + // compatibility is no longer required. + if s.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if s.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if s.Unsubscriber == nil && m.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if s.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if s.DefaultURL == "" { + return fmt.Errorf("%s websocket %w", m.exchangeName, errDefaultURLIsEmpty) + } + m.defaultURL = s.DefaultURL + if s.RunningURL == "" { + return fmt.Errorf("%s websocket %w", m.exchangeName, errRunningURLIsEmpty) + } + + m.connector = s.Connector + m.Subscriber = s.Subscriber + m.Unsubscriber = s.Unsubscriber + m.GenerateSubs = s.GenerateSubscriptions + + err := m.SetWebsocketURL(s.RunningURL, false, false) + if err != nil { + return fmt.Errorf("%s %w", m.exchangeName, err) + } + + if s.RunningURLAuth != "" { + err = m.SetWebsocketURL(s.RunningURLAuth, true, false) + if err != nil { + return fmt.Errorf("%s %w", m.exchangeName, err) + } + } + } + + m.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay + if m.connectionMonitorDelay <= 0 { + m.connectionMonitorDelay = config.DefaultConnectionMonitorDelay + } + + if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second { + return fmt.Errorf("%s %w cannot be less than %s", + m.exchangeName, + errInvalidTrafficTimeout, + time.Second) + } + m.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout + + m.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) + + if err := m.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, m.DataHandler); err != nil { + return err + } + + m.Trade.Setup(s.TradeFeed, m.DataHandler) + m.Fills.Setup(s.FillsFeed, m.DataHandler) + + if s.MaxWebsocketSubscriptionsPerConnection < 0 { + return fmt.Errorf("%s %w", m.exchangeName, errInvalidMaxSubscriptions) + } + m.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + m.setState(disconnectedState) + + m.rateLimitDefinitions = s.RateLimitDefinitions + return nil +} + +// SetupNewConnection sets up an auth or unauth streaming connection +func (m *Manager) SetupNewConnection(c *ConnectionSetup) error { + if err := common.NilGuard(m, c); err != nil { + return err + } + + if c == nil || c.ResponseCheckTimeout == 0 && + c.ResponseMaxLimit == 0 && + c.RateLimit == nil && + c.URL == "" && + c.ConnectionLevelReporter == nil && + c.BespokeGenerateMessageID == nil { + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) + } + + if m.exchangeName == "" { + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) + } + if m.TrafficAlert == nil { + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) + } + if m.ReadMessageErrors == nil { + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) + } + if c.ConnectionLevelReporter == nil { + c.ConnectionLevelReporter = m.ExchangeLevelReporter + } + if c.ConnectionLevelReporter == nil { + c.ConnectionLevelReporter = globalReporter + } + + if m.useMultiConnectionManagement { + // The connection and supporting functions are defined per connection + // and the connection wrapper is stored in the connection manager. + if c.URL == "" { + return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty) + } + if c.Connector == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset) + } + if c.GenerateSubscriptions == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset) + } + if c.Subscriber == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset) + } + if c.Unsubscriber == nil && m.features.Unsubscribe { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset) + } + if c.Handler == nil { + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) + } + + if c.MessageFilter != nil && !reflect.TypeOf(c.MessageFilter).Comparable() { + return errMessageFilterNotComparable + } + + for x := range m.connectionManager { + // Below allows for multiple connections to the same URL with different outbound request signatures. This + // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on + // a spot connection. + if m.connectionManager[x].setup.URL == c.URL && c.MessageFilter == m.connectionManager[x].setup.MessageFilter { + return fmt.Errorf("%w: %w", errConnSetup, errDuplicateConnectionSetup) + } + } + m.connectionManager = append(m.connectionManager, &connectionWrapper{ + setup: c, + subscriptions: subscription.NewStore(), + }) + return nil + } + + if c.Authenticated { + m.AuthConn = m.getConnectionFromSetup(c) + } else { + m.Conn = m.getConnectionFromSetup(c) + } + + return nil +} + +// getConnectionFromSetup returns a websocket connection from a setup +// configuration. This is used for setting up new connections on the fly. +func (m *Manager) getConnectionFromSetup(c *ConnectionSetup) *connection { + connectionURL := m.GetWebsocketURL() + if c.URL != "" { + connectionURL = c.URL + } + return &connection{ + ExchangeName: m.exchangeName, + URL: connectionURL, + ProxyURL: m.GetProxyAddress(), + Verbose: m.verbose, + ResponseMaxLimit: c.ResponseMaxLimit, + Traffic: m.TrafficAlert, + readMessageErrors: m.ReadMessageErrors, + shutdown: m.ShutdownC, + Wg: &m.Wg, + Match: m.Match, + RateLimit: c.RateLimit, + Reporter: c.ConnectionLevelReporter, + bespokeGenerateMessageID: c.BespokeGenerateMessageID, + RateLimitDefinitions: m.rateLimitDefinitions, + } +} + +// Connect initiates a websocket connection by using a package defined connection +// function +func (m *Manager) Connect() error { + m.m.Lock() + defer m.m.Unlock() + return m.connect() +} + +func (m *Manager) connect() error { + if !m.IsEnabled() { + return ErrWebsocketNotEnabled + } + if m.IsConnecting() { + return fmt.Errorf("%v %w", m.exchangeName, errAlreadyReconnecting) + } + if m.IsConnected() { + return fmt.Errorf("%v %w", m.exchangeName, errAlreadyConnected) + } + + if m.subscriptions == nil { + return fmt.Errorf("%w: subscriptions", common.ErrNilPointer) + } + m.subscriptions.Clear() + + m.setState(connectingState) + + m.Wg.Add(2) + go m.monitorFrame(&m.Wg, m.monitorData) + go m.monitorFrame(&m.Wg, m.monitorTraffic) + + if !m.useMultiConnectionManagement { + if m.connector == nil { + return fmt.Errorf("%v %w", m.exchangeName, errNoConnectFunc) + } + err := m.connector() + if err != nil { + m.setState(disconnectedState) + return fmt.Errorf("%v Error connecting %w", m.exchangeName, err) + } + m.setState(connectedState) + + if m.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go m.monitorFrame(nil, m.monitorConnection) + } + + subs, err := m.GenerateSubs() // regenerate state on new connection + if err != nil { + return fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + } + if len(subs) != 0 { + if err := m.SubscribeToChannels(nil, subs); err != nil { + return err + } + + if missing := m.subscriptions.Missing(subs); len(missing) > 0 { + return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotAdded, missing) + } + } + return nil + } + + if len(m.connectionManager) == 0 { + m.setState(disconnectedState) + return fmt.Errorf("cannot connect: %w", errNoPendingConnections) + } + + // multiConnectFatalError is a fatal error that will cause all connections to + // be shutdown and the websocket to be disconnected. + var multiConnectFatalError error + + // subscriptionError is a non-fatal error that does not shutdown connections + var subscriptionError error + + // TODO: Implement concurrency below. + for i := range m.connectionManager { + if m.connectionManager[i].setup.GenerateSubscriptions == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriptionsGeneratorUnset) + break + } + + subs, err := m.connectionManager[i].setup.GenerateSubscriptions() // regenerate state on new connection + if err != nil { + multiConnectFatalError = fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + break + } + + if len(subs) == 0 { + // If no subscriptions are generated, we skip the connection + if m.verbose { + log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", m.exchangeName) + } + continue + } + + if m.connectionManager[i].setup.Connector == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errNoConnectFunc) + break + } + if m.connectionManager[i].setup.Handler == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketDataHandlerUnset) + break + } + if m.connectionManager[i].setup.Subscriber == nil { + multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriberUnset) + break + } + + // TODO: Add window for max subscriptions per connection, to spawn new connections if needed. + + conn := m.getConnectionFromSetup(m.connectionManager[i].setup) + + err = m.connectionManager[i].setup.Connector(context.TODO(), conn) + if err != nil { + multiConnectFatalError = fmt.Errorf("%v Error connecting %w", m.exchangeName, err) + break + } + + if !conn.IsConnected() { + multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", m.exchangeName, i+1, conn.URL) + break + } + + m.connections[conn] = m.connectionManager[i] + m.connectionManager[i].connection = conn + + m.Wg.Add(1) + go m.Reader(context.TODO(), conn, m.connectionManager[i].setup.Handler) + + if m.connectionManager[i].setup.Authenticate != nil && m.CanUseAuthenticatedEndpoints() { + err = m.connectionManager[i].setup.Authenticate(context.TODO(), conn) + if err != nil { + multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", m.exchangeName, i+1, conn.URL, err) + break + } + } + + err = m.connectionManager[i].setup.Subscriber(context.TODO(), conn, subs) + if err != nil { + subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", m.exchangeName, err)) + continue + } + + if missing := m.connectionManager[i].subscriptions.Missing(subs); len(missing) > 0 { + subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotAdded, missing)) + continue + } + + if m.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]", + m.exchangeName, + i+1, + conn.URL, + len(subs)) + } + } + + if multiConnectFatalError != nil { + // Roll back any successful connections and flush subscriptions + for x := range m.connectionManager { + if m.connectionManager[x].connection != nil { + if err := m.connectionManager[x].connection.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + m.connectionManager[x].connection = nil + } + m.connectionManager[x].subscriptions.Clear() + } + clear(m.connections) + m.setState(disconnectedState) // Flip from connecting to disconnected. + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + drain(m.ReadMessageErrors) + + return multiConnectFatalError + } + + // Assume connected state here. All connections have been established. + // All subscriptions have been sent and stored. All data received is being + // handled by the appropriate data handler. + m.setState(connectedState) + + if m.connectionMonitorRunning.CompareAndSwap(false, true) { + // This oversees all connections and does not need to be part of wait group management. + go m.monitorFrame(nil, m.monitorConnection) + } + + return subscriptionError +} + +// Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling +func (m *Manager) Disable() error { + if !m.IsEnabled() { + return fmt.Errorf("%s %w", m.exchangeName, ErrAlreadyDisabled) + } + + m.setEnabled(false) + return nil +} + +// Enable enables the exchange websocket protocol +func (m *Manager) Enable() error { + if m.IsConnected() || m.IsEnabled() { + return fmt.Errorf("%s %w", m.exchangeName, errWebsocketAlreadyEnabled) + } + + m.setEnabled(true) + return m.Connect() +} + +// Shutdown attempts to shut down a websocket connection and associated routines +// by using a package defined shutdown function +func (m *Manager) Shutdown() error { + m.m.Lock() + defer m.m.Unlock() + return m.shutdown() +} + +func (m *Manager) shutdown() error { + if !m.IsConnected() { + return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected) + } + + // TODO: Interrupt connection and or close connection when it is re-established. + if m.IsConnecting() { + return fmt.Errorf("%v %w: %w ", m.exchangeName, errCannotShutdown, errAlreadyReconnecting) + } + + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", m.exchangeName) + } + + defer m.Orderbook.FlushBuffer() + + // During the shutdown process, all errors are treated as non-fatal to avoid issues when the connection has already + // been closed. In such cases, attempting to close the connection may result in a + // "failed to send closeNotify alert (but connection was closed anyway)" error. Treating these errors as non-fatal + // prevents the shutdown process from being interrupted, which could otherwise trigger a continuous traffic monitor + // cycle and potentially block the initiation of a new connection. + var nonFatalCloseConnectionErrors error + + // Shutdown managed connections + for x := range m.connectionManager { + if m.connectionManager[x].connection != nil { + if err := m.connectionManager[x].connection.Shutdown(); err != nil { + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) + } + m.connectionManager[x].connection = nil + // Flush any subscriptions from last connection across any managed connections + m.connectionManager[x].subscriptions.Clear() + } + } + // Clean map of old connections + clear(m.connections) + + if m.Conn != nil { + if err := m.Conn.Shutdown(); err != nil { + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) + } + } + if m.AuthConn != nil { + if err := m.AuthConn.Shutdown(); err != nil { + nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err) + } + } + // flush any subscriptions from last connection if needed + m.subscriptions.Clear() + + m.setState(disconnectedState) + + close(m.ShutdownC) + m.Wg.Wait() + m.ShutdownC = make(chan struct{}) + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", m.exchangeName) + } + + // Drain residual error in the single buffered channel, this mitigates + // the cycle when `Connect` is called again and the connectionMonitor + // starts but there is an old error in the channel. + drain(m.ReadMessageErrors) + + if nonFatalCloseConnectionErrors != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: shutdown error: %v", m.exchangeName, nonFatalCloseConnectionErrors) + } + + return nil +} + +func (m *Manager) setState(s uint32) { + m.state.Store(s) +} + +// IsInitialised returns whether the websocket has been Setup() already +func (m *Manager) IsInitialised() bool { + return m.state.Load() != uninitialisedState +} + +// IsConnected returns whether the websocket is connected +func (m *Manager) IsConnected() bool { + return m.state.Load() == connectedState +} + +// IsConnecting returns whether the websocket is connecting +func (m *Manager) IsConnecting() bool { + return m.state.Load() == connectingState +} + +func (m *Manager) setEnabled(b bool) { + m.enabled.Store(b) +} + +// IsEnabled returns whether the websocket is enabled +func (m *Manager) IsEnabled() bool { + return m.enabled.Load() +} + +// CanUseAuthenticatedWebsocketForWrapper Handles a common check to +// verify whether a wrapper can use an authenticated websocket endpoint +func (m *Manager) CanUseAuthenticatedWebsocketForWrapper() bool { + if m.IsConnected() { + if m.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, "%v - Websocket not authenticated, using REST\n", m.exchangeName) + } + return false +} + +// SetWebsocketURL sets websocket URL and can refresh underlying connections +func (m *Manager) SetWebsocketURL(url string, auth, reconnect bool) error { + if m.useMultiConnectionManagement { + // TODO: Add functionality for multi-connection management to change URL + return fmt.Errorf("%s: %w", m.exchangeName, errCannotChangeConnectionURL) + } + defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage + if auth { + if defaultVals { + url = m.defaultURLAuth + } + + err := checkWebsocketURL(url) + if err != nil { + return err + } + m.runningURLAuth = url + + if m.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: setting authenticated websocket URL: %s\n", m.exchangeName, url) + } + + if m.AuthConn != nil { + m.AuthConn.SetURL(url) + } + } else { + if defaultVals { + url = m.defaultURL + } + err := checkWebsocketURL(url) + if err != nil { + return err + } + m.runningURL = url + + if m.verbose { + log.Debugf(log.WebsocketMgr, "%s websocket: setting unauthenticated websocket URL: %s\n", m.exchangeName, url) + } + + if m.Conn != nil { + m.Conn.SetURL(url) + } + } + + if m.IsConnected() && reconnect { + log.Debugf(log.WebsocketMgr, "%s websocket: flushing websocket connection to %s\n", m.exchangeName, url) + return m.Shutdown() + } + return nil +} + +// GetWebsocketURL returns the running websocket URL +func (m *Manager) GetWebsocketURL() string { + return m.runningURL +} + +// SetProxyAddress sets websocket proxy address +func (m *Manager) SetProxyAddress(proxyAddr string) error { + m.m.Lock() + defer m.m.Unlock() + if proxyAddr != "" { + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + return fmt.Errorf("%v websocket: cannot set proxy address: %w", m.exchangeName, err) + } + + if m.proxyAddr == proxyAddr { + return fmt.Errorf("%v websocket: %w '%v'", m.exchangeName, errSameProxyAddress, m.proxyAddr) + } + + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", m.exchangeName, proxyAddr) + } else { + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", m.exchangeName) + } + + for _, wrapper := range m.connectionManager { + if wrapper.connection != nil { + wrapper.connection.SetProxy(proxyAddr) + } + } + if m.Conn != nil { + m.Conn.SetProxy(proxyAddr) + } + if m.AuthConn != nil { + m.AuthConn.SetProxy(proxyAddr) + } + + m.proxyAddr = proxyAddr + + if !m.IsConnected() { + return nil + } + if err := m.shutdown(); err != nil { + return err + } + return m.connect() +} + +// GetProxyAddress returns the current websocket proxy +func (m *Manager) GetProxyAddress() string { + return m.proxyAddr +} + +// GetName returns exchange name +func (m *Manager) GetName() string { + return m.exchangeName +} + +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (m *Manager) SetCanUseAuthenticatedEndpoints(b bool) { + m.canUseAuthenticatedEndpoints.Store(b) +} + +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner +func (m *Manager) CanUseAuthenticatedEndpoints() bool { + return m.canUseAuthenticatedEndpoints.Load() +} + +// checkWebsocketURL checks for a valid websocket url +func checkWebsocketURL(s string) error { + u, err := url.Parse(s) + if err != nil { + return err + } + if u.Scheme != "ws" && u.Scheme != "wss" { + return fmt.Errorf("cannot set %w %s", errInvalidWebsocketURL, s) + } + return nil +} + +// Reader reads and handles data from a specific connection +func (m *Manager) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, message []byte) error) { + defer m.Wg.Done() + for { + resp := conn.ReadMessage() + if resp.Raw == nil { + return // Connection has been closed + } + if err := handler(ctx, resp.Raw); err != nil { + m.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) + } + } +} + +func drain(ch <-chan error) { + for { + select { + case <-ch: + default: + return + } + } +} + +// ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit +type ClosureFrame func() func() bool + +// monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true +// This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup +// is optional and is used to signal when the monitor has finished. +func (m *Manager) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) { + if wg != nil { + defer m.Wg.Done() + } + observe := fn() + for { + if observe() { + return + } + } +} + +// monitorData monitors data throughput and logs if there is a back log of data +func (m *Manager) monitorData() func() bool { + dropped := 0 + return func() bool { return m.observeData(&dropped) } +} + +// observeData observes data throughput and logs if there is a back log of data +func (m *Manager) observeData(dropped *int) (exit bool) { + select { + case <-m.ShutdownC: + return true + case d := <-m.DataHandler: + select { + case m.ToRoutine <- d: + if *dropped != 0 { + log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", m.exchangeName, dropped) + *dropped = 0 + } + default: + if *dropped == 0 { + // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible + log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", m.exchangeName) + } + *dropped++ + } + return false + } +} + +// monitorConnection monitors the connection and attempts to reconnect if the connection is lost +func (m *Manager) monitorConnection() func() bool { + timer := time.NewTimer(m.connectionMonitorDelay) + return func() bool { return m.observeConnection(timer) } +} + +// observeConnection observes the connection and attempts to reconnect if the connection is lost +func (m *Manager) observeConnection(t *time.Timer) (exit bool) { + select { + case err := <-m.ReadMessageErrors: + if errors.Is(err, errConnectionFault) { + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", m.exchangeName, err) + if m.IsConnected() { + if shutdownErr := m.Shutdown(); shutdownErr != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", m.exchangeName, shutdownErr) + } + } + } + // Speedier reconnection, instead of waiting for the next cycle. + if m.IsEnabled() && (!m.IsConnected() && !m.IsConnecting()) { + if connectErr := m.Connect(); connectErr != nil { + log.Errorln(log.WebsocketMgr, connectErr) + } + } + m.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority) + case <-t.C: + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", m.exchangeName) + } + if !m.IsEnabled() { + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", m.exchangeName) + } + if m.IsConnected() { + if err := m.Shutdown(); err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", m.exchangeName) + } + t.Stop() + m.connectionMonitorRunning.Store(false) + return true + } + if !m.IsConnecting() && !m.IsConnected() { + err := m.Connect() + if err != nil { + log.Errorln(log.WebsocketMgr, err) + } + } + t.Reset(m.connectionMonitorDelay) + } + return false +} + +// monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic +// the connection is shutdown and will be reconnected by the connectionMonitor routine. +func (m *Manager) monitorTraffic() func() bool { + timer := time.NewTimer(m.trafficTimeout) + return func() bool { return m.observeTraffic(timer) } +} + +func (m *Manager) observeTraffic(t *time.Timer) bool { + select { + case <-m.ShutdownC: + if m.verbose { + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", m.exchangeName) + } + case <-t.C: + if m.IsConnecting() || signalReceived(m.TrafficAlert) { + t.Reset(m.trafficTimeout) + return false + } + if m.verbose { + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", m.exchangeName, m.trafficTimeout) + } + if m.IsConnected() { + go func() { // Without this the m.Shutdown() call below will deadlock + if err := m.Shutdown(); err != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", m.exchangeName, err) + } + }() + } + } + t.Stop() + return true +} + +// signalReceived checks if a signal has been received, this also clears the signal. +func signalReceived(ch chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + +// GetConnection returns a connection by message filter (defined in exchange package _wrapper.go websocket connection) +// for request and response handling in a multi connection context. +func (m *Manager) GetConnection(messageFilter any) (Connection, error) { + if err := common.NilGuard(m); err != nil { + return nil, err + } + if messageFilter == nil { + return nil, fmt.Errorf("%w: messageFilter", common.ErrNilPointer) + } + + m.m.Lock() + defer m.m.Unlock() + + if !m.useMultiConnectionManagement { + return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", m.exchangeName, errCannotObtainOutboundConnection) + } + + if !m.IsConnected() { + return nil, ErrNotConnected + } + + for _, wrapper := range m.connectionManager { + if wrapper.setup.MessageFilter == messageFilter { + if wrapper.connection == nil { + return nil, fmt.Errorf("%s: %s %w associated with message filter: '%v'", m.exchangeName, wrapper.setup.URL, ErrNotConnected, messageFilter) + } + return wrapper.connection, nil + } + } + + return nil, fmt.Errorf("%s: %w associated with message filter: '%v'", m.exchangeName, ErrRequestRouteNotFound, messageFilter) +} diff --git a/exchanges/stream/websocket_test.go b/internal/exchange/websocket/manager_test.go similarity index 66% rename from exchanges/stream/websocket_test.go rename to internal/exchange/websocket/manager_test.go index 5b512e28..3b32c7a5 100644 --- a/exchanges/stream/websocket_test.go +++ b/internal/exchange/websocket/manager_test.go @@ -1,4 +1,4 @@ -package stream +package websocket import ( "bytes" @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "errors" + "fmt" "net/http" "net/http/httptest" "strconv" @@ -14,7 +15,7 @@ import ( "testing" "time" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" @@ -28,6 +29,7 @@ import ( ) const ( + Ping = "ping" useProxyTests = false // Disabled by default. Freely available proxy servers that work all the time are difficult to find proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) @@ -36,7 +38,7 @@ var errDastardlyReason = errors.New("some dastardly reason") type testStruct struct { Error error - WC WebsocketConnection + WC connection } type testRequest struct { @@ -61,8 +63,8 @@ type testSubKey struct { Mood string } -func newDefaultSetup() *WebsocketSetup { - return &WebsocketSetup{ +func newDefaultSetup() *ManagerSetup { + return &ManagerSetup{ ExchangeConfig: &config.Exchange{ Features: &config.FeaturesConfig{ Enabled: config.FeaturesEnabledConfig{Websocket: true}, @@ -92,32 +94,31 @@ func newDefaultSetup() *WebsocketSetup { func TestSetup(t *testing.T) { t.Parallel() - var w *Websocket + var w *Manager err := w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketIsNil) + assert.ErrorContains(t, err, "nil pointer: *websocket.Manager") - w = &Websocket{DataHandler: make(chan any)} + w = &Manager{DataHandler: make(chan any)} err = w.Setup(nil) - assert.ErrorIs(t, err, errWebsocketSetupIsNil) - - websocketSetup := &WebsocketSetup{} + assert.ErrorContains(t, err, "nil pointer: *websocket.ManagerSetup") + websocketSetup := &ManagerSetup{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigIsNil) + assert.ErrorContains(t, err, "nil pointer: ManagerSetup.Exchange") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) + assert.ErrorContains(t, err, "nil pointer: ManagerSetup.ExchangeConfig.Features") + + websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} + err = w.Setup(websocketSetup) + assert.ErrorContains(t, err, "nil pointer: ManagerSetup.Features") + + websocketSetup.Features = &protocol.Features{} + err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errExchangeConfigNameEmpty) websocketSetup.ExchangeConfig.Name = "testname" - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset) - - websocketSetup.Features = &protocol.Features{} - err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errConfigFeaturesIsNil) - - websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} websocketSetup.Subscriber = func(subscription.List) error { return nil } // kicks off the setup err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errWebsocketConnectorUnset) @@ -164,7 +165,7 @@ func TestSetup(t *testing.T) { func TestConnectionMessageErrors(t *testing.T) { t.Parallel() - wsWrong := &Websocket{} + wsWrong := &Manager{} wsWrong.connector = func() error { return nil } err := wsWrong.Connect() assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") @@ -185,7 +186,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = wsWrong.Connect() assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") - ws := NewWebsocket() + ws := NewManager() err = ws.Setup(newDefaultSetup()) require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute @@ -207,9 +208,9 @@ func TestConnectionMessageErrors(t *testing.T) { checkToRoutineResult := func(t *testing.T) { t.Helper() v, ok := <-ws.ToRoutine - require.True(t, ok, "ToRoutine should not be closed on us") + require.True(t, ok, "ToRoutine must not be closed on us") switch err := v.(type) { - case *websocket.CloseError: + case *gws.CloseError: assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") case error: assert.ErrorIs(t, err, errDastardlyReason, "Should get the correct error") @@ -222,7 +223,7 @@ func TestConnectionMessageErrors(t *testing.T) { ws.ReadMessageErrors <- errDastardlyReason checkToRoutineResult(t) - ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + ws.ReadMessageErrors <- &gws.CloseError{Code: 1006, Text: "SpecialText"} checkToRoutineResult(t) // Test individual connection defined functions @@ -237,87 +238,87 @@ func TestConnectionMessageErrors(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} + ws.connectionManager = []*connectionWrapper{{setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) - ws.connectionManager[0].Setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } + ws.connectionManager[0].setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } - ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { + ws.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { + ws.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil } err = ws.Connect() require.ErrorIs(t, err, errNoConnectFunc) - ws.connectionManager[0].Setup.Connector = func(context.Context, Connection) error { + ws.connectionManager[0].setup.Connector = func(context.Context, Connection) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketDataHandlerUnset) - ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].setup.Handler = func(context.Context, []byte) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriberUnset) - ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].Setup.Connector = func(ctx context.Context, conn Connection) error { - return conn.DialContext(ctx, websocket.DefaultDialer, nil) + ws.connectionManager[0].setup.Connector = func(ctx context.Context, conn Connection) error { + return conn.DialContext(ctx, gws.DefaultDialer, nil) } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].Setup.Handler = func(context.Context, []byte) error { + ws.connectionManager[0].setup.Handler = func(context.Context, []byte) error { return errDastardlyReason } err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) - ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } - ws.connectionManager[0].Setup.Authenticate = nil + ws.connectionManager[0].setup.Authenticate = nil err = ws.Connect() require.ErrorIs(t, err, errDastardlyReason) require.NoError(t, ws.shutdown()) - ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { + ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { return nil } err = ws.Connect() require.ErrorIs(t, err, ErrSubscriptionsNotAdded) require.NoError(t, ws.shutdown()) - ws.connectionManager[0].Subscriptions = subscription.NewStore() - ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error { - return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"}) + ws.connectionManager[0].subscriptions = subscription.NewStore() + ws.connectionManager[0].setup.Subscriber = func(context.Context, Connection, subscription.List) error { + return ws.connectionManager[0].subscriptions.Add(&subscription.Subscription{Channel: "test"}) } err = ws.Connect() require.NoError(t, err) - err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test")) + err = ws.connectionManager[0].connection.SendRawMessage(context.Background(), request.Unset, gws.TextMessage, []byte("test")) require.NoError(t, err) require.NoError(t, err) require.NoError(t, ws.Shutdown()) } -func TestWebsocket(t *testing.T) { +func TestManager(t *testing.T) { t.Parallel() - ws := NewWebsocket() + ws := NewManager() err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") @@ -412,237 +413,15 @@ func TestWebsocket(t *testing.T) { ws.useMultiConnectionManagement = true - ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} + ws.connectionManager = []*connectionWrapper{{setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, connection: &connection{}}} err = ws.SetProxyAddress("https://192.168.0.1:1337") require.NoError(t, err) } -func currySimpleSub(w *Websocket) func(subscription.List) error { - return func(subs subscription.List) error { - return w.AddSuccessfulSubscriptions(nil, subs...) - } -} - -func currySimpleSubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { - return func(_ context.Context, conn Connection, subs subscription.List) error { - return w.AddSuccessfulSubscriptions(conn, subs...) - } -} - -func currySimpleUnsub(w *Websocket) func(subscription.List) error { - return func(unsubs subscription.List) error { - return w.RemoveSubscriptions(nil, unsubs...) - } -} - -func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscription.List) error { - return func(_ context.Context, conn Connection, unsubs subscription.List) error { - return w.RemoveSubscriptions(conn, unsubs...) - } -} - -// TestSubscribe logic test -func TestSubscribeUnsubscribe(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error") - - ws.Subscriber = currySimpleSub(ws) - ws.Unsubscriber = currySimpleUnsub(ws) - - subs, err := ws.GenerateSubs() - require.NoError(t, err, "Generating test subscriptions should not error") - assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") - assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") - assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") - assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") - bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) - if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { - assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") - assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") - } - if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { - assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") - } - if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { - assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") - } - if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") { - assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") - } - assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") - assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") - assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") - - ws.Subscriber = func(subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - - err = ws.SubscribeToChannels(nil, subscription.List{nil}) - assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") - - multi := NewWebsocket() - set := newDefaultSetup() - set.UseMultiConnectionManagement = true - assert.NoError(t, multi.Setup(set)) - - amazingCandidate := &ConnectionSetup{ - URL: "AMAZING", - Connector: func(context.Context, Connection) error { return nil }, - GenerateSubscriptions: ws.GenerateSubs, - Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { - return currySimpleSubConn(multi)(ctx, c, s) - }, - Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { - return currySimpleUnsubConn(multi)(ctx, c, s) - }, - Handler: func(context.Context, []byte) error { return nil }, - } - require.NoError(t, multi.SetupNewConnection(amazingCandidate)) - - amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.connections = map[Connection]*ConnectionWrapper{ - amazingConn: multi.connectionManager[0], - } - - subs, err = amazingCandidate.GenerateSubscriptions() - require.NoError(t, err, "Generating test subscriptions should not error") - assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") - assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") - assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") - assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") - - assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") - - assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") - assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") - bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) - if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { - assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") - assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") - } - if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") { - assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") - } - if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { - assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") - } - if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") { - assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") - } - assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") - assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") - assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") - assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") - - amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } - assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") - - err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) - assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") -} - -// TestResubscribe tests Resubscribing to existing subscriptions -func TestResubscribe(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - - wackedOutSetup := newDefaultSetup() - wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 - err := ws.Setup(wackedOutSetup) - assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error") - - err = ws.Setup(newDefaultSetup()) - assert.NoError(t, err, "WS Setup should not error") - - ws.Subscriber = currySimpleSub(ws) - ws.Unsubscriber = currySimpleUnsub(ws) - - channel := subscription.List{{Channel: "resubTest"}} - - assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") - assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") -} - -// TestSubscriptions tests adding, getting and removing subscriptions -func TestSubscriptions(t *testing.T) { - t.Parallel() - w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil - assert.ErrorIs(t, (*Websocket)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") - s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") - assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error") - assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") - - err := w.RemoveSubscriptions(nil, s) - require.NoError(t, err, "RemoveSubscriptions must not error") - assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") - assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") - - require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error") - require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription should not error") - assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state") -} - -// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions -func TestSuccessfulSubscriptions(t *testing.T) { - t.Parallel() - w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil - assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") - c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} - require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription should not error") - assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") - require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error") - assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error") - - err := w.RemoveSubscriptions(nil, c) - require.NoError(t, err, "RemoveSubscriptions must not error") - assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") - assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found") - assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") - w.subscriptions = nil - assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") -} - -// TestGetSubscription logic test -func TestGetSubscription(t *testing.T) { - t.Parallel() - assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") - assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") - w := NewWebsocket() - assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") - s := &subscription.Subscription{Key: 42, Channel: "hello3"} - require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error") - assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") -} - -// TestGetSubscriptions logic test -func TestGetSubscriptions(t *testing.T) { - t.Parallel() - assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") - assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") - w := NewWebsocket() - s := subscription.List{ - {Key: 42, Channel: "hello3"}, - {Key: 45, Channel: "hello4"}, - } - err := w.AddSubscriptions(nil, s...) - require.NoError(t, err, "AddSubscriptions must not error") - assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details") -} - // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := NewWebsocket() + ws := NewManager() assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") @@ -657,7 +436,7 @@ func TestDial(t *testing.T) { testCases := []testStruct{ { - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test1", Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", @@ -667,7 +446,7 @@ func TestDial(t *testing.T) { }, { Error: errors.New(" Error: malformed ws or wss URL"), - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test2", Verbose: true, URL: "", @@ -675,7 +454,7 @@ func TestDial(t *testing.T) { }, }, { - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test3", Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", @@ -690,7 +469,7 @@ func TestDial(t *testing.T) { t.Log("Proxy testing not enabled, skipping") continue } - err := testCases[i].WC.Dial(&websocket.Dialer{}, http.Header{}) + err := testCases[i].WC.Dial(&gws.Dialer{}, http.Header{}) if err != nil { if testCases[i].Error != nil && strings.Contains(err.Error(), testCases[i].Error.Error()) { return @@ -709,7 +488,7 @@ func TestSendMessage(t *testing.T) { testCases := []testStruct{ { - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test1", Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", @@ -719,7 +498,7 @@ func TestSendMessage(t *testing.T) { }, { Error: errors.New(" Error: malformed ws or wss URL"), - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test2", Verbose: true, URL: "", @@ -727,7 +506,7 @@ func TestSendMessage(t *testing.T) { }, }, { - WC: WebsocketConnection{ + WC: connection{ ExchangeName: "test3", Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", @@ -742,7 +521,7 @@ func TestSendMessage(t *testing.T) { t.Log("Proxy testing not enabled, skipping") continue } - err := testCases[x].WC.Dial(&websocket.Dialer{}, http.Header{}) + err := testCases[x].WC.Dial(&gws.Dialer{}, http.Header{}) if err != nil { if testCases[x].Error != nil && strings.Contains(err.Error(), testCases[x].Error.Error()) { return @@ -751,7 +530,7 @@ func TestSendMessage(t *testing.T) { } err = testCases[x].WC.SendJSONMessage(context.Background(), request.Unset, Ping) require.NoError(t, err) - err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping)) + err = testCases[x].WC.SendRawMessage(context.Background(), request.Unset, gws.TextMessage, []byte(Ping)) require.NoError(t, err) } } @@ -762,7 +541,7 @@ func TestSendMessageReturnResponse(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - wc := &WebsocketConnection{ + wc := &connection{ Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, @@ -772,7 +551,7 @@ func TestSendMessageReturnResponse(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&websocket.Dialer{}, http.Header{}) + err := wc.Dial(&gws.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } @@ -809,7 +588,7 @@ func TestSendMessageReturnResponse(t *testing.T) { func TestWaitForResponses(t *testing.T) { t.Parallel() - dummy := &WebsocketConnection{ + dummy := &connection{ ResponseMaxLimit: time.Nanosecond, Match: NewMatch(), } @@ -852,7 +631,7 @@ func (r *reporter) Latency(name string, message []byte, t time.Duration) { } // readMessages helper func -func readMessages(t *testing.T, wc *WebsocketConnection) { +func readMessages(t *testing.T, wc *connection) { t.Helper() timer := time.NewTimer(20 * time.Second) for { @@ -886,7 +665,7 @@ func TestSetupPingHandler(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - wc := &WebsocketConnection{ + wc := &connection{ URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), @@ -897,14 +676,14 @@ func TestSetupPingHandler(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } wc.shutdown = make(chan struct{}) - err := wc.Dial(&websocket.Dialer{}, http.Header{}) + err := wc.Dial(&gws.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } wc.SetupPingHandler(request.Unset, PingHandler{ UseGorillaHandler: true, - MessageType: websocket.PingMessage, + MessageType: gws.PingMessage, Delay: 100, }) @@ -913,12 +692,12 @@ func TestSetupPingHandler(t *testing.T) { t.Error(err) } - err = wc.Dial(&websocket.Dialer{}, http.Header{}) + err = wc.Dial(&gws.Dialer{}, http.Header{}) if err != nil { t.Fatal(err) } wc.SetupPingHandler(request.Unset, PingHandler{ - MessageType: websocket.TextMessage, + MessageType: gws.TextMessage, Message: []byte(Ping), Delay: 200, }) @@ -934,7 +713,7 @@ func TestParseBinaryResponse(t *testing.T) { mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - wc := &WebsocketConnection{ + wc := &connection{ URL: "ws" + mock.URL[len("http"):] + "/ws", ResponseMaxLimit: time.Second * 5, Match: NewMatch(), @@ -968,7 +747,7 @@ func TestParseBinaryResponse(t *testing.T) { // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() - ws := &Websocket{} + ws := &Manager{} assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") ws.setState(connectedState) @@ -981,22 +760,22 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { func TestGenerateMessageID(t *testing.T) { t.Parallel() - wc := WebsocketConnection{} + wc := connection{} const spins = 1000 ids := make([]int64, spins) for i := range spins { id := wc.GenerateMessageID(true) - assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + assert.NotContains(t, ids, id, "GenerateMessageID should not generate the same ID twice") ids[i] = id } wc.bespokeGenerateMessageID = func(bool) int64 { return 42 } - assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID must use bespokeGenerateMessageID") + assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID should use bespokeGenerateMessageID") } // 7002502 166.7 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_High(b *testing.B) { - wc := WebsocketConnection{} + wc := connection{} for b.Loop() { _ = wc.GenerateMessageID(true) } @@ -1004,7 +783,7 @@ func BenchmarkGenerateMessageID_High(b *testing.B) { // 6536250 186.1 ns/op 48 B/op 3 allocs/op func BenchmarkGenerateMessageID_Low(b *testing.B) { - wc := WebsocketConnection{} + wc := connection{} for b.Loop() { _ = wc.GenerateMessageID(false) } @@ -1072,7 +851,7 @@ func TestFlushChannels(t *testing.T) { t.Parallel() // Enabled pairs/setup system - dodgyWs := Websocket{} + dodgyWs := Manager{} err := dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") @@ -1085,7 +864,7 @@ func TestFlushChannels(t *testing.T) { currency.NewPair(currency.BTC, currency.USDT), }} - w := NewWebsocket() + w := NewManager() w.exchangeName = "test" w.connector = connect w.Subscriber = newgen.SUBME @@ -1174,7 +953,7 @@ func TestFlushChannels(t *testing.T) { amazingCandidate := &ConnectionSetup{ URL: "ws" + mock.URL[len("http"):] + "/ws", Connector: func(ctx context.Context, conn Connection) error { - return conn.DialContext(ctx, websocket.DefaultDialer, nil) + return conn.DialContext(ctx, gws.DefaultDialer, nil) }, GenerateSubscriptions: newgen.generateSubs, Subscriber: func(context.Context, Connection, subscription.List) error { return nil }, @@ -1184,7 +963,7 @@ func TestFlushChannels(t *testing.T) { require.NoError(t, w.SetupNewConnection(amazingCandidate)) require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store") - w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error { + w.connectionManager[0].setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleSubConn(w)(ctx, c, s) } require.NoError(t, w.FlushChannels(), "FlushChannels must not error") @@ -1196,10 +975,10 @@ func TestFlushChannels(t *testing.T) { // Unsubscribe what's already subscribed. No subscriptions left over, which then forces the shutdown and removal // of the connection from management. w.features.Subscribe = true - w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } + w.connectionManager[0].setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store") - w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error { + w.connectionManager[0].setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleUnsubConn(w)(ctx, c, s) } require.NoError(t, w.FlushChannels(), "FlushChannels must not error") @@ -1207,7 +986,7 @@ func TestFlushChannels(t *testing.T) { func TestDisable(t *testing.T) { t.Parallel() - w := NewWebsocket() + w := NewManager() w.setEnabled(true) w.setState(connectedState) require.NoError(t, w.Disable(), "Disable must not error") @@ -1216,7 +995,7 @@ func TestDisable(t *testing.T) { func TestEnable(t *testing.T) { t.Parallel() - w := NewWebsocket() + w := NewManager() w.connector = connect w.Subscriber = func(subscription.List) error { return nil } w.Unsubscriber = func(subscription.List) error { return nil } @@ -1227,15 +1006,15 @@ func TestEnable(t *testing.T) { func TestSetupNewConnection(t *testing.T) { t.Parallel() - var nonsenseWebsock *Websocket + var nonsenseWebsock *Manager err := nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) - assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") + assert.ErrorContains(t, err, "nil pointer: *websocket.Manager") - nonsenseWebsock = &Websocket{} + nonsenseWebsock = &Manager{} err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") - nonsenseWebsock = &Websocket{exchangeName: "test"} + nonsenseWebsock = &Manager{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") @@ -1243,7 +1022,7 @@ func TestSetupNewConnection(t *testing.T) { err = nonsenseWebsock.SetupNewConnection(&ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := NewWebsocket() + web := NewManager() err = web.Setup(newDefaultSetup()) assert.NoError(t, err, "Setup should not error") @@ -1255,13 +1034,13 @@ func TestSetupNewConnection(t *testing.T) { assert.NoError(t, err, "SetupNewConnection should not error") // Test connection candidates for multi connection tracking. - multi := NewWebsocket() + multi := NewManager() set := newDefaultSetup() set.UseMultiConnectionManagement = true require.NoError(t, multi.Setup(set)) err = multi.SetupNewConnection(nil) - require.ErrorIs(t, err, errExchangeConfigEmpty) + assert.ErrorContains(t, err, "nil pointer: *websocket.ConnectionSetup") connSetup := &ConnectionSetup{ResponseCheckTimeout: time.Millisecond} err = multi.SetupNewConnection(connSetup) @@ -1302,24 +1081,24 @@ func TestSetupNewConnection(t *testing.T) { require.Nil(t, multi.Conn) err = multi.SetupNewConnection(connSetup) - require.ErrorIs(t, err, errConnectionWrapperDuplication) + require.ErrorIs(t, err, errDuplicateConnectionSetup) } -func TestWebsocketConnectionShutdown(t *testing.T) { +func TestConnectionShutdown(t *testing.T) { t.Parallel() - wc := WebsocketConnection{shutdown: make(chan struct{})} + wc := connection{shutdown: make(chan struct{})} err := wc.Shutdown() assert.NoError(t, err, "Shutdown should not error") - err = wc.Dial(&websocket.Dialer{}, nil) - assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") + err = wc.Dial(&gws.Dialer{}, nil) + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial should error correctly") mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() wc.URL = "ws" + mock.URL[len("http"):] + "/ws" - err = wc.Dial(&websocket.Dialer{}, nil) + err = wc.Dial(&gws.Dialer{}, nil) require.NoError(t, err, "Dial must not error") err = wc.Shutdown() @@ -1335,7 +1114,7 @@ func TestLatency(t *testing.T) { r := &reporter{} exch := "Kraken" - wc := &WebsocketConnection{ + wc := &connection{ ExchangeName: exch, Verbose: true, URL: "ws" + mock.URL[len("http"):] + "/ws", @@ -1347,7 +1126,7 @@ func TestLatency(t *testing.T) { t.Skip("Proxy testing not enabled, skipping") } - err := wc.Dial(&websocket.Dialer{}, http.Header{}) + err := wc.Dial(&gws.Dialer{}, http.Header{}) require.NoError(t, err) go readMessages(t, wc) @@ -1361,40 +1140,8 @@ func TestLatency(t *testing.T) { _, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req) require.NoError(t, err) - require.NotEmpty(t, r.t, "Latency should have a duration") - require.Equal(t, exch, r.name, "Latency should have the correct exchange name") -} - -func TestCheckSubscriptions(t *testing.T) { - t.Parallel() - ws := Websocket{} - err := ws.checkSubscriptions(nil, nil) - assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions") - assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions") - - ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(nil, nil) - assert.NoError(t, err, "checkSubscriptions should not error on a nil list") - - ws.MaxSubscriptionsPerConnection = 1 - - err = ws.checkSubscriptions(nil, subscription.List{{}}) - assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty") - - ws.subscriptions = subscription.NewStore() - err = ws.checkSubscriptions(nil, subscription.List{{}, {}}) - assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") - - ws.MaxSubscriptionsPerConnection = 2 - - ws.subscriptions = subscription.NewStore() - err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) - require.NoError(t, err, "Add subscription must not error") - err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}}) - assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - - err = ws.checkSubscriptions(nil, subscription.List{{}}) - assert.NoError(t, err, "checkSubscriptions should not error") + require.NotEmpty(t, r.t, "Latency must have a duration") + require.Equal(t, exch, r.name, "Latency must have the correct exchange name") } func TestRemoveURLQueryString(t *testing.T) { @@ -1406,7 +1153,7 @@ func TestRemoveURLQueryString(t *testing.T) { func TestWriteToConn(t *testing.T) { t.Parallel() - wc := WebsocketConnection{} + wc := connection{} require.ErrorIs(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }), errWebsocketIsDisconnected) wc.setConnectedStatus(true) // No rate limits set @@ -1434,18 +1181,18 @@ func TestDrain(t *testing.T) { drain(nil) ch := make(chan error) drain(ch) - require.Empty(t, ch, "Drain should empty the channel") + require.Empty(t, ch, "Drain must empty the channel") ch = make(chan error, 10) for range 10 { ch <- errors.New("test") } drain(ch) - require.Empty(t, ch, "Drain should empty the channel") + require.Empty(t, ch, "Drain must empty the channel") } func TestMonitorFrame(t *testing.T) { t.Parallel() - ws := Websocket{} + ws := Manager{} require.Panics(t, func() { ws.monitorFrame(nil, nil) }, "monitorFrame must panic on nil frame") require.Panics(t, func() { ws.monitorFrame(nil, func() func() bool { return nil }) }, "monitorFrame must panic on nil function") ws.Wg.Add(1) @@ -1455,7 +1202,7 @@ func TestMonitorFrame(t *testing.T) { func TestMonitorData(t *testing.T) { t.Parallel() - ws := Websocket{ShutdownC: make(chan struct{}), DataHandler: make(chan any, 10)} + ws := Manager{ShutdownC: make(chan struct{}), DataHandler: make(chan any, 10)} // Handle shutdown signal close(ws.ShutdownC) require.True(t, ws.observeData(nil)) @@ -1481,7 +1228,7 @@ func TestMonitorData(t *testing.T) { func TestMonitorConnection(t *testing.T) { t.Parallel() - ws := Websocket{verbose: true, ReadMessageErrors: make(chan error, 1), ShutdownC: make(chan struct{})} + ws := Manager{verbose: true, ReadMessageErrors: make(chan error, 1), ShutdownC: make(chan struct{})} // Handle timer expired and websocket disabled, shutdown everything. timer := time.NewTimer(0) ws.setState(connectedState) @@ -1517,7 +1264,7 @@ func TestMonitorConnection(t *testing.T) { func TestMonitorTraffic(t *testing.T) { t.Parallel() - ws := Websocket{verbose: true, ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1)} + ws := Manager{verbose: true, ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1)} ws.Wg.Add(1) // Handle external shutdown signal timer := time.NewTimer(time.Second) @@ -1551,14 +1298,16 @@ func TestMonitorTraffic(t *testing.T) { func TestGetConnection(t *testing.T) { t.Parallel() - var ws *Websocket + var ws *Manager _, err := ws.GetConnection(nil) require.ErrorIs(t, err, common.ErrNilPointer) + require.ErrorContains(t, err, fmt.Sprintf("%T", ws)) - ws = &Websocket{} + ws = &Manager{} _, err = ws.GetConnection(nil) - require.ErrorIs(t, err, errMessageFilterNotSet) + require.ErrorIs(t, err, common.ErrNilPointer) + require.ErrorContains(t, err, "messageFilter") _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, errCannotObtainOutboundConnection) @@ -1573,57 +1322,17 @@ func TestGetConnection(t *testing.T) { _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrRequestRouteNotFound) - ws.connectionManager = []*ConnectionWrapper{{ - Setup: &ConnectionSetup{MessageFilter: "testURL", URL: "testURL"}, + ws.connectionManager = []*connectionWrapper{{ + setup: &ConnectionSetup{MessageFilter: "testURL", URL: "testURL"}, }} _, err = ws.GetConnection("testURL") require.ErrorIs(t, err, ErrNotConnected) - expected := &WebsocketConnection{} - ws.connectionManager[0].Connection = expected + expected := &connection{} + ws.connectionManager[0].connection = expected conn, err := ws.GetConnection("testURL") require.NoError(t, err) assert.Same(t, expected, conn) } - -func TestUpdateChannelSubscriptions(t *testing.T) { - t.Parallel() - - ws := Websocket{} - store := subscription.NewStore() - err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) - require.ErrorIs(t, err, common.ErrNilPointer) - require.Zero(t, store.Len()) - - ws.Subscriber = func(subs subscription.List) error { - for _, sub := range subs { - if err := store.Add(sub); err != nil { - return err - } - } - return nil - } - - ws.subscriptions = store - err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) - require.NoError(t, err) - require.Equal(t, 1, store.Len()) - - err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) - require.ErrorIs(t, err, common.ErrNilPointer) - - ws.Unsubscriber = func(subs subscription.List) error { - for _, sub := range subs { - if err := store.Remove(sub); err != nil { - return err - } - } - return nil - } - - err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) - require.NoError(t, err) - require.Zero(t, store.Len()) -} diff --git a/exchanges/stream/stream_match.go b/internal/exchange/websocket/match.go similarity index 99% rename from exchanges/stream/stream_match.go rename to internal/exchange/websocket/match.go index 430688a8..8bed2f18 100644 --- a/exchanges/stream/stream_match.go +++ b/internal/exchange/websocket/match.go @@ -1,4 +1,4 @@ -package stream +package websocket import ( "errors" diff --git a/exchanges/stream/stream_match_test.go b/internal/exchange/websocket/match_test.go similarity index 99% rename from exchanges/stream/stream_match_test.go rename to internal/exchange/websocket/match_test.go index 8053cb37..f98cd065 100644 --- a/exchanges/stream/stream_match_test.go +++ b/internal/exchange/websocket/match_test.go @@ -1,4 +1,4 @@ -package stream +package websocket import ( "testing" diff --git a/internal/exchange/websocket/subscriptions.go b/internal/exchange/websocket/subscriptions.go new file mode 100644 index 00000000..5ff2df4e --- /dev/null +++ b/internal/exchange/websocket/subscriptions.go @@ -0,0 +1,349 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "slices" + + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" + "github.com/thrasher-corp/gocryptotrader/log" +) + +// Public subscription errors +var ( + ErrSubscriptionFailure = errors.New("subscription failure") + ErrSubscriptionsNotAdded = errors.New("subscriptions not added") + ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed") +) + +// Public subscription errors +var ( + errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") +) + +// UnsubscribeChannels unsubscribes from a list of websocket channel +func (m *Manager) UnsubscribeChannels(conn Connection, channels subscription.List) error { + if len(channels) == 0 { + return nil // No channels to unsubscribe from is not an error + } + if wrapper, ok := m.connections[conn]; ok && conn != nil { + return m.unsubscribe(wrapper.subscriptions, channels, func(channels subscription.List) error { + return wrapper.setup.Unsubscriber(context.TODO(), conn, channels) + }) + } + + if m.Unsubscriber == nil { + return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer) + } + + return m.unsubscribe(m.subscriptions, channels, func(channels subscription.List) error { + return m.Unsubscriber(channels) + }) +} + +func (m *Manager) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error { + if store == nil { + return nil // No channels to unsubscribe from is not an error + } + for _, s := range channels { + if store.Get(s) == nil { + return fmt.Errorf("%w: %s", subscription.ErrNotFound, s) + } + } + return unsub(channels) +} + +// ResubscribeToChannel resubscribes to channel +// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription +// Errors if subscription is already subscribing +func (m *Manager) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error { + l := subscription.List{s} + if err := s.SetState(subscription.ResubscribingState); err != nil { + return fmt.Errorf("%w: %s", err, s) + } + if err := m.UnsubscribeChannels(conn, l); err != nil { + return err + } + return m.SubscribeToChannels(conn, l) +} + +// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method +// Errors are returned for duplicates or exceeding max Subscriptions +func (m *Manager) SubscribeToChannels(conn Connection, subs subscription.List) error { + if slices.Contains(subs, nil) { + return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer) + } + if err := m.checkSubscriptions(conn, subs); err != nil { + return err + } + + if wrapper, ok := m.connections[conn]; ok && conn != nil { + return wrapper.setup.Subscriber(context.TODO(), conn, subs) + } + + if m.Subscriber == nil { + return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer) + } + + if err := m.Subscriber(subs); err != nil { + return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err) + } + return nil +} + +// AddSubscriptions adds subscriptions to the subscription store +// Sets state to Subscribing unless the state is already set +func (m *Manager) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error { + if m == nil { + return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) + } + var subscriptionStore **subscription.Store + if wrapper, ok := m.connections[conn]; ok && conn != nil { + subscriptionStore = &wrapper.subscriptions + } else { + subscriptionStore = &m.subscriptions + } + + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() + } + var errs error + for _, s := range subs { + if s.State() == subscription.InactiveState { + if err := s.SetState(subscription.SubscribingState); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) + } + } + if err := (*subscriptionStore).Add(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs +} + +// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store +func (m *Manager) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error { + if m == nil { + return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer) + } + + var subscriptionStore **subscription.Store + if wrapper, ok := m.connections[conn]; ok && conn != nil { + subscriptionStore = &wrapper.subscriptions + } else { + subscriptionStore = &m.subscriptions + } + + if *subscriptionStore == nil { + *subscriptionStore = subscription.NewStore() + } + + var errs error + for _, s := range subs { + if err := s.SetState(subscription.SubscribedState); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) + } + if err := (*subscriptionStore).Add(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs +} + +// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed +func (m *Manager) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error { + if m == nil { + return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer) + } + + var subscriptionStore *subscription.Store + if wrapper, ok := m.connections[conn]; ok && conn != nil { + subscriptionStore = wrapper.subscriptions + } else { + subscriptionStore = m.subscriptions + } + + if subscriptionStore == nil { + return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer) + } + + var errs error + for _, s := range subs { + if err := s.SetState(subscription.UnsubscribedState); err != nil { + errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s)) + } + if err := subscriptionStore.Remove(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs +} + +// GetSubscription returns a subscription at the key provided +// returns nil if no subscription is at that key or the key is nil +// Keys can implement subscription.MatchableKey in order to provide custom matching logic +func (m *Manager) GetSubscription(key any) *subscription.Subscription { + if m == nil || key == nil { + return nil + } + for _, c := range m.connectionManager { + if c.subscriptions == nil { + continue + } + sub := c.subscriptions.Get(key) + if sub != nil { + return sub + } + } + if m.subscriptions == nil { + return nil + } + return m.subscriptions.Get(key) +} + +// GetSubscriptions returns a new slice of the subscriptions +func (m *Manager) GetSubscriptions() subscription.List { + if m == nil { + return nil + } + var subs subscription.List + for _, c := range m.connectionManager { + if c.subscriptions != nil { + subs = append(subs, c.subscriptions.List()...) + } + } + if m.subscriptions != nil { + subs = append(subs, m.subscriptions.List()...) + } + return subs +} + +// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists +// The subscription state is not considered when counting existing subscriptions +func (m *Manager) checkSubscriptions(conn Connection, subs subscription.List) error { + var subscriptionStore *subscription.Store + if wrapper, ok := m.connections[conn]; ok && conn != nil { + subscriptionStore = wrapper.subscriptions + } else { + subscriptionStore = m.subscriptions + } + if subscriptionStore == nil { + return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer) + } + + existing := subscriptionStore.Len() + if m.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > m.MaxSubscriptionsPerConnection { + return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", + errSubscriptionsExceedsLimit, + existing, + len(subs), + m.MaxSubscriptionsPerConnection) + } + + for _, s := range subs { + if s.State() == subscription.ResubscribingState { + continue + } + if found := subscriptionStore.Get(s); found != nil { + return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) + } + } + + return nil +} + +// FlushChannels flushes channel subscriptions when there is a pair/asset change +func (m *Manager) FlushChannels() error { + if !m.IsEnabled() { + return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketNotEnabled) + } + + if !m.IsConnected() { + return fmt.Errorf("%s %w", m.exchangeName, ErrNotConnected) + } + + // If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to + // maintain consistency. + if !m.features.Subscribe || !m.features.Unsubscribe { + m.m.Lock() + defer m.m.Unlock() + if err := m.shutdown(); err != nil { + return err + } + return m.connect() + } + + if !m.useMultiConnectionManagement { + newSubs, err := m.GenerateSubs() + if err != nil { + return err + } + return m.updateChannelSubscriptions(nil, m.subscriptions, newSubs) + } + + for x := range m.connectionManager { + newSubs, err := m.connectionManager[x].setup.GenerateSubscriptions() + if err != nil { + return err + } + + // Case if there is nothing to unsubscribe from and the connection is nil + if len(newSubs) == 0 && m.connectionManager[x].connection == nil { + continue + } + + // If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection. + if m.connectionManager[x].connection == nil { + conn := m.getConnectionFromSetup(m.connectionManager[x].setup) + if err := m.connectionManager[x].setup.Connector(context.TODO(), conn); err != nil { + return err + } + m.Wg.Add(1) + go m.Reader(context.TODO(), conn, m.connectionManager[x].setup.Handler) + m.connections[conn] = m.connectionManager[x] + m.connectionManager[x].connection = conn + } + + err = m.updateChannelSubscriptions(m.connectionManager[x].connection, m.connectionManager[x].subscriptions, newSubs) + if err != nil { + return err + } + + // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. + if m.connectionManager[x].subscriptions.Len() == 0 { + delete(m.connections, m.connectionManager[x].connection) // Remove from lookup map + if err := m.connectionManager[x].connection.Shutdown(); err != nil { + log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", m.exchangeName, err) + } + m.connectionManager[x].connection = nil + } + } + return nil +} + +// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels +// have been subscribed to or unsubscribed from. +func (m *Manager) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error { + subs, unsubs := store.Diff(incoming) + if len(unsubs) != 0 { + if err := m.UnsubscribeChannels(c, unsubs); err != nil { + return err + } + + if contained := store.Contained(unsubs); len(contained) > 0 { + return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotRemoved, contained) + } + } + if len(subs) != 0 { + if err := m.SubscribeToChannels(c, subs); err != nil { + return err + } + + if missing := store.Missing(subs); len(missing) > 0 { + return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotAdded, missing) + } + } + return nil +} diff --git a/internal/exchange/websocket/subscriptions_test.go b/internal/exchange/websocket/subscriptions_test.go new file mode 100644 index 00000000..00125573 --- /dev/null +++ b/internal/exchange/websocket/subscriptions_test.go @@ -0,0 +1,305 @@ +package websocket + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" +) + +// TestSubscribe logic test +func TestSubscribeUnsubscribe(t *testing.T) { + t.Parallel() + ws := NewManager() + assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error") + + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) + + subs, err := ws.GenerateSubs() + require.NoError(t, err, "Generating test subscriptions must not error") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") + assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") + assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") + bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") + } + if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { + assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { + assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") { + assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") + } + assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") + assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") + assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") + + ws.Subscriber = func(subscription.List) error { return errDastardlyReason } + assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + + err = ws.SubscribeToChannels(nil, subscription.List{nil}) + assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") + + multi := NewManager() + set := newDefaultSetup() + set.UseMultiConnectionManagement = true + assert.NoError(t, multi.Setup(set)) + + amazingCandidate := &ConnectionSetup{ + URL: "AMAZING", + Connector: func(context.Context, Connection) error { return nil }, + GenerateSubscriptions: ws.GenerateSubs, + Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleSubConn(multi)(ctx, c, s) + }, + Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { + return currySimpleUnsubConn(multi)(ctx, c, s) + }, + Handler: func(context.Context, []byte) error { return nil }, + } + require.NoError(t, multi.SetupNewConnection(amazingCandidate)) + + amazingConn := multi.getConnectionFromSetup(amazingCandidate) + multi.connections = map[Connection]*connectionWrapper{ + amazingConn: multi.connectionManager[0], + } + + subs, err = amazingCandidate.GenerateSubscriptions() + require.NoError(t, err, "Generating test subscriptions must not error") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") + assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") + assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") + + assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") + + assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") + assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") + bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") + } + if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") { + assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { + assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") + } + if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") { + assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") + } + assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") + assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") + assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") + assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") + + amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } + assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") + + err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) + assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") +} + +// TestResubscribe tests Resubscribing to existing subscriptions +func TestResubscribe(t *testing.T) { + t.Parallel() + ws := NewManager() + + wackedOutSetup := newDefaultSetup() + wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 + err := ws.Setup(wackedOutSetup) + assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error") + + err = ws.Setup(newDefaultSetup()) + assert.NoError(t, err, "WS Setup should not error") + + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) + + channel := subscription.List{{Channel: "resubTest"}} + + assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") + assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") +} + +// TestSubscriptions tests adding, getting and removing subscriptions +func TestSubscriptions(t *testing.T) { + t.Parallel() + w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil + assert.ErrorIs(t, (*Manager)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") + s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error") + assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") + assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error") + assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") + + err := w.RemoveSubscriptions(nil, s) + require.NoError(t, err, "RemoveSubscriptions must not error") + assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") + assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") + + require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error") + require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error") + assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state") +} + +// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions +func TestSuccessfulSubscriptions(t *testing.T) { + t.Parallel() + w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil + assert.ErrorIs(t, (*Manager)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") + c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} + require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription must not error") + assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") + require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error") + assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error") + + err := w.RemoveSubscriptions(nil, c) + require.NoError(t, err, "RemoveSubscriptions must not error") + assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found") + assert.ErrorIs(t, (*Manager)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") + w.subscriptions = nil + assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") +} + +// TestGetSubscription logic test +func TestGetSubscription(t *testing.T) { + t.Parallel() + assert.Nil(t, (*Manager).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") + assert.Nil(t, (&Manager{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") + w := NewManager() + assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") + s := &subscription.Subscription{Key: 42, Channel: "hello3"} + require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error") + assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") +} + +// TestGetSubscriptions logic test +func TestGetSubscriptions(t *testing.T) { + t.Parallel() + assert.Nil(t, (*Manager).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") + assert.Nil(t, (&Manager{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") + w := NewManager() + s := subscription.List{ + {Key: 42, Channel: "hello3"}, + {Key: 45, Channel: "hello4"}, + } + err := w.AddSubscriptions(nil, s...) + require.NoError(t, err, "AddSubscriptions must not error") + assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details") +} + +func TestCheckSubscriptions(t *testing.T) { + t.Parallel() + ws := Manager{} + err := ws.checkSubscriptions(nil, nil) + assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions") + assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions") + + ws.subscriptions = subscription.NewStore() + err = ws.checkSubscriptions(nil, nil) + assert.NoError(t, err, "checkSubscriptions should not error on a nil list") + + ws.MaxSubscriptionsPerConnection = 1 + + err = ws.checkSubscriptions(nil, subscription.List{{}}) + assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty") + + ws.subscriptions = subscription.NewStore() + err = ws.checkSubscriptions(nil, subscription.List{{}, {}}) + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") + + ws.MaxSubscriptionsPerConnection = 2 + + ws.subscriptions = subscription.NewStore() + err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) + require.NoError(t, err, "Add subscription must not error") + err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}}) + assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") + + err = ws.checkSubscriptions(nil, subscription.List{{}}) + assert.NoError(t, err, "checkSubscriptions should not error") +} + +func TestUpdateChannelSubscriptions(t *testing.T) { + t.Parallel() + + ws := NewManager() + store := subscription.NewStore() + err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) + require.ErrorIs(t, err, common.ErrNilPointer) + require.Zero(t, store.Len()) + + ws.Subscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := store.Add(sub); err != nil { + return err + } + } + return nil + } + + ws.subscriptions = store + err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) + require.NoError(t, err) + require.Equal(t, 1, store.Len()) + + err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) + require.ErrorIs(t, err, common.ErrNilPointer) + + ws.Unsubscriber = func(subs subscription.List) error { + for _, sub := range subs { + if err := store.Remove(sub); err != nil { + return err + } + } + return nil + } + + err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) + require.NoError(t, err) + require.Zero(t, store.Len()) +} + +func currySimpleSub(w *Manager) func(subscription.List) error { + return func(subs subscription.List) error { + return w.AddSuccessfulSubscriptions(nil, subs...) + } +} + +func currySimpleSubConn(w *Manager) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, subs subscription.List) error { + return w.AddSuccessfulSubscriptions(conn, subs...) + } +} + +func currySimpleUnsub(w *Manager) func(subscription.List) error { + return func(unsubs subscription.List) error { + return w.RemoveSubscriptions(nil, unsubs...) + } +} + +func currySimpleUnsubConn(w *Manager) func(context.Context, Connection, subscription.List) error { + return func(_ context.Context, conn Connection, unsubs subscription.List) error { + return w.RemoveSubscriptions(conn, unsubs...) + } +} diff --git a/internal/exchange/websocket/types.go b/internal/exchange/websocket/types.go new file mode 100644 index 00000000..2ca3c012 --- /dev/null +++ b/internal/exchange/websocket/types.go @@ -0,0 +1,56 @@ +package websocket + +import ( + "time" + + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" +) + +// PingHandler container for ping handler settings +type PingHandler struct { + Websocket bool + UseGorillaHandler bool + MessageType int + Message []byte + Delay time.Duration +} + +// FundingData defines funding data +type FundingData struct { + Timestamp time.Time + CurrencyPair currency.Pair + AssetType asset.Item + Exchange string + Amount float64 + Rate float64 + Period int64 + Side order.Side +} + +// KlineData defines kline feed +type KlineData struct { + Timestamp time.Time + Pair currency.Pair + AssetType asset.Item + Exchange string + StartTime time.Time + CloseTime time.Time + Interval string + OpenPrice float64 + ClosePrice float64 + HighPrice float64 + LowPrice float64 + Volume float64 +} + +// UnhandledMessageWarning defines a container for unhandled message warnings +type UnhandledMessageWarning struct { + Message string +} + +// Reporter interface groups observability functionality over Websocket request latency. +type Reporter interface { + Latency(name string, message []byte, t time.Duration) +} diff --git a/internal/testing/exchange/exchange_test.go b/internal/testing/exchange/exchange_test.go index c77848a7..6d279082 100644 --- a/internal/testing/exchange/exchange_test.go +++ b/internal/testing/exchange/exchange_test.go @@ -3,7 +3,7 @@ package exchange import ( "testing" - "github.com/gorilla/websocket" + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" @@ -31,6 +31,6 @@ func TestMockHTTPInstance(t *testing.T) { // TestMockWsInstance exercises MockWsInstance func TestMockWsInstance(t *testing.T) { - b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *websocket.Conn) error { return nil })) + b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *gws.Conn) error { return nil })) require.NotNil(t, b, "MockWsInstance must not be nil") }