diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 404d76d8..e7585e94 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -348,18 +348,10 @@ func (w *Websocket) dataMonitor() { go func() { defer func() { - for { - // Bleeds data from the websocket connection if needed - select { - case <-w.DataHandler: - default: - w.setDataMonitorRunning(false) - w.Wg.Done() - return - } - } + w.setDataMonitorRunning(false) + w.Wg.Done() }() - + dropped := 0 for { select { case <-w.ShutdownC: @@ -367,15 +359,16 @@ func (w *Websocket) dataMonitor() { case d := <-w.DataHandler: select { case w.ToRoutine <- d: - case <-w.ShutdownC: - return - default: - log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) - select { - case w.ToRoutine <- d: - case <-w.ShutdownC: - return + if dropped != 0 { + log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped) + dropped = 0 } + default: + if dropped == 0 { + // If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible + log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName) + } + dropped++ } } } @@ -413,12 +406,15 @@ func (w *Websocket) connectionMonitor() error { } select { case err := <-w.ReadMessageErrors: + w.DataHandler <- err if IsDisconnectionError(err) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - w.setState(disconnected) + if w.IsConnected() { + if shutdownErr := w.Shutdown(); shutdownErr != nil { + log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr) + } + } } - - w.DataHandler <- err case <-timer.C: if !w.IsConnecting() && !w.IsConnected() { err := w.Connect() diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 0d4e9c02..5e4891f8 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -326,26 +326,29 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Connect() require.NoError(t, err, "Connect must not error") - ws.TrafficAlert <- struct{}{} - c := func(tb *assert.CollectT) { select { - case v := <-ws.ToRoutine: + case v, ok := <-ws.ToRoutine: + require.True(tb, ok, "ToRoutine should not be closed on us") switch err := v.(type) { case *websocket.CloseError: assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") case error: assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") + default: + assert.Failf(tb, "Wrong data type sent to ToRoutine", "Got type: %T", err) } default: + assert.Fail(tb, "Nothing available on ToRoutine") } } + ws.TrafficAlert <- struct{}{} ws.ReadMessageErrors <- errDastardlyReason - assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} - assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + assert.EventuallyWithT(t, c, 2*time.Second, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) {