engine: Add websocket data handler register function (#935)

* engine: Add websocket interceptor register function

* Update engine/engine.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* Update engine/websocketroutine_manager_types.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* engine/websock: switch to data handler function register and range over handlers to still include default gct handling

* engine/websocket: change name

* glorious: nits

* linter: fix

* glorious: nits

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2022-05-16 09:04:17 +10:00
committed by GitHub
parent 5cb26e7ecf
commit ccde38d25a
5 changed files with 263 additions and 47 deletions

View File

@@ -948,3 +948,23 @@ func (bot *Engine) SetupExchanges() error {
func (bot *Engine) WaitForInitialCurrencySync() error {
return bot.currencyPairSyncer.WaitForInitialSync()
}
// RegisterWebsocketDataHandler registers an externally defined data handler
// for diverting and handling websocket notifications across all enabled
// exchanges. InterceptorOnly as true will purge all other registered handlers
// (including default) bypassing all other handling.
func (bot *Engine) RegisterWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error {
if bot == nil {
return errNilBot
}
return bot.websocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly)
}
// SetDefaultWebsocketDataHandler sets the default websocket handler and
// removing all pre-existing handlers
func (bot *Engine) SetDefaultWebsocketDataHandler() error {
if bot == nil {
return errNilBot
}
return bot.websocketRoutineManager.setWebsocketDataHandler(bot.websocketRoutineManager.websocketDataHandler)
}

View File

@@ -293,3 +293,33 @@ func TestFlagSetWith(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", isRunning, false)
}
}
func TestRegisterWebsocketDataHandler(t *testing.T) {
t.Parallel()
var e *Engine
err := e.RegisterWebsocketDataHandler(nil, false)
if !errors.Is(err, errNilBot) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot)
}
e = &Engine{websocketRoutineManager: &websocketRoutineManager{}}
err = e.RegisterWebsocketDataHandler(func(_ string, _ interface{}) error { return nil }, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}
func TestSetDefaultWebsocketDataHandler(t *testing.T) {
t.Parallel()
var e *Engine
err := e.SetDefaultWebsocketDataHandler()
if !errors.Is(err, errNilBot) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot)
}
e = &Engine{websocketRoutineManager: &websocketRoutineManager{}}
err = e.SetDefaultWebsocketDataHandler()
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}

View File

