From ccde38d25abde72b8b06344706a8961c4098280b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 16 May 2022 09:04:17 +1000 Subject: [PATCH] engine: Add websocket data handler register function (#935) * engine: Add websocket interceptor register function * Update engine/engine.go Co-authored-by: Scott * Update engine/websocketroutine_manager_types.go Co-authored-by: Scott * engine/websock: switch to data handler function register and range over handlers to still include default gct handling * engine/websocket: change name * glorious: nits * linter: fix * glorious: nits Co-authored-by: Scott --- engine/engine.go | 20 ++++ engine/engine_test.go | 30 ++++++ engine/websocketroutine_manager.go | 110 ++++++++++++++----- engine/websocketroutine_manager_test.go | 129 ++++++++++++++++++++--- engine/websocketroutine_manager_types.go | 21 ++-- 5 files changed, 263 insertions(+), 47 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index f857c361..f70a3ccd 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -948,3 +948,23 @@ func (bot *Engine) SetupExchanges() error { func (bot *Engine) WaitForInitialCurrencySync() error { return bot.currencyPairSyncer.WaitForInitialSync() } + +// RegisterWebsocketDataHandler registers an externally defined data handler +// for diverting and handling websocket notifications across all enabled +// exchanges. InterceptorOnly as true will purge all other registered handlers +// (including default) bypassing all other handling. +func (bot *Engine) RegisterWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error { + if bot == nil { + return errNilBot + } + return bot.websocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly) +} + +// SetDefaultWebsocketDataHandler sets the default websocket handler and +// removing all pre-existing handlers +func (bot *Engine) SetDefaultWebsocketDataHandler() error { + if bot == nil { + return errNilBot + } + return bot.websocketRoutineManager.setWebsocketDataHandler(bot.websocketRoutineManager.websocketDataHandler) +} diff --git a/engine/engine_test.go b/engine/engine_test.go index 8aa707b4..2cad97c1 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -293,3 +293,33 @@ func TestFlagSetWith(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", isRunning, false) } } + +func TestRegisterWebsocketDataHandler(t *testing.T) { + t.Parallel() + var e *Engine + err := e.RegisterWebsocketDataHandler(nil, false) + if !errors.Is(err, errNilBot) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) + } + + e = &Engine{websocketRoutineManager: &websocketRoutineManager{}} + err = e.RegisterWebsocketDataHandler(func(_ string, _ interface{}) error { return nil }, false) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } +} + +func TestSetDefaultWebsocketDataHandler(t *testing.T) { + t.Parallel() + var e *Engine + err := e.SetDefaultWebsocketDataHandler() + if !errors.Is(err, errNilBot) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) + } + + e = &Engine{websocketRoutineManager: &websocketRoutineManager{}} + err = e.SetDefaultWebsocketDataHandler() + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } +} diff --git a/engine/websocketroutine_manager.go b/engine/websocketroutine_manager.go index 0bb97a43..f361c135 100644 --- a/engine/websocketroutine_manager.go +++ b/engine/websocketroutine_manager.go @@ -33,14 +33,15 @@ func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager if cfg.CurrencyPairFormat == nil && verbose { return nil, errNilCurrencyPairFormat } - return &websocketRoutineManager{ + man := &websocketRoutineManager{ verbose: verbose, exchangeManager: exchangeManager, orderManager: orderManager, syncer: syncer, currencyConfig: cfg, shutdown: make(chan struct{}), - }, nil + } + return man, man.registerWebsocketDataHandler(man.websocketDataHandler, false) } // Start runs the subsystem @@ -113,7 +114,12 @@ func (m *websocketRoutineManager) websocketRoutine() { if err != nil { log.Errorf(log.WebsocketMgr, "%v", err) } - go m.WebsocketDataReceiver(ws) + + err = m.websocketDataReceiver(ws) + if err != nil { + log.Errorf(log.WebsocketMgr, "%v", err) + } + err = ws.FlushChannels() if err != nil { log.Errorf(log.WebsocketMgr, "Failed to subscribe: %v", err) @@ -131,34 +137,48 @@ func (m *websocketRoutineManager) websocketRoutine() { // WebsocketDataReceiver handles websocket data coming from a websocket feed // associated with an exchange -func (m *websocketRoutineManager) WebsocketDataReceiver(ws *stream.Websocket) { - if m == nil || atomic.LoadInt32(&m.started) == 0 { - return +func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error { + if m == nil { + return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem) } - m.wg.Add(1) - defer m.wg.Done() - for { - select { - case <-m.shutdown: - return - case data := <-ws.ToRoutine: - err := m.WebsocketDataHandler(ws.GetName(), data) - if err != nil { - log.Error(log.WebsocketMgr, err) + if ws == nil { + return errNilWebsocket + } + + if atomic.LoadInt32(&m.started) == 0 { + return errRoutineManagerNotStarted + } + + m.wg.Add(1) + go func() { + defer m.wg.Done() + for { + select { + case <-m.shutdown: + return + case data := <-ws.ToRoutine: + if data == nil { + log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName()) + } + m.mu.RLock() + for x := range m.dataHandlers { + err := m.dataHandlers[x](ws.GetName(), data) + if err != nil { + log.Error(log.WebsocketMgr, err) + } + } + m.mu.RUnlock() } } - } + }() + return nil } -// WebsocketDataHandler is a central point for exchange websocket implementations to send -// processed data. WebsocketDataHandler will then pass that to an appropriate handler -func (m *websocketRoutineManager) WebsocketDataHandler(exchName string, data interface{}) error { - if data == nil { - return fmt.Errorf("exchange %s nil data sent to websocket", - exchName) - } - +// websocketDataHandler is the default central point for exchange websocket +// implementations to send processed data which will then pass that to an +// appropriate handler. +func (m *websocketRoutineManager) websocketDataHandler(exchName string, data interface{}) error { switch d := data.(type) { case string: log.Info(log.WebsocketMgr, d) @@ -338,3 +358,43 @@ func (m *websocketRoutineManager) printAccountHoldingsChangeSummary(o account.Ch o.Amount, o.Account) } + +// registerWebsocketDataHandler registers an externally (GCT Library) defined +// dedicated filter specific data types for internal & external strategy use. +// InterceptorOnly as true will purge all other registered handlers +// (including default) bypassing all other handling. +func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error { + if m == nil { + return fmt.Errorf("%T %w", m, ErrNilSubsystem) + } + + if fn == nil { + return errNilWebsocketDataHandlerFunction + } + + if interceptorOnly { + return m.setWebsocketDataHandler(fn) + } + + m.mu.Lock() + // Push front so that any registered data handler has first preference + // over the gct default handler. + m.dataHandlers = append([]WebsocketDataHandler{fn}, m.dataHandlers...) + m.mu.Unlock() + return nil +} + +// setWebsocketDataHandler sets a single websocket data handler, removing all +// pre-existing handlers. +func (m *websocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error { + if m == nil { + return fmt.Errorf("%T %w", m, ErrNilSubsystem) + } + if fn == nil { + return errNilWebsocketDataHandlerFunction + } + m.mu.Lock() + m.dataHandlers = []WebsocketDataHandler{fn} + m.mu.Unlock() + return nil +} diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index 34cf56a9..1938aa7a 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -152,19 +152,15 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { t.Errorf("error '%v', expected '%v'", err, nil) } var orderID = "1337" - err = m.WebsocketDataHandler(exchName, errors.New("error")) + err = m.websocketDataHandler(exchName, errors.New("error")) if err == nil { t.Error("Error not handled correctly") } - err = m.WebsocketDataHandler(exchName, nil) - if err == nil { - t.Error("Expected nil data error") - } - err = m.WebsocketDataHandler(exchName, stream.FundingData{}) + err = m.websocketDataHandler(exchName, stream.FundingData{}) if err != nil { t.Error(err) } - err = m.WebsocketDataHandler(exchName, &ticker.Price{ + err = m.websocketDataHandler(exchName, &ticker.Price{ ExchangeName: exchName, Pair: currency.NewPair(currency.BTC, currency.USDC), AssetType: asset.Spot, @@ -172,7 +168,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { if !errors.Is(err, nil) { t.Errorf("error '%v', expected '%v'", err, nil) } - err = m.WebsocketDataHandler(exchName, stream.KlineData{}) + err = m.websocketDataHandler(exchName, stream.KlineData{}) if err != nil { t.Error(err) } @@ -182,12 +178,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { Amount: 1337, Price: 1337, } - err = m.WebsocketDataHandler(exchName, origOrder) + err = m.websocketDataHandler(exchName, origOrder) if err != nil { t.Error(err) } // Send it again since it exists now - err = m.WebsocketDataHandler(exchName, &order.Detail{ + err = m.websocketDataHandler(exchName, &order.Detail{ Exchange: exchName, ID: orderID, Amount: 1338, @@ -203,7 +199,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { t.Error("Bad pipeline") } - err = m.WebsocketDataHandler(exchName, &order.Modify{ + err = m.websocketDataHandler(exchName, &order.Modify{ Exchange: "Bitstamp", ID: orderID, Status: order.Active, @@ -220,12 +216,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { } // Send some gibberish - err = m.WebsocketDataHandler(exchName, order.Stop) + err = m.websocketDataHandler(exchName, order.Stop) if err != nil { t.Error(err) } - err = m.WebsocketDataHandler(exchName, stream.UnhandledMessageWarning{ + err = m.websocketDataHandler(exchName, stream.UnhandledMessageWarning{ Message: "there's an issue here's a tissue"}, ) if err != nil { @@ -237,7 +233,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { OrderID: "one", Err: errors.New("lol"), } - err = m.WebsocketDataHandler(exchName, classificationError) + err = m.websocketDataHandler(exchName, classificationError) if err == nil { t.Error("Expected error") } @@ -245,15 +241,116 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { t.Errorf("error '%v', expected '%v'", err, classificationError.Err) } - err = m.WebsocketDataHandler(exchName, &orderbook.Base{ + err = m.websocketDataHandler(exchName, &orderbook.Base{ Exchange: "Bitstamp", Pair: currency.NewPair(currency.BTC, currency.USD), }) if err != nil { t.Error(err) } - err = m.WebsocketDataHandler(exchName, "this is a test string") + err = m.websocketDataHandler(exchName, "this is a test string") if err != nil { t.Error(err) } } + +func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { + t.Parallel() + var m *websocketRoutineManager + err := m.registerWebsocketDataHandler(nil, false) + if !errors.Is(err, ErrNilSubsystem) { + t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) + } + + m = new(websocketRoutineManager) + m.shutdown = make(chan struct{}) + + err = m.registerWebsocketDataHandler(nil, false) + if !errors.Is(err, errNilWebsocketDataHandlerFunction) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction) + } + + // externally defined capture device + dataChan := make(chan interface{}) + fn := func(_ string, data interface{}) error { + switch data.(type) { + case string: + dataChan <- data + default: + } + return nil + } + + err = m.registerWebsocketDataHandler(fn, true) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(m.dataHandlers) != 1 { + t.Fatal("unexpected data handlers registered") + } + + mock := stream.New() + mock.ToRoutine = make(chan interface{}) + m.started = 1 + err = m.websocketDataReceiver(mock) + if err != nil { + t.Fatal(err) + } + + mock.ToRoutine <- nil + mock.ToRoutine <- 1336 + mock.ToRoutine <- "intercepted" + + if r := <-dataChan; r != "intercepted" { + t.Fatal("unexpected value received") + } + + close(m.shutdown) + m.wg.Wait() +} + +func TestSetWebsocketDataHandler(t *testing.T) { + t.Parallel() + var m *websocketRoutineManager + err := m.setWebsocketDataHandler(nil) + if !errors.Is(err, ErrNilSubsystem) { + t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) + } + + m = new(websocketRoutineManager) + m.shutdown = make(chan struct{}) + + err = m.setWebsocketDataHandler(nil) + if !errors.Is(err, errNilWebsocketDataHandlerFunction) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction) + } + + err = m.registerWebsocketDataHandler(m.websocketDataHandler, false) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + err = m.registerWebsocketDataHandler(m.websocketDataHandler, false) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + err = m.registerWebsocketDataHandler(m.websocketDataHandler, false) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(m.dataHandlers) != 3 { + t.Fatal("unexpected data handler count") + } + + err = m.setWebsocketDataHandler(m.websocketDataHandler) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if len(m.dataHandlers) != 1 { + t.Fatal("unexpected data handler count") + } +} diff --git a/engine/websocketroutine_manager_types.go b/engine/websocketroutine_manager_types.go index c1fc7b58..a00f710e 100644 --- a/engine/websocketroutine_manager_types.go +++ b/engine/websocketroutine_manager_types.go @@ -7,6 +7,16 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" ) +var ( + errNilOrderManager = errors.New("nil order manager received") + errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received") + errNilCurrencyConfig = errors.New("nil currency config received") + errNilCurrencyPairFormat = errors.New("nil currency pair format received") + errNilWebsocketDataHandlerFunction = errors.New("websocket data handler function is nil") + errNilWebsocket = errors.New("websocket is nil") + errRoutineManagerNotStarted = errors.New("websocket routine manager not started") +) + // websocketRoutineManager is used to process websocket updates from a unified location type websocketRoutineManager struct { started int32 @@ -16,12 +26,11 @@ type websocketRoutineManager struct { syncer iCurrencyPairSyncer currencyConfig *currency.Config shutdown chan struct{} + dataHandlers []WebsocketDataHandler wg sync.WaitGroup + mu sync.RWMutex } -var ( - errNilOrderManager = errors.New("nil order manager received") - errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received") - errNilCurrencyConfig = errors.New("nil currency config received") - errNilCurrencyPairFormat = errors.New("nil currency pair format received") -) +// WebsocketDataHandler defines a function signature for a function that handles +// data coming from websocket connections. +type WebsocketDataHandler func(service string, incoming interface{}) error