mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-21 23:16:49 +00:00
engine: Add websocket data handler register function (#935)
* engine: Add websocket interceptor register function * Update engine/engine.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> * Update engine/websocketroutine_manager_types.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> * engine/websock: switch to data handler function register and range over handlers to still include default gct handling * engine/websocket: change name * glorious: nits * linter: fix * glorious: nits Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
@@ -948,3 +948,23 @@ func (bot *Engine) SetupExchanges() error {
|
||||
func (bot *Engine) WaitForInitialCurrencySync() error {
|
||||
return bot.currencyPairSyncer.WaitForInitialSync()
|
||||
}
|
||||
|
||||
// RegisterWebsocketDataHandler registers an externally defined data handler
|
||||
// for diverting and handling websocket notifications across all enabled
|
||||
// exchanges. InterceptorOnly as true will purge all other registered handlers
|
||||
// (including default) bypassing all other handling.
|
||||
func (bot *Engine) RegisterWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error {
|
||||
if bot == nil {
|
||||
return errNilBot
|
||||
}
|
||||
return bot.websocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly)
|
||||
}
|
||||
|
||||
// SetDefaultWebsocketDataHandler sets the default websocket handler and
|
||||
// removing all pre-existing handlers
|
||||
func (bot *Engine) SetDefaultWebsocketDataHandler() error {
|
||||
if bot == nil {
|
||||
return errNilBot
|
||||
}
|
||||
return bot.websocketRoutineManager.setWebsocketDataHandler(bot.websocketRoutineManager.websocketDataHandler)
|
||||
}
|
||||
|
||||
@@ -293,3 +293,33 @@ func TestFlagSetWith(t *testing.T) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", isRunning, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterWebsocketDataHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *Engine
|
||||
err := e.RegisterWebsocketDataHandler(nil, false)
|
||||
if !errors.Is(err, errNilBot) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot)
|
||||
}
|
||||
|
||||
e = &Engine{websocketRoutineManager: &websocketRoutineManager{}}
|
||||
err = e.RegisterWebsocketDataHandler(func(_ string, _ interface{}) error { return nil }, false)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDefaultWebsocketDataHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e *Engine
|
||||
err := e.SetDefaultWebsocketDataHandler()
|
||||
if !errors.Is(err, errNilBot) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot)
|
||||
}
|
||||
|
||||
e = &Engine{websocketRoutineManager: &websocketRoutineManager{}}
|
||||
err = e.SetDefaultWebsocketDataHandler()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,14 +33,15 @@ func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager
|
||||
if cfg.CurrencyPairFormat == nil && verbose {
|
||||
return nil, errNilCurrencyPairFormat
|
||||
}
|
||||
return &websocketRoutineManager{
|
||||
man := &websocketRoutineManager{
|
||||
verbose: verbose,
|
||||
exchangeManager: exchangeManager,
|
||||
orderManager: orderManager,
|
||||
syncer: syncer,
|
||||
currencyConfig: cfg,
|
||||
shutdown: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
return man, man.registerWebsocketDataHandler(man.websocketDataHandler, false)
|
||||
}
|
||||
|
||||
// Start runs the subsystem
|
||||
@@ -113,7 +114,12 @@ func (m *websocketRoutineManager) websocketRoutine() {
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "%v", err)
|
||||
}
|
||||
go m.WebsocketDataReceiver(ws)
|
||||
|
||||
err = m.websocketDataReceiver(ws)
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "%v", err)
|
||||
}
|
||||
|
||||
err = ws.FlushChannels()
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "Failed to subscribe: %v", err)
|
||||
@@ -131,34 +137,48 @@ func (m *websocketRoutineManager) websocketRoutine() {
|
||||
|
||||
// WebsocketDataReceiver handles websocket data coming from a websocket feed
|
||||
// associated with an exchange
|
||||
func (m *websocketRoutineManager) WebsocketDataReceiver(ws *stream.Websocket) {
|
||||
if m == nil || atomic.LoadInt32(&m.started) == 0 {
|
||||
return
|
||||
func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem)
|
||||
}
|
||||
m.wg.Add(1)
|
||||
defer m.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.shutdown:
|
||||
return
|
||||
case data := <-ws.ToRoutine:
|
||||
err := m.WebsocketDataHandler(ws.GetName(), data)
|
||||
if err != nil {
|
||||
log.Error(log.WebsocketMgr, err)
|
||||
if ws == nil {
|
||||
return errNilWebsocket
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&m.started) == 0 {
|
||||
return errRoutineManagerNotStarted
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-m.shutdown:
|
||||
return
|
||||
case data := <-ws.ToRoutine:
|
||||
if data == nil {
|
||||
log.Errorf(log.WebsocketMgr, "exchange %s nil data sent to websocket", ws.GetName())
|
||||
}
|
||||
m.mu.RLock()
|
||||
for x := range m.dataHandlers {
|
||||
err := m.dataHandlers[x](ws.GetName(), data)
|
||||
if err != nil {
|
||||
log.Error(log.WebsocketMgr, err)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebsocketDataHandler is a central point for exchange websocket implementations to send
|
||||
// processed data. WebsocketDataHandler will then pass that to an appropriate handler
|
||||
func (m *websocketRoutineManager) WebsocketDataHandler(exchName string, data interface{}) error {
|
||||
if data == nil {
|
||||
return fmt.Errorf("exchange %s nil data sent to websocket",
|
||||
exchName)
|
||||
}
|
||||
|
||||
// websocketDataHandler is the default central point for exchange websocket
|
||||
// implementations to send processed data which will then pass that to an
|
||||
// appropriate handler.
|
||||
func (m *websocketRoutineManager) websocketDataHandler(exchName string, data interface{}) error {
|
||||
switch d := data.(type) {
|
||||
case string:
|
||||
log.Info(log.WebsocketMgr, d)
|
||||
@@ -338,3 +358,43 @@ func (m *websocketRoutineManager) printAccountHoldingsChangeSummary(o account.Ch
|
||||
o.Amount,
|
||||
o.Account)
|
||||
}
|
||||
|
||||
// registerWebsocketDataHandler registers an externally (GCT Library) defined
|
||||
// dedicated filter specific data types for internal & external strategy use.
|
||||
// InterceptorOnly as true will purge all other registered handlers
|
||||
// (including default) bypassing all other handling.
|
||||
func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("%T %w", m, ErrNilSubsystem)
|
||||
}
|
||||
|
||||
if fn == nil {
|
||||
return errNilWebsocketDataHandlerFunction
|
||||
}
|
||||
|
||||
if interceptorOnly {
|
||||
return m.setWebsocketDataHandler(fn)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
// Push front so that any registered data handler has first preference
|
||||
// over the gct default handler.
|
||||
m.dataHandlers = append([]WebsocketDataHandler{fn}, m.dataHandlers...)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// setWebsocketDataHandler sets a single websocket data handler, removing all
|
||||
// pre-existing handlers.
|
||||
func (m *websocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("%T %w", m, ErrNilSubsystem)
|
||||
}
|
||||
if fn == nil {
|
||||
return errNilWebsocketDataHandlerFunction
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.dataHandlers = []WebsocketDataHandler{fn}
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -152,19 +152,15 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
t.Errorf("error '%v', expected '%v'", err, nil)
|
||||
}
|
||||
var orderID = "1337"
|
||||
err = m.WebsocketDataHandler(exchName, errors.New("error"))
|
||||
err = m.websocketDataHandler(exchName, errors.New("error"))
|
||||
if err == nil {
|
||||
t.Error("Error not handled correctly")
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected nil data error")
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, stream.FundingData{})
|
||||
err = m.websocketDataHandler(exchName, stream.FundingData{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, &ticker.Price{
|
||||
err = m.websocketDataHandler(exchName, &ticker.Price{
|
||||
ExchangeName: exchName,
|
||||
Pair: currency.NewPair(currency.BTC, currency.USDC),
|
||||
AssetType: asset.Spot,
|
||||
@@ -172,7 +168,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
if !errors.Is(err, nil) {
|
||||
t.Errorf("error '%v', expected '%v'", err, nil)
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, stream.KlineData{})
|
||||
err = m.websocketDataHandler(exchName, stream.KlineData{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -182,12 +178,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
Amount: 1337,
|
||||
Price: 1337,
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, origOrder)
|
||||
err = m.websocketDataHandler(exchName, origOrder)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// Send it again since it exists now
|
||||
err = m.WebsocketDataHandler(exchName, &order.Detail{
|
||||
err = m.websocketDataHandler(exchName, &order.Detail{
|
||||
Exchange: exchName,
|
||||
ID: orderID,
|
||||
Amount: 1338,
|
||||
@@ -203,7 +199,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
t.Error("Bad pipeline")
|
||||
}
|
||||
|
||||
err = m.WebsocketDataHandler(exchName, &order.Modify{
|
||||
err = m.websocketDataHandler(exchName, &order.Modify{
|
||||
Exchange: "Bitstamp",
|
||||
ID: orderID,
|
||||
Status: order.Active,
|
||||
@@ -220,12 +216,12 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Send some gibberish
|
||||
err = m.WebsocketDataHandler(exchName, order.Stop)
|
||||
err = m.websocketDataHandler(exchName, order.Stop)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
err = m.WebsocketDataHandler(exchName, stream.UnhandledMessageWarning{
|
||||
err = m.websocketDataHandler(exchName, stream.UnhandledMessageWarning{
|
||||
Message: "there's an issue here's a tissue"},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -237,7 +233,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
OrderID: "one",
|
||||
Err: errors.New("lol"),
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, classificationError)
|
||||
err = m.websocketDataHandler(exchName, classificationError)
|
||||
if err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
@@ -245,15 +241,116 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) {
|
||||
t.Errorf("error '%v', expected '%v'", err, classificationError.Err)
|
||||
}
|
||||
|
||||
err = m.WebsocketDataHandler(exchName, &orderbook.Base{
|
||||
err = m.websocketDataHandler(exchName, &orderbook.Base{
|
||||
Exchange: "Bitstamp",
|
||||
Pair: currency.NewPair(currency.BTC, currency.USD),
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = m.WebsocketDataHandler(exchName, "this is a test string")
|
||||
err = m.websocketDataHandler(exchName, "this is a test string")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) {
|
||||
t.Parallel()
|
||||
var m *websocketRoutineManager
|
||||
err := m.registerWebsocketDataHandler(nil, false)
|
||||
if !errors.Is(err, ErrNilSubsystem) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem)
|
||||
}
|
||||
|
||||
m = new(websocketRoutineManager)
|
||||
m.shutdown = make(chan struct{})
|
||||
|
||||
err = m.registerWebsocketDataHandler(nil, false)
|
||||
if !errors.Is(err, errNilWebsocketDataHandlerFunction) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction)
|
||||
}
|
||||
|
||||
// externally defined capture device
|
||||
dataChan := make(chan interface{})
|
||||
fn := func(_ string, data interface{}) error {
|
||||
switch data.(type) {
|
||||
case string:
|
||||
dataChan <- data
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = m.registerWebsocketDataHandler(fn, true)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(m.dataHandlers) != 1 {
|
||||
t.Fatal("unexpected data handlers registered")
|
||||
}
|
||||
|
||||
mock := stream.New()
|
||||
mock.ToRoutine = make(chan interface{})
|
||||
m.started = 1
|
||||
err = m.websocketDataReceiver(mock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mock.ToRoutine <- nil
|
||||
mock.ToRoutine <- 1336
|
||||
mock.ToRoutine <- "intercepted"
|
||||
|
||||
if r := <-dataChan; r != "intercepted" {
|
||||
t.Fatal("unexpected value received")
|
||||
}
|
||||
|
||||
close(m.shutdown)
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
func TestSetWebsocketDataHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
var m *websocketRoutineManager
|
||||
err := m.setWebsocketDataHandler(nil)
|
||||
if !errors.Is(err, ErrNilSubsystem) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem)
|
||||
}
|
||||
|
||||
m = new(websocketRoutineManager)
|
||||
m.shutdown = make(chan struct{})
|
||||
|
||||
err = m.setWebsocketDataHandler(nil)
|
||||
if !errors.Is(err, errNilWebsocketDataHandlerFunction) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction)
|
||||
}
|
||||
|
||||
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
err = m.registerWebsocketDataHandler(m.websocketDataHandler, false)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(m.dataHandlers) != 3 {
|
||||
t.Fatal("unexpected data handler count")
|
||||
}
|
||||
|
||||
err = m.setWebsocketDataHandler(m.websocketDataHandler)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if len(m.dataHandlers) != 1 {
|
||||
t.Fatal("unexpected data handler count")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,16 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
)
|
||||
|
||||
var (
|
||||
errNilOrderManager = errors.New("nil order manager received")
|
||||
errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received")
|
||||
errNilCurrencyConfig = errors.New("nil currency config received")
|
||||
errNilCurrencyPairFormat = errors.New("nil currency pair format received")
|
||||
errNilWebsocketDataHandlerFunction = errors.New("websocket data handler function is nil")
|
||||
errNilWebsocket = errors.New("websocket is nil")
|
||||
errRoutineManagerNotStarted = errors.New("websocket routine manager not started")
|
||||
)
|
||||
|
||||
// websocketRoutineManager is used to process websocket updates from a unified location
|
||||
type websocketRoutineManager struct {
|
||||
started int32
|
||||
@@ -16,12 +26,11 @@ type websocketRoutineManager struct {
|
||||
syncer iCurrencyPairSyncer
|
||||
currencyConfig *currency.Config
|
||||
shutdown chan struct{}
|
||||
dataHandlers []WebsocketDataHandler
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
errNilOrderManager = errors.New("nil order manager received")
|
||||
errNilCurrencyPairSyncer = errors.New("nil currency pair syncer received")
|
||||
errNilCurrencyConfig = errors.New("nil currency config received")
|
||||
errNilCurrencyPairFormat = errors.New("nil currency pair format received")
|
||||
)
|
||||
// WebsocketDataHandler defines a function signature for a function that handles
|
||||
// data coming from websocket connections.
|
||||
type WebsocketDataHandler func(service string, incoming interface{}) error
|
||||
|
||||
Reference in New Issue
Block a user