mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-18 23:16:49 +00:00
accounts: Move to instance methods, fix races and isolate tests (#1923)
* Bybit: Fix race in TestUpdateAccountInfo and TestWSHandleData * DriveBy rename TestWSHandleData * This doesn't address running with -race=2+ due to the singleton * Accounts: Add account.GetService() * exchange: Assertify TestSetupDefaults * Exchanges: Add account.Service override for testing * Exchanges: Remove duplicate IsWebsocketEnabled test from TestSetupDefaults * Dispatch: Replace nil checks with NilGuard * Engine: Remove deprecated printAccountHoldingsChangeSummary * Dispatcher: Add EnsureRunning method * Accounts: Move singleton accounts service to exchange Accounts * Move singleton accounts service to exchange Accounts This maintains the concept of a global store, whilst allowing exchanges to override it when needed, particularly for testing. APIServer: * Remove getAllActiveAccounts from apiserver Deprecated apiserver only thing using this, so remove it instead of updating it * Update comment for UpdateAccountBalances everywhere * Docs: Add punctuation to function comments * Bybit: Coverage for wsProcessWalletPushData Save
This commit is contained in:
@@ -8,15 +8,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
// Public errors.
|
||||
var (
|
||||
// ErrNotRunning defines an error when the dispatcher is not running
|
||||
ErrNotRunning = errors.New("dispatcher not running")
|
||||
ErrNotRunning = errors.New("dispatcher not running")
|
||||
ErrDispatcherAlreadyRunning = errors.New("dispatcher already running")
|
||||
)
|
||||
|
||||
errDispatcherNotInitialized = errors.New("dispatcher not initialised")
|
||||
errDispatcherAlreadyRunning = errors.New("dispatcher already running")
|
||||
var (
|
||||
errDispatchShutdown = errors.New("dispatcher did not shutdown properly, routines failed to close")
|
||||
errDispatcherUUIDNotFoundInRouteList = errors.New("dispatcher uuid not found in route list")
|
||||
errTypeAssertionFailure = errors.New("type assertion failure")
|
||||
@@ -29,7 +31,7 @@ var (
|
||||
limitMessage = "%w [%d] current worker count [%d]. Spawn more workers via --dispatchworkers=x, or increase the jobs limit via --dispatchjobslimit=x"
|
||||
)
|
||||
|
||||
// Name is an exported subsystem name
|
||||
// Name is an exported subsystem name.
|
||||
const Name = "dispatch"
|
||||
|
||||
func init() {
|
||||
@@ -41,59 +43,58 @@ func NewDispatcher() *Dispatcher {
|
||||
return &Dispatcher{
|
||||
routes: make(map[uuid.UUID][]chan any),
|
||||
outbound: sync.Pool{
|
||||
New: getChan,
|
||||
New: func() any { return make(chan any) },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getChan() any {
|
||||
// Create unbuffered channel for data pass
|
||||
return make(chan any)
|
||||
}
|
||||
|
||||
// Start starts the dispatch system by spawning workers and allocating memory
|
||||
// Start starts the dispatch system and spawns workers.
|
||||
func Start(workers, jobsLimit int) error {
|
||||
dispatcher.m.Lock()
|
||||
defer dispatcher.m.Unlock()
|
||||
return dispatcher.start(workers, jobsLimit)
|
||||
}
|
||||
|
||||
// Stop attempts to stop the dispatch service, this will close all pipe channels
|
||||
// flush job list and drop all workers
|
||||
// EnsureRunning starts the global dispatcher if it's not already running.
|
||||
func EnsureRunning(workers, jobsLimit int) error {
|
||||
dispatcher.m.Lock()
|
||||
defer dispatcher.m.Unlock()
|
||||
if dispatcher.running {
|
||||
return nil
|
||||
}
|
||||
return dispatcher.start(workers, jobsLimit)
|
||||
}
|
||||
|
||||
// Stop will halt the dispatch service.
|
||||
func Stop() error {
|
||||
log.Debugln(log.DispatchMgr, "Dispatch manager shutting down...")
|
||||
return dispatcher.stop()
|
||||
}
|
||||
|
||||
// IsRunning checks to see if the dispatch service is running
|
||||
// IsRunning checks to see if the dispatch service is running.
|
||||
func IsRunning() bool {
|
||||
return dispatcher.isRunning()
|
||||
}
|
||||
|
||||
// start compares atomic running value, sets defaults, overrides with
|
||||
// configuration, then spawns workers
|
||||
// start sets defaults and config and spawns workers.
|
||||
// Does not provide locking protection.
|
||||
func (d *Dispatcher) start(workers, channelCapacity int) error {
|
||||
if d == nil {
|
||||
return errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.m.Lock()
|
||||
defer d.m.Unlock()
|
||||
|
||||
if d.running {
|
||||
return errDispatcherAlreadyRunning
|
||||
return ErrDispatcherAlreadyRunning
|
||||
}
|
||||
|
||||
d.running = true
|
||||
|
||||
if workers < 1 {
|
||||
log.Warnf(log.DispatchMgr,
|
||||
"workers cannot be zero, using default value %d\n",
|
||||
DefaultMaxWorkers)
|
||||
log.Warnf(log.DispatchMgr, "Dispatcher workers cannot be zero, using default value %d\n", DefaultMaxWorkers)
|
||||
workers = DefaultMaxWorkers
|
||||
}
|
||||
if channelCapacity < 1 {
|
||||
log.Warnf(log.DispatchMgr,
|
||||
"jobs limit cannot be zero, using default values %d\n",
|
||||
DefaultJobsLimit)
|
||||
log.Warnf(log.DispatchMgr, "Dispatcher jobs limit cannot be zero, using default values %d\n", DefaultJobsLimit)
|
||||
channelCapacity = DefaultJobsLimit
|
||||
}
|
||||
d.jobs = make(chan job, channelCapacity)
|
||||
@@ -107,10 +108,10 @@ func (d *Dispatcher) start(workers, channelCapacity int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop stops the service and shuts down all worker routines
|
||||
// stop stops the service and shuts down all worker routines.
|
||||
func (d *Dispatcher) stop() error {
|
||||
if d == nil {
|
||||
return errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.m.Lock()
|
||||
@@ -155,7 +156,7 @@ func (d *Dispatcher) stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRunning returns if the dispatch system is running
|
||||
// isRunning returns if the dispatch system is running.
|
||||
func (d *Dispatcher) isRunning() bool {
|
||||
if d == nil {
|
||||
return false
|
||||
@@ -166,7 +167,7 @@ func (d *Dispatcher) isRunning() bool {
|
||||
return d.running
|
||||
}
|
||||
|
||||
// relayer routine relays communications across the defined routes
|
||||
// relayer routine relays communications across the defined routes.
|
||||
func (d *Dispatcher) relayer() {
|
||||
for {
|
||||
select {
|
||||
@@ -201,20 +202,16 @@ func (d *Dispatcher) relayer() {
|
||||
}
|
||||
}
|
||||
|
||||
// publish relays data to the subscribed subsystems
|
||||
// publish relays data to the subscribed subsystems.
|
||||
func (d *Dispatcher) publish(id uuid.UUID, data any) error {
|
||||
if d == nil {
|
||||
return errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id.IsNil() {
|
||||
return errIDNotSet
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return errNoData
|
||||
}
|
||||
|
||||
d.m.RLock()
|
||||
defer d.m.RUnlock()
|
||||
|
||||
@@ -226,18 +223,14 @@ func (d *Dispatcher) publish(id uuid.UUID, data any) error {
|
||||
case d.jobs <- job{data, id}: // Push job into job channel.
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf(limitMessage,
|
||||
errDispatcherJobsAtLimit,
|
||||
len(d.jobs),
|
||||
d.maxWorkers)
|
||||
return fmt.Errorf(limitMessage, errDispatcherJobsAtLimit, len(d.jobs), d.maxWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe subscribes a system and returns a communication chan, this does not
|
||||
// ensure initial push.
|
||||
// Subscribe subscribes a system and returns a communication chan, this does not ensure initial push.
|
||||
func (d *Dispatcher) subscribe(id uuid.UUID) (chan any, error) {
|
||||
if d == nil {
|
||||
return nil, errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if id.IsNil() {
|
||||
@@ -268,10 +261,10 @@ func (d *Dispatcher) subscribe(id uuid.UUID) (chan any, error) {
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubs a routine from the dispatcher
|
||||
// Unsubscribe unsubs a routine from the dispatcher.
|
||||
func (d *Dispatcher) unsubscribe(id uuid.UUID, usedChan chan any) error {
|
||||
if d == nil {
|
||||
return errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if id.IsNil() {
|
||||
@@ -321,10 +314,10 @@ func (d *Dispatcher) unsubscribe(id uuid.UUID, usedChan chan any) error {
|
||||
return errChannelNotFoundInUUIDRef
|
||||
}
|
||||
|
||||
// GetNewID returns a new ID
|
||||
// GetNewID returns a new ID.
|
||||
func (d *Dispatcher) getNewID(genFn func() (uuid.UUID, error)) (uuid.UUID, error) {
|
||||
if d == nil {
|
||||
return uuid.Nil, errDispatcherNotInitialized
|
||||
if err := common.NilGuard(d); err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
if genFn == nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -22,8 +23,17 @@ func TestGlobalDispatcher(t *testing.T) {
|
||||
assert.True(t, IsRunning(), "IsRunning should return true")
|
||||
|
||||
err = Stop()
|
||||
assert.NoError(t, err, "Stop should not error")
|
||||
require.NoError(t, err, "Stop must not error")
|
||||
assert.False(t, IsRunning(), "IsRunning should return false")
|
||||
|
||||
err = EnsureRunning(0, 0)
|
||||
require.NoError(t, err, "EnsureRunning must not error when starting")
|
||||
assert.True(t, IsRunning(), "IsRunning should return true after EnsureRunning")
|
||||
|
||||
err = EnsureRunning(0, 0)
|
||||
require.NoError(t, err, "EnsureRunning must not error when called twice")
|
||||
|
||||
assert.NoError(t, Stop(), "Stop should not error")
|
||||
}
|
||||
|
||||
func TestStartStop(t *testing.T) {
|
||||
@@ -33,10 +43,10 @@ func TestStartStop(t *testing.T) {
|
||||
assert.False(t, d.isRunning(), "IsRunning should return false")
|
||||
|
||||
err := d.stop()
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "stop should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "stop should error correctly")
|
||||
|
||||
err = d.start(10, 0)
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "start should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "start should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
@@ -49,7 +59,7 @@ func TestStartStop(t *testing.T) {
|
||||
assert.True(t, d.isRunning(), "IsRunning should return true")
|
||||
|
||||
err = d.start(0, 0)
|
||||
assert.ErrorIs(t, err, errDispatcherAlreadyRunning, "start should error correctly")
|
||||
assert.ErrorIs(t, err, ErrDispatcherAlreadyRunning, "start should error correctly")
|
||||
|
||||
// Add route option
|
||||
id, err := d.getNewID(uuid.NewV4)
|
||||
@@ -74,7 +84,7 @@ func TestSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var d *Dispatcher
|
||||
_, err := d.subscribe(uuid.Nil)
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "subscribe should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "subscribe should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
@@ -97,7 +107,7 @@ func TestSubscribe(t *testing.T) {
|
||||
_, err = d.subscribe(id)
|
||||
assert.ErrorIs(t, err, errTypeAssertionFailure, "subscribe should error correctly")
|
||||
|
||||
d.outbound.New = getChan
|
||||
d.outbound.New = func() any { return make(chan any) }
|
||||
ch, err := d.subscribe(id)
|
||||
assert.NoError(t, err, "subscribe should not error")
|
||||
assert.NotNil(t, ch, "Channel should not be nil")
|
||||
@@ -108,7 +118,7 @@ func TestUnsubscribe(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
err := d.unsubscribe(uuid.Nil, nil)
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "unsubscribe should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "unsubscribe should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
@@ -152,7 +162,7 @@ func TestPublish(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
err := d.publish(uuid.Nil, nil)
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "publish should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "publish should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
@@ -162,11 +172,11 @@ func TestPublish(t *testing.T) {
|
||||
err = d.start(2, 10)
|
||||
require.NoError(t, err, "start must not error")
|
||||
|
||||
err = d.publish(uuid.Nil, nil)
|
||||
err = d.publish(uuid.Nil, "test")
|
||||
assert.ErrorIs(t, err, errIDNotSet, "publish should error correctly")
|
||||
|
||||
err = d.publish(nonEmptyUUID, nil)
|
||||
assert.ErrorIs(t, err, errNoData, "publish should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "publish should error correctly")
|
||||
|
||||
// demonstrate job limit error
|
||||
d.routes[nonEmptyUUID] = []chan any{
|
||||
@@ -209,7 +219,7 @@ func TestGetNewID(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
_, err := d.getNewID(uuid.NewV4)
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "getNewID should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "getNewID should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
@@ -233,16 +243,16 @@ func TestMux(t *testing.T) {
|
||||
t.Parallel()
|
||||
var mux *Mux
|
||||
_, err := mux.Subscribe(uuid.Nil)
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Subscribe should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Subscribe should error correctly")
|
||||
|
||||
err = mux.Unsubscribe(uuid.Nil, nil)
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Unsubscribe should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Unsubscribe should error correctly")
|
||||
|
||||
err = mux.Publish(nil)
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Publish should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Publish should error correctly")
|
||||
|
||||
_, err = mux.GetID()
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "GetID should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "GetID should error correctly")
|
||||
|
||||
d := NewDispatcher()
|
||||
err = d.start(0, 0)
|
||||
@@ -251,7 +261,7 @@ func TestMux(t *testing.T) {
|
||||
mux = GetNewMux(d)
|
||||
|
||||
err = mux.Publish(nil)
|
||||
assert.ErrorIs(t, err, errNoData, "Publish should error correctly")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Publish should error correctly")
|
||||
|
||||
err = mux.Publish("lol")
|
||||
assert.ErrorIs(t, err, errNoIDs, "Publish should error correctly")
|
||||
|
||||
@@ -5,12 +5,11 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
)
|
||||
|
||||
var (
|
||||
errMuxIsNil = errors.New("mux is nil")
|
||||
errIDNotSet = errors.New("id not set")
|
||||
errNoData = errors.New("data payload is nil")
|
||||
errNoIDs = errors.New("no IDs to publish data to")
|
||||
)
|
||||
|
||||
@@ -26,8 +25,8 @@ func GetNewMux(d *Dispatcher) *Mux {
|
||||
// Subscribe takes in a package defined signature element pointing to an ID set
|
||||
// and returns the associated pipe
|
||||
func (m *Mux) Subscribe(id uuid.UUID) (Pipe, error) {
|
||||
if m == nil {
|
||||
return Pipe{}, errMuxIsNil
|
||||
if err := common.NilGuard(m); err != nil {
|
||||
return Pipe{}, err
|
||||
}
|
||||
|
||||
if id.IsNil() {
|
||||
@@ -44,8 +43,8 @@ func (m *Mux) Subscribe(id uuid.UUID) (Pipe, error) {
|
||||
|
||||
// Unsubscribe returns channel to the pool for the full signature set
|
||||
func (m *Mux) Unsubscribe(id uuid.UUID, ch chan any) error {
|
||||
if m == nil {
|
||||
return errMuxIsNil
|
||||
if err := common.NilGuard(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.d.unsubscribe(id, ch)
|
||||
}
|
||||
@@ -53,12 +52,8 @@ func (m *Mux) Unsubscribe(id uuid.UUID, ch chan any) error {
|
||||
// Publish takes in a persistent memory address and dispatches changes to
|
||||
// required pipes.
|
||||
func (m *Mux) Publish(data any, ids ...uuid.UUID) error {
|
||||
if m == nil {
|
||||
return errMuxIsNil
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return errNoData
|
||||
if err := common.NilGuard(m, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
@@ -69,8 +64,7 @@ func (m *Mux) Publish(data any, ids ...uuid.UUID) error {
|
||||
}
|
||||
|
||||
for i := range ids {
|
||||
err := m.d.publish(ids[i], data)
|
||||
if err != nil {
|
||||
if err := m.d.publish(ids[i], data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -79,8 +73,8 @@ func (m *Mux) Publish(data any, ids ...uuid.UUID) error {
|
||||
|
||||
// GetID a new unique ID to track routing information in the dispatch system
|
||||
func (m *Mux) GetID() (uuid.UUID, error) {
|
||||
if m == nil {
|
||||
return uuid.UUID{}, errMuxIsNil
|
||||
if err := common.NilGuard(m); err != nil {
|
||||
return uuid.UUID{}, err
|
||||
}
|
||||
return m.d.getNewID(uuid.NewV4)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user