diff --git a/engine/engine.go b/engine/engine.go index 8daddc11..eb2b18ce 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -44,7 +44,7 @@ type Engine struct { OrderManager *OrderManager portfolioManager *portfolioManager gctScriptManager *gctscript.GctScriptManager - websocketRoutineManager *websocketRoutineManager + WebsocketRoutineManager *WebsocketRoutineManager WithdrawManager *WithdrawManager dataHistoryManager *DataHistoryManager currencyStateManager *CurrencyStateManager @@ -542,11 +542,11 @@ func (bot *Engine) Start() error { } if bot.Settings.EnableWebsocketRoutine { - bot.websocketRoutineManager, err = setupWebsocketRoutineManager(bot.ExchangeManager, bot.OrderManager, bot.currencyPairSyncer, &bot.Config.Currency, bot.Settings.Verbose) + bot.WebsocketRoutineManager, err = setupWebsocketRoutineManager(bot.ExchangeManager, bot.OrderManager, bot.currencyPairSyncer, &bot.Config.Currency, bot.Settings.Verbose) if err != nil { gctlog.Errorf(gctlog.Global, "Unable to initialise websocket routine manager. Err: %s", err) } else { - err = bot.websocketRoutineManager.Start() + err = bot.WebsocketRoutineManager.Start() if err != nil { gctlog.Errorf(gctlog.Global, "failed to start websocket routine manager. Err: %s", err) } @@ -656,8 +656,8 @@ func (bot *Engine) Stop() { gctlog.Errorf(gctlog.DispatchMgr, "Dispatch system unable to stop. Error: %v", err) } } - if bot.websocketRoutineManager.IsRunning() { - if err := bot.websocketRoutineManager.Stop(); err != nil { + if bot.WebsocketRoutineManager.IsRunning() { + if err := bot.WebsocketRoutineManager.Stop(); err != nil { gctlog.Errorf(gctlog.Global, "websocket routine manager unable to stop. Error: %v", err) } } @@ -957,7 +957,7 @@ func (bot *Engine) RegisterWebsocketDataHandler(fn WebsocketDataHandler, interce if bot == nil { return errNilBot } - return bot.websocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly) + return bot.WebsocketRoutineManager.registerWebsocketDataHandler(fn, interceptorOnly) } // SetDefaultWebsocketDataHandler sets the default websocket handler and @@ -966,7 +966,7 @@ func (bot *Engine) SetDefaultWebsocketDataHandler() error { if bot == nil { return errNilBot } - return bot.websocketRoutineManager.setWebsocketDataHandler(bot.websocketRoutineManager.websocketDataHandler) + return bot.WebsocketRoutineManager.setWebsocketDataHandler(bot.WebsocketRoutineManager.websocketDataHandler) } // waitForGPRCShutdown routines waits for a signal from the grpc server to diff --git a/engine/engine_test.go b/engine/engine_test.go index 92d59c75..f408a8d0 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -308,7 +308,7 @@ func TestRegisterWebsocketDataHandler(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) } - e = &Engine{websocketRoutineManager: &websocketRoutineManager{}} + e = &Engine{WebsocketRoutineManager: &WebsocketRoutineManager{}} err = e.RegisterWebsocketDataHandler(func(_ string, _ interface{}) error { return nil }, false) if !errors.Is(err, nil) { t.Fatalf("received: '%v' but expected: '%v'", err, nil) @@ -323,7 +323,7 @@ func TestSetDefaultWebsocketDataHandler(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) } - e = &Engine{websocketRoutineManager: &websocketRoutineManager{}} + e = &Engine{WebsocketRoutineManager: &WebsocketRoutineManager{}} err = e.SetDefaultWebsocketDataHandler() if !errors.Is(err, nil) { t.Fatalf("received: '%v' but expected: '%v'", err, nil) diff --git a/engine/websocketroutine_manager.go b/engine/websocketroutine_manager.go index 1f60c0ed..a27d61f9 100644 --- a/engine/websocketroutine_manager.go +++ b/engine/websocketroutine_manager.go @@ -2,6 +2,7 @@ package engine import ( "fmt" + "sync" "sync/atomic" "github.com/thrasher-corp/gocryptotrader/common" @@ -17,7 +18,7 @@ import ( ) // setupWebsocketRoutineManager creates a new websocket routine manager -func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager iOrderManager, syncer iCurrencyPairSyncer, cfg *currency.Config, verbose bool) (*websocketRoutineManager, error) { +func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager iOrderManager, syncer iCurrencyPairSyncer, cfg *currency.Config, verbose bool) (*WebsocketRoutineManager, error) { if exchangeManager == nil { return nil, errNilExchangeManager } @@ -33,19 +34,18 @@ func setupWebsocketRoutineManager(exchangeManager iExchangeManager, orderManager if cfg.CurrencyPairFormat == nil { return nil, errNilCurrencyPairFormat } - man := &websocketRoutineManager{ + man := &WebsocketRoutineManager{ verbose: verbose, exchangeManager: exchangeManager, orderManager: orderManager, syncer: syncer, currencyConfig: cfg, - shutdown: make(chan struct{}), } return man, man.registerWebsocketDataHandler(man.websocketDataHandler, false) } // Start runs the subsystem -func (m *websocketRoutineManager) Start() error { +func (m *WebsocketRoutineManager) Start() error { if m == nil { return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem) } @@ -58,37 +58,50 @@ func (m *websocketRoutineManager) Start() error { return errNilCurrencyPairFormat } - if !atomic.CompareAndSwapInt32(&m.started, 0, 1) { + if !atomic.CompareAndSwapInt32(&m.state, stoppedState, startingState) { return ErrSubSystemAlreadyStarted } + m.shutdown = make(chan struct{}) - m.websocketRoutine() + + go func() { + m.websocketRoutine() + // It's okay for this to fail, just means shutdown has started + atomic.CompareAndSwapInt32(&m.state, startingState, readyState) + }() return nil } // IsRunning safely checks whether the subsystem is running -func (m *websocketRoutineManager) IsRunning() bool { +func (m *WebsocketRoutineManager) IsRunning() bool { if m == nil { return false } - return atomic.LoadInt32(&m.started) == 1 + return atomic.LoadInt32(&m.state) == readyState } // Stop attempts to shutdown the subsystem -func (m *websocketRoutineManager) Stop() error { +func (m *WebsocketRoutineManager) Stop() error { if m == nil { return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem) } - if !atomic.CompareAndSwapInt32(&m.started, 1, 0) { + + m.mu.Lock() + if atomic.LoadInt32(&m.state) == stoppedState { + m.mu.Unlock() return fmt.Errorf("websocket routine manager %w", ErrSubSystemNotStarted) } + atomic.StoreInt32(&m.state, stoppedState) + m.mu.Unlock() + close(m.shutdown) m.wg.Wait() + return nil } // websocketRoutine Initial routine management system for websocket -func (m *websocketRoutineManager) websocketRoutine() { +func (m *WebsocketRoutineManager) websocketRoutine() { if m.verbose { log.Debugln(log.WebsocketMgr, "Connecting exchange websocket services...") } @@ -96,8 +109,11 @@ func (m *websocketRoutineManager) websocketRoutine() { if err != nil { log.Errorf(log.WebsocketMgr, "websocket routine manager cannot get exchanges: %v", err) } + wg := sync.WaitGroup{} + wg.Add(len(exchanges)) for i := range exchanges { go func(i int) { + defer wg.Done() if exchanges[i].SupportsWebsocket() { if m.verbose { log.Debugf(log.WebsocketMgr, @@ -142,11 +158,12 @@ func (m *websocketRoutineManager) websocketRoutine() { } }(i) } + wg.Wait() } // WebsocketDataReceiver handles websocket data coming from a websocket feed // associated with an exchange -func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error { +func (m *WebsocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) error { if m == nil { return fmt.Errorf("websocket routine manager %w", ErrNilSubsystem) } @@ -155,7 +172,7 @@ func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) er return errNilWebsocket } - if atomic.LoadInt32(&m.started) == 0 { + if atomic.LoadInt32(&m.state) == stoppedState { return errRoutineManagerNotStarted } @@ -187,7 +204,7 @@ func (m *websocketRoutineManager) websocketDataReceiver(ws *stream.Websocket) er // websocketDataHandler is the default central point for exchange websocket // implementations to send processed data which will then pass that to an // appropriate handler. -func (m *websocketRoutineManager) websocketDataHandler(exchName string, data interface{}) error { +func (m *WebsocketRoutineManager) websocketDataHandler(exchName string, data interface{}) error { switch d := data.(type) { case string: log.Infoln(log.WebsocketMgr, d) @@ -293,8 +310,8 @@ func (m *websocketRoutineManager) websocketDataHandler(exchName string, data int // FormatCurrency is a method that formats and returns a currency pair // based on the user currency display preferences -func (m *websocketRoutineManager) FormatCurrency(p currency.Pair) currency.Pair { - if m == nil || atomic.LoadInt32(&m.started) == 0 { +func (m *WebsocketRoutineManager) FormatCurrency(p currency.Pair) currency.Pair { + if m == nil || atomic.LoadInt32(&m.state) == stoppedState { return p } return p.Format(*m.currencyConfig.CurrencyPairFormat) @@ -302,8 +319,8 @@ func (m *websocketRoutineManager) FormatCurrency(p currency.Pair) currency.Pair // printOrderSummary this function will be deprecated when a order manager // update is done. -func (m *websocketRoutineManager) printOrderSummary(o *order.Detail, isUpdate bool) { - if m == nil || atomic.LoadInt32(&m.started) == 0 || o == nil { +func (m *WebsocketRoutineManager) printOrderSummary(o *order.Detail, isUpdate bool) { + if m == nil || atomic.LoadInt32(&m.state) == stoppedState || o == nil { return } @@ -331,8 +348,8 @@ func (m *websocketRoutineManager) printOrderSummary(o *order.Detail, isUpdate bo // printAccountHoldingsChangeSummary this function will be deprecated when a // account holdings update is done. -func (m *websocketRoutineManager) printAccountHoldingsChangeSummary(o account.Change) { - if m == nil || atomic.LoadInt32(&m.started) == 0 { +func (m *WebsocketRoutineManager) printAccountHoldingsChangeSummary(o account.Change) { + if m == nil || atomic.LoadInt32(&m.state) == stoppedState { return } log.Debugf(log.WebsocketMgr, @@ -348,7 +365,7 @@ func (m *websocketRoutineManager) printAccountHoldingsChangeSummary(o account.Ch // dedicated filter specific data types for internal & external strategy use. // InterceptorOnly as true will purge all other registered handlers // (including default) bypassing all other handling. -func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error { +func (m *WebsocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataHandler, interceptorOnly bool) error { if m == nil { return fmt.Errorf("%T %w", m, ErrNilSubsystem) } @@ -371,7 +388,7 @@ func (m *websocketRoutineManager) registerWebsocketDataHandler(fn WebsocketDataH // setWebsocketDataHandler sets a single websocket data handler, removing all // pre-existing handlers. -func (m *websocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error { +func (m *WebsocketRoutineManager) setWebsocketDataHandler(fn WebsocketDataHandler) error { if m == nil { return fmt.Errorf("%T %w", m, ErrNilSubsystem) } diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index 45eca731..3d175334 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -3,7 +3,9 @@ package engine import ( "errors" "sync" + "sync/atomic" "testing" + "time" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -48,7 +50,7 @@ func TestWebsocketRoutineManagerSetup(t *testing.T) { } func TestWebsocketRoutineManagerStart(t *testing.T) { - var m *websocketRoutineManager + var m *WebsocketRoutineManager err := m.Start() if !errors.Is(err, ErrNilSubsystem) { t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) @@ -72,7 +74,7 @@ func TestWebsocketRoutineManagerStart(t *testing.T) { } func TestWebsocketRoutineManagerIsRunning(t *testing.T) { - var m *websocketRoutineManager + var m *WebsocketRoutineManager if m.IsRunning() { t.Error("expected false") } @@ -89,13 +91,16 @@ func TestWebsocketRoutineManagerIsRunning(t *testing.T) { if !errors.Is(err, nil) { t.Errorf("error '%v', expected '%v'", err, nil) } + for atomic.LoadInt32(&m.state) == startingState { + <-time.After(time.Second / 100) + } if !m.IsRunning() { t.Error("expected true") } } func TestWebsocketRoutineManagerStop(t *testing.T) { - var m *websocketRoutineManager + var m *WebsocketRoutineManager err := m.Stop() if !errors.Is(err, ErrNilSubsystem) { t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) @@ -258,13 +263,13 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Parallel() - var m *websocketRoutineManager + var m *WebsocketRoutineManager err := m.registerWebsocketDataHandler(nil, false) if !errors.Is(err, ErrNilSubsystem) { t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) } - m = new(websocketRoutineManager) + m = new(WebsocketRoutineManager) m.shutdown = make(chan struct{}) err = m.registerWebsocketDataHandler(nil, false) @@ -294,7 +299,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { mock := stream.New() mock.ToRoutine = make(chan interface{}) - m.started = 1 + m.state = readyState err = m.websocketDataReceiver(mock) if err != nil { t.Fatal(err) @@ -314,13 +319,13 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { func TestSetWebsocketDataHandler(t *testing.T) { t.Parallel() - var m *websocketRoutineManager + var m *WebsocketRoutineManager err := m.setWebsocketDataHandler(nil) if !errors.Is(err, ErrNilSubsystem) { t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) } - m = new(websocketRoutineManager) + m = new(WebsocketRoutineManager) m.shutdown = make(chan struct{}) err = m.setWebsocketDataHandler(nil) diff --git a/engine/websocketroutine_manager_types.go b/engine/websocketroutine_manager_types.go index a00f710e..c6d2bccf 100644 --- a/engine/websocketroutine_manager_types.go +++ b/engine/websocketroutine_manager_types.go @@ -17,9 +17,15 @@ var ( errRoutineManagerNotStarted = errors.New("websocket routine manager not started") ) -// websocketRoutineManager is used to process websocket updates from a unified location -type websocketRoutineManager struct { - started int32 +const ( + stoppedState int32 = iota + startingState + readyState +) + +// WebsocketRoutineManager is used to process websocket updates from a unified location +type WebsocketRoutineManager struct { + state int32 verbose bool exchangeManager iExchangeManager orderManager iOrderManager