@@ -33,14 +33,15 @@ func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager
if cfg.CurrencyPairFormat == nil && verbose {
return nil, errNilCurrencyPairFormat
}
return &websocketRoutineManager{
man := &websocketRoutineManager{
verbose: verbose,
exchangeManager: exchangeManager,
orderManager: orderManager,
syncer: syncer,
currencyConfig: cfg,
shutdown: make(chan struct{}),
}, nil
}
return man, man.registerWebsocketDataHandler(man.websocketDataHandler, false)
}
// Start runs the subsystem
@@ -113,7 +114,12 @@ func (m *websocketRoutineManager) websocketRoutine() {
if err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}
go m.WebsocketDataReceiver(ws)
err = m.websocketDataReceiver(ws)
if err != nil {
log.Errorf(log.WebsocketMgr, "%v", err)
}
err = ws.FlushChannels()
if err != nil {
log.Errorf(log.WebsocketMgr, "Failed to subscribe: %v", err)
@@ -131,34 +137,48 @@ func (m *websocketRoutineManager) websocketRoutine() {
// WebsocketDataReceiver handles websocket data coming from a websocket feed
// associated with an exchange
func (m *websocketRoutineManager) WebsocketDataReceiver(ws *stream.Websocket) {
if m == nil || atomic.LoadInt32(&m.started) == 0 {
return
func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error {
if m == nil {
return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem)
}
m.wg.Add(1)
defer m.wg.Done()
for {
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
err := m.WebsocketDataHandler(ws.GetName(), data)
if err != nil {
log.Error(log.WebsocketMgr, err)
if ws == nil {
return errNilWebsocket
}
if atomic.LoadInt32(&m.started) == 0 {
return errRoutineManagerNotStarted
}
m.wg.Add(1)
go func() {
defer m.wg.Done()
for {
select {
case <-m.shutdown:
return
case data := <-ws.ToRoutine:
if data == nil {
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
}
m.mu.RLock()
for x := range m.dataHandlers {
err := m.dataHandlers[x](ws.GetName(), data)
if err != nil {
log.Error(log.WebsocketMgr, err)
}
}
m.mu.RUnlock()
}
}
}
}()
return nil
}
// WebsocketDataHandler is a central point for exchange websocket implementations to send
// processed data. WebsocketDataHandler will then pass that to an appropriate handler
func (m *websocketRoutineManager) WebsocketDataHandler(exchName string, data interface{}) error {
if data == nil {
return fmt.Errorf("exchange %s nil data sent to websocket",
exchName)
}
// websocketDataHandler is the default central point for exchange websocket
// implementations to send processed data which will then pass that to an
// appropriate handler.
func (m *websocketRoutineManager) websocketDataHandler(exchName string, data interface{}) error {
switch d := data.(type) {
case string:
log.Info(log.WebsocketMgr, d)
@@ -338,3 +358,43 @@ func (m *websocketRoutineManager) printAccountHoldingsChangeSummary(o account.Ch
o.Amount,
o.Account)
}
// registerWebsocketDataHandler registers an externally (GCT Library) defined
// dedicated filter specific data types for internal & external strategy use.
// InterceptorOnly as true will purge all other registered handlers
// (including default) bypassing all other handling.
func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error {
if m == nil {
return fmt.Errorf("%T %w", m, ErrNilSubsystem)
}
if fn == nil {
return errNilWebsocketDataHandlerFunction
}
if interceptorOnly {
return m.setWebsocketDataHandler(fn)
}
m.mu.Lock()
// Push front so that any registered data handler has first preference
// over the gct default handler.
m.dataHandlers = append([]WebsocketDataHandler{fn}, m.dataHandlers...)
m.mu.Unlock()
return nil
}
// setWebsocketDataHandler sets a single websocket data handler, removing all
// pre-existing handlers.
func (m *websocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error {
if m == nil {
return fmt.Errorf("%T %w", m, ErrNilSubsystem)
}
if fn == nil {
return errNilWebsocketDataHandlerFunction
}
m.mu.Lock()
m.dataHandlers = []WebsocketDataHandler{fn}
m.mu.Unlock()
return nil
}

View File

@@ -152,19 +152,15 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
t.Errorf("error '%v', expected '%v'", err, nil)
}
var orderID = "1337"
err = m.WebsocketDataHandler(exchName, errors.New("error"))
err = m.websocketDataHandler(exchName, errors.New("error"))
if err == nil {
t.Error("Error not handled correctly")
}
err = m.WebsocketDataHandler(exchName, nil)
if err == nil {
t.Error("Expected nil data error")
}
err = m.WebsocketDataHandler(exchName, stream.FundingData{})
err = m.websocketDataHandler(exchName, stream.FundingData{})
if err != nil {
t.Error(err)
}
err = m.WebsocketDataHandler(exchName, &ticker.Price{
err = m.websocketDataHandler(exchName, &ticker.Price{
ExchangeName: exchName,
Pair: currency.NewPair(currency.BTC, currency.USDC),
AssetType: asset.Spot,
@@ -172,7 +168,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
if !errors.Is(err, nil) {
t.Errorf("error '%v', expected '%v'", err, nil)
}
err = m.WebsocketDataHandler(exchName, stream.KlineData{})
err = m.websocketDataHandler(exchName, stream.KlineData{})
if err != nil {
t.Error(err)
}
@@ -182,12 +178,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
Amount: 1337,
Price: 1337,
}
err = m.WebsocketDataHandler(exchName, origOrder)
err = m.websocketDataHandler(exchName, origOrder)
if err != nil {
t.Error(err)
}
// Send it again since it exists now
err = m.WebsocketDataHandler(exchName, &order.Detail{
err = m.websocketDataHandler(exchName, &order.Detail{
Exchange: exchName,
ID: orderID,
Amount: 1338,
@@ -203,7 +199,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
t.Error("Bad pipeline")
}
err = m.WebsocketDataHandler(exchName, &order.Modify{
err = m.websocketDataHandler(exchName, &order.Modify{
Exchange: "Bitstamp",
ID: orderID,
Status: order.Active,
@@ -220,12 +216,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
}
// Send some gibberish
err = m.WebsocketDataHandler(exchName, order.Stop)
err = m.websocketDataHandler(exchName, order.Stop)
if err != nil {
t.Error(err)
}
err = m.WebsocketDataHandler(exchName, stream.UnhandledMessageWarning{
err = m.websocketDataHandler(exchName, stream.UnhandledMessageWarning{
Message: "there's an issue here's a tissue"},
)
if err != nil {
@@ -237,7 +233,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
OrderID: "one",
Err: errors.New("lol"),
}
err = m.WebsocketDataHandler(exchName, classificationError)
err = m.websocketDataHandler(exchName, classificationError)
if err == nil {
t.Error("Expected error")
}
@@ -245,15 +241,116 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
t.Errorf("error '%v', expected '%v'", err, classificationError.Err)
}
err = m.WebsocketDataHandler(exchName, &orderbook.Base{
err = m.websocketDataHandler(exchName, &orderbook.Base{
Exchange: "Bitstamp",
Pair: currency.NewPair(currency.BTC, currency.USD),
})
if err != nil {
t.Error(err)
}
err = m.WebsocketDataHandler(exchName, "this is a test string")
err = m.websocketDataHandler(exchName, "this is a test string")
if err != nil {
t.Error(err)
}
}
func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
t.Parallel()
var m *websocketRoutineManager
err := m.registerWebsocketDataHandler(nil, false)
if !errors.Is(err, ErrNilSubsystem) {
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem)
}
m = new(websocketRoutineManager)
m.shutdown = make(chan struct{})
err = m.registerWebsocketDataHandler(nil, false)
if !errors.Is(err, errNilWebsocketDataHandlerFunction) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction)
}
// externally defined capture device
dataChan := make(chan interface{})
fn := func(_ string, data interface{}) error {
switch data.(type) {
case string:
dataChan <- data
default:
}
return nil
}
err = m.registerWebsocketDataHandler(fn, true)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(m.dataHandlers) != 1 {
t.Fatal("unexpected data handlers registered")
}
mock := stream.New()
mock.ToRoutine = make(chan interface{})
m.started = 1
err = m.websocketDataReceiver(mock)
if err != nil {
t.Fatal(err)
}
mock.ToRoutine <- nil
mock.ToRoutine <- 1336
mock.ToRoutine <- "intercepted"
if r := <-dataChan; r != "intercepted" {
t.Fatal("unexpected value received")
}
close(m.shutdown)
m.wg.Wait()
}
func TestSetWebsocketDataHandler(t *testing.T) {
t.Parallel()
var m *websocketRoutineManager
err := m.setWebsocketDataHandler(nil)
if !errors.Is(err, ErrNilSubsystem) {
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem)
}
m = new(websocketRoutineManager)
m.shutdown = make(chan struct{})
err = m.setWebsocketDataHandler(nil)
if !errors.Is(err, errNilWebsocketDataHandlerFunction) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction)
}
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(m.dataHandlers) != 3 {
t.Fatal("unexpected data handler count")
}
err = m.setWebsocketDataHandler(m.websocketDataHandler)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if len(m.dataHandlers) != 1 {
t.Fatal("unexpected data handler count")
}
}

View File

@@ -7,6 +7,16 @@ import (
"github.com/thrasher-corp/gocryptotrader/currency"
)
var (
errNilOrderManager = errors.New("nil order manager received")
errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received")
errNilCurrencyConfig = errors.New("nil currency config received")
errNilCurrencyPairFormat = errors.New("nil currency pair format received")
errNilWebsocketDataHandlerFunction = errors.New("websocket data handler function is nil")
errNilWebsocket = errors.New("websocket is nil")
errRoutineManagerNotStarted = errors.New("websocket routine manager not started")
)
// websocketRoutineManager is used to process websocket updates from a unified location
type websocketRoutineManager struct {
started int32
@@ -16,12 +26,11 @@ type websocketRoutineManager struct {
syncer iCurrencyPairSyncer
currencyConfig *currency.Config
shutdown chan struct{}
dataHandlers []WebsocketDataHandler
wg sync.WaitGroup
mu sync.RWMutex
}
var (
errNilOrderManager = errors.New("nil order manager received")
errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received")
errNilCurrencyConfig = errors.New("nil currency config received")
errNilCurrencyPairFormat = errors.New("nil currency pair format received")
)
// WebsocketDataHandler defines a function signature for a function that handles
// data coming from websocket connections.
type WebsocketDataHandler func(service string, incoming interface{}) error