diff --git a/exchanges/okx/okx_types.go b/exchanges/okx/okx_types.go index cd99e930..2ae81cce 100644 --- a/exchanges/okx/okx_types.go +++ b/exchanges/okx/okx_types.go @@ -3171,6 +3171,7 @@ type wsRequestDataChannelsMultiplexer struct { Register chan *wsRequestInfo Unregister chan string Message chan *wsIncomingData + shutdown chan bool } // wsSubscriptionParameters represents toggling boolean values for subscription parameters. diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 5f5fece5..8ba715bb 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -27,8 +27,6 @@ import ( var ( errInvalidChecksum = errors.New("invalid checksum") - // responseStream a channel thought which the data coming from the two websocket connection will go through. - responseStream = make(chan stream.Response) ) var ( @@ -226,10 +224,8 @@ func (ok *Okx) WsConnect() error { if err != nil { return err } - ok.Websocket.Wg.Add(2) - go ok.wsFunnelConnectionData(ok.Websocket.Conn) - go ok.WsReadData() - go ok.WsResponseMultiplexer.Run() + ok.Websocket.Wg.Add(1) + go ok.wsReadData(ok.Websocket.Conn) if ok.Verbose { log.Debugf(log.ExchangeSys, "Successful connection to %v\n", ok.Websocket.GetWebsocketURL()) @@ -262,7 +258,7 @@ func (ok *Okx) WsAuth(ctx context.Context, dialer *websocket.Dialer) error { return fmt.Errorf("%v Websocket connection %v error. Error %v", ok.Name, okxAPIWebsocketPrivateURL, err) } ok.Websocket.Wg.Add(1) - go ok.wsFunnelConnectionData(ok.Websocket.AuthConn) + go ok.wsReadData(ok.Websocket.AuthConn) ok.Websocket.AuthConn.SetupPingHandler(stream.PingHandler{ MessageType: websocket.TextMessage, Message: pingMsg, @@ -335,16 +331,17 @@ func (ok *Okx) WsAuth(ctx context.Context, dialer *websocket.Dialer) error { } } -// wsFunnelConnectionData receives data from multiple connection and pass the data -// to wsRead through a channel responseStream -func (ok *Okx) wsFunnelConnectionData(ws stream.Connection) { +// wsReadData sends msgs from public and auth websockets to data handler +func (ok *Okx) wsReadData(ws stream.Connection) { defer ok.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil { return } - responseStream <- stream.Response{Raw: resp.Raw} + if err := ok.WsHandleData(resp.Raw); err != nil { + ok.Websocket.DataHandler <- err + } } } @@ -531,34 +528,6 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []stream.Chann return nil } -// WsReadData read coming messages thought the websocket connection and process the data. -func (ok *Okx) WsReadData() { - defer ok.Websocket.Wg.Done() - for { - select { - case <-ok.Websocket.ShutdownC: - select { - case resp := <-responseStream: - err := ok.WsHandleData(resp.Raw) - if err != nil { - select { - case ok.Websocket.DataHandler <- err: - default: - log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", ok.Name, err) - } - } - default: - } - return - case resp := <-responseStream: - err := ok.WsHandleData(resp.Raw) - if err != nil { - ok.Websocket.DataHandler <- err - } - } - } -} - // WsHandleData will read websocket raw data and pass to appropriate handler func (ok *Okx) WsHandleData(respRaw []byte) error { var resp wsIncomingData @@ -1675,6 +1644,10 @@ func (m *wsRequestDataChannelsMultiplexer) Run() { tickerData := time.NewTicker(time.Second) for { select { + case <-m.shutdown: + // We've consumed the shutdown, so create a new chan for subsequent runs + m.shutdown = make(chan bool) + return case <-tickerData.C: for x, myChan := range m.WsResponseChannelsMap { if myChan == nil { @@ -1709,6 +1682,12 @@ func (m *wsRequestDataChannelsMultiplexer) Run() { } } +// Shutdown causes the multiplexer to exit its Run loop +// All channels are left open, but websocket shutdown first will ensure no more messages block on multiplexer reading +func (m *wsRequestDataChannelsMultiplexer) Shutdown() { + close(m.shutdown) +} + // wsChannelSubscription sends a subscription or unsubscription request for different channels through the websocket stream. func (ok *Okx) wsChannelSubscription(operation, channel string, assetType asset.Item, pair currency.Pair, tInstrumentType, tInstrumentID, tUnderlying bool) error { if operation != operationSubscribe && operation != operationUnsubscribe { diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 7ba52901..2cc5b106 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -174,16 +174,14 @@ func (ok *Okx) SetDefaults() { // Setup takes in the supplied exchange configuration details and sets params func (ok *Okx) Setup(exch *config.Exchange) error { - err := exch.Validate() - if err != nil { + if err := exch.Validate(); err != nil { return err } if !exch.Enabled { ok.SetEnabled(false) return nil } - err = ok.SetupDefaults(exch) - if err != nil { + if err := ok.SetupDefaults(exch); err != nil { return err } @@ -192,13 +190,14 @@ func (ok *Okx) Setup(exch *config.Exchange) error { Register: make(chan *wsRequestInfo), Unregister: make(chan string), Message: make(chan *wsIncomingData), + shutdown: make(chan bool), } wsRunningEndpoint, err := ok.API.Endpoints.GetURL(exchange.WebsocketSpot) if err != nil { return err } - err = ok.Websocket.Setup(&stream.WebsocketSetup{ + if err := ok.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: exch, DefaultURL: okxAPIWebsocketPublicURL, RunningURL: wsRunningEndpoint, @@ -211,20 +210,21 @@ func (ok *Okx) Setup(exch *config.Exchange) error { OrderbookBufferConfig: buffer.Config{ Checksum: ok.CalculateUpdateOrderbookChecksum, }, - }) - - if err != nil { + }); err != nil { return err } - err = ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ + + go ok.WsResponseMultiplexer.Run() + + if err := ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ URL: okxAPIWebsocketPublicURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: okxWebsocketResponseMaxLimit, RateLimit: 500, - }) - if err != nil { + }); err != nil { return err } + return ok.Websocket.SetupNewConnection(stream.ConnectionSetup{ URL: okxAPIWebsocketPrivateURL, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, @@ -277,6 +277,18 @@ func (ok *Okx) Run(ctx context.Context) { } } +// Shutdown calls Base.Shutdown and then shuts down the response multiplexer +func (ok *Okx) Shutdown() error { + if err := ok.Base.Shutdown(); err != nil { + return err + } + + // Must happen after the Websocket shutdown in Base.Shutdown, so there are no new blocking writes to the multiplexer + ok.WsResponseMultiplexer.Shutdown() + + return nil +} + // GetServerTime returns the current exchange server time. func (ok *Okx) GetServerTime(ctx context.Context, _ asset.Item) (time.Time, error) { return ok.GetSystemTime(ctx)