accounts: Move to instance methods, fix races and isolate tests (#1923)

* Bybit: Fix race in TestUpdateAccountInfo and  TestWSHandleData

* DriveBy rename TestWSHandleData
* This doesn't address running with -race=2+ due to the singleton

* Accounts: Add account.GetService()

* exchange: Assertify TestSetupDefaults

* Exchanges: Add account.Service override for testing

* Exchanges: Remove duplicate IsWebsocketEnabled test from TestSetupDefaults

* Dispatch: Replace nil checks with NilGuard

* Engine: Remove deprecated printAccountHoldingsChangeSummary

* Dispatcher: Add EnsureRunning method

* Accounts: Move singleton accounts service to exchange Accounts

* Move singleton accounts service to exchange Accounts

This maintains the concept of a global store, whilst allowing exchanges
to override it when needed, particularly for testing.

APIServer:

* Remove getAllActiveAccounts from apiserver

Deprecated apiserver only thing using this, so remove it instead of
updating it

* Update comment for UpdateAccountBalances everywhere

* Docs: Add punctuation to function comments

* Bybit: Coverage for wsProcessWalletPushData Save
This commit is contained in:
Gareth Kirwan
2025-10-28 09:52:45 +07:00
committed by GitHub
parent bda9bbec66
commit 73e200e4e7
140 changed files with 3515 additions and 4025 deletions

View File

@@ -0,0 +1,312 @@
package accounts
import (
"context"
"errors"
"fmt"
"maps"
"slices"
"sync"
"time"
"github.com/gofrs/uuid"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/dispatch"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
)
// Public errors.
var (
ErrNoBalances = errors.New("no balances found")
ErrNoSubAccounts = errors.New("no subAccounts found")
)
var (
errCredentialsEmpty = errors.New("no credentials provided")
errUpdatingBalance = errors.New("error updating balance")
errPublish = errors.New("error publishing account changes")
)
// Accounts holds a stream ID and a map to the exchange holdings.
type Accounts struct {
Exchange exchange
routingID uuid.UUID // GCT internal routing mux id
subAccounts credSubAccounts
mu sync.RWMutex
mux *dispatch.Mux
}
type (
credSubAccounts map[Credentials]subAccounts
subAccounts map[key.SubAccountAsset]currencyBalances
)
// SubAccount contains an account for an asset type and its balances.
// The SubAccount may be the main account depending on exchange structure.
type SubAccount struct {
ID string
AssetType asset.Item
Balances CurrencyBalances
}
// SubAccounts contains a list of public SubAccounts.
type SubAccounts []*SubAccount
// MustNewAccounts returns an initialised Accounts store for use in isolation from a global exchange accounts store.
// mux is set to the global dispatch.Dispatcher.
// Any errors in mux ID generation will panic, so users should balance risk vs utility accordingly depending on use-case.
func MustNewAccounts(e exchange) *Accounts {
a, err := NewAccounts(e, dispatch.GetNewMux(nil))
if err != nil {
panic(err)
}
return a
}
// NewAccounts returns an initialised Accounts store for use in isolation from a global exchange accounts store.
func NewAccounts(e exchange, mux *dispatch.Mux) (*Accounts, error) {
if err := common.NilGuard(e); err != nil {
return nil, err
}
id, err := mux.GetID()
if err != nil {
return nil, err
}
return &Accounts{
Exchange: e,
subAccounts: make(credSubAccounts),
routingID: id,
mux: mux,
}, nil
}
// NewSubAccount returns a new SubAccount.
// id may be empty.
func NewSubAccount(a asset.Item, id string) *SubAccount {
return &SubAccount{
AssetType: a,
ID: id,
Balances: CurrencyBalances{},
}
}
// Subscribe subscribes to your exchange accounts.
func (a *Accounts) Subscribe() (dispatch.Pipe, error) {
if err := common.NilGuard(a); err != nil {
return dispatch.Pipe{}, err
}
return a.mux.Subscribe(a.routingID)
}
// CurrencyBalances returns the balances for the Accounts grouped by currency.
// If creds is nil, all credential SubAccounts will be collated.
// If assetType is asset.All, all assets will be collated.
func (a *Accounts) CurrencyBalances(creds *Credentials, assetType asset.Item) (CurrencyBalances, error) {
if err := common.NilGuard(a); err != nil {
return nil, err
}
if !assetType.IsValid() && assetType != asset.All {
return nil, fmt.Errorf("%s %s %w", a.Exchange.GetName(), assetType, asset.ErrNotSupported)
}
currs := CurrencyBalances{}
a.mu.RLock()
defer a.mu.RUnlock()
for credsKey, subAccountsForCreds := range a.subAccounts {
if !creds.IsEmpty() && *creds != credsKey {
continue
}
for subAcctKey, balances := range subAccountsForCreds {
if assetType != asset.All && assetType != subAcctKey.Asset {
continue
}
for curr, bal := range balances {
if err := currs.Add(curr.Currency(), bal.Balance()); err != nil {
return nil, err // Should be impossible, so return immediately
}
}
}
}
if len(currs) == 0 {
return nil, fmt.Errorf("%w for %s credentials %s asset %s", ErrNoBalances, a.Exchange.GetName(), creds, assetType)
}
return currs, nil
}
// SubAccounts returns the public SubAccounts and their balances.
// If creds is nil, all credential SubAccounts will be returned.
// If assetType is asset.All, all assets will be returned.
func (a *Accounts) SubAccounts(creds *Credentials, assetType asset.Item) (SubAccounts, error) {
if err := common.NilGuard(a); err != nil {
return nil, err
}
if !assetType.IsValid() && assetType != asset.All {
return nil, fmt.Errorf("%s %s %w", a.Exchange.GetName(), assetType, asset.ErrNotSupported)
}
var subAccts SubAccounts
a.mu.RLock()
defer a.mu.RUnlock()
for credsKey, subAccountsForCreds := range a.subAccounts {
if !creds.IsEmpty() && *creds != credsKey {
continue
}
for subAcctKey, balances := range subAccountsForCreds {
if assetType != asset.All && assetType != subAcctKey.Asset {
continue
}
subAccts = append(subAccts, &SubAccount{
ID: subAcctKey.SubAccount,
AssetType: subAcctKey.Asset,
Balances: balances.Public(),
})
}
}
if len(subAccts) == 0 {
return nil, fmt.Errorf("%w for %s credentials %s asset %s", ErrNoSubAccounts, a.Exchange.GetName(), creds, assetType)
}
return subAccts, nil
}
// GetBalance returns a copy of the balance for that asset item.
func (a *Accounts) GetBalance(subAccount string, creds *Credentials, aType asset.Item, c currency.Code) (Balance, error) {
if err := common.NilGuard(a); err != nil {
return Balance{}, err
}
if !aType.IsValid() {
return Balance{}, fmt.Errorf("cannot get balance: %w: %q", asset.ErrNotSupported, aType)
}
if creds.IsEmpty() {
return Balance{}, fmt.Errorf("cannot get balance: %w", errCredentialsEmpty)
}
if c.IsEmpty() {
return Balance{}, fmt.Errorf("cannot get balance: %w", currency.ErrCurrencyCodeEmpty)
}
a.mu.RLock()
defer a.mu.RUnlock()
subAccts, ok := a.subAccounts[*creds]
if !ok {
return Balance{}, fmt.Errorf("%w for %s", ErrNoBalances, creds)
}
assets, ok := subAccts[key.SubAccountAsset{
SubAccount: subAccount,
Asset: aType,
}]
if !ok {
return Balance{}, fmt.Errorf("%w for %s SubAccount %q %s", ErrNoBalances, a.Exchange.GetName(), subAccount, aType)
}
b, ok := assets[c.Item]
if !ok {
return Balance{}, fmt.Errorf("%w for %s SubAccount %q %s %s", ErrNoBalances, a.Exchange.GetName(), subAccount, aType, c)
}
return b.Balance(), nil
}
// Save updates the account balances.
// If isSnapshot is true any missing currencies will be removed.
// Credentials will be retrieved from ctx, Use DeployCredentialsToContext.
// Changes to balances are published individually.
func (a *Accounts) Save(ctx context.Context, subAccts SubAccounts, isSnapshot bool) error {
if err := common.NilGuard(a); err != nil {
return fmt.Errorf("cannot save holdings: %w", err)
}
if err := common.NilGuard(a.subAccounts); err != nil {
return fmt.Errorf("cannot save holdings: %w", err)
}
creds, err := a.Exchange.GetCredentials(ctx)
if err != nil {
return err
}
if creds.IsEmpty() {
return fmt.Errorf("%w: %w", errUpdatingBalance, errCredentialsEmpty)
}
var errs error
a.mu.Lock()
defer a.mu.Unlock()
for _, s := range subAccts {
if !s.AssetType.IsValid() {
errs = common.AppendError(errs, fmt.Errorf("error loading %s[%s] SubAccount holdings: %w", s.ID, s.AssetType, asset.ErrNotSupported))
continue
}
accBalances := a.currencyBalances(creds, s.ID, s.AssetType)
updated := false
missing := maps.Clone(accBalances)
for curr, newBal := range s.Balances {
delete(missing, curr.Item)
if newBal.UpdatedAt.IsZero() {
newBal.UpdatedAt = time.Now()
}
if newBal.Currency.IsEmpty() {
newBal.Currency = curr
}
s.Balances[curr] = newBal
if u, err := accBalances.balance(curr.Item).update(newBal); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w for account ID %q [%s %s]: %w", errUpdatingBalance, s.ID, s.AssetType, curr, err))
} else if u {
updated = true
}
}
if isSnapshot {
for cur := range missing {
delete(accBalances, cur)
updated = true
}
}
if updated {
if err := a.mux.Publish(s, a.routingID); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w for %s %w", errPublish, a.Exchange, err))
}
}
}
return errs
}
// Merge adds CurrencyBalances in s to the SubAccount in l with a matching AssetType and ID.
// If no SubAccount matches, s is appended.
// Duplicate Currency Balances are added together.
func (l SubAccounts) Merge(s *SubAccount) SubAccounts {
if err := common.NilGuard(s); err != nil {
return nil
}
i := slices.IndexFunc(l, func(b *SubAccount) bool { return s.AssetType == b.AssetType && s.ID == b.ID })
if i == -1 {
return append(l, s)
}
for curr, newBal := range s.Balances {
l[i].Balances[curr] = newBal.Add(l[i].Balances[curr])
}
return l
}
// currencyBalances returns a currencyBalances entry for Credentials, SubAccount and asset.
// No nilguard protection provided, since this is a private function.
func (a *Accounts) currencyBalances(c *Credentials, subAcct string, aType asset.Item) currencyBalances {
k := key.SubAccountAsset{SubAccount: subAcct, Asset: aType}
if _, ok := a.subAccounts[*c]; !ok {
a.subAccounts[*c] = make(subAccounts)
}
if _, ok := a.subAccounts[*c][k]; !ok {
a.subAccounts[*c][k] = make(currencyBalances)
}
return a.subAccounts[*c][k]
}

View File

@@ -0,0 +1,533 @@
package accounts
import (
"context"
"fmt"
"maps"
"reflect"
"runtime"
"slices"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/dispatch"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
)
var (
creds1 = &Credentials{Key: "1"}
creds2 = &Credentials{Key: "2"}
creds3 = &Credentials{Key: "3"}
)
func TestNewAccounts(t *testing.T) {
t.Parallel()
a, err := NewAccounts(&mockEx{"mocky"}, dispatch.GetNewMux(nil))
require.NoError(t, err)
require.NotNil(t, a)
assert.Equal(t, "mocky", a.Exchange.GetName(), "Exchange name should set correctly")
assert.NotNil(t, a.subAccounts, "subAccounts should be initialised")
assert.NotEmpty(t, a.routingID, "routingID should not be empty")
assert.NotNil(t, a.mux, "mux should be set correctly")
_, err = NewAccounts(nil, dispatch.GetNewMux(nil))
assert.ErrorIs(t, err, common.ErrNilPointer)
_, err = NewAccounts(&mockEx{"mocky"}, nil)
assert.ErrorContains(t, err, "nil pointer: *dispatch.Mux")
}
func TestMustNewAccounts(t *testing.T) {
t.Parallel()
a := MustNewAccounts(&mockEx{"mocky"})
require.NotNil(t, a)
require.Panics(t, func() { _ = MustNewAccounts(nil) })
}
func TestNewSubAccount(t *testing.T) {
t.Parallel()
a := NewSubAccount(asset.Spot, "")
require.NotNil(t, a, "must not return nil with no id")
assert.Equal(t, asset.Spot, a.AssetType, "AssetType should be correct")
assert.Empty(t, a.ID, "ID should not default to anything")
a = NewSubAccount(asset.Spot, "42")
assert.Equal(t, "42", a.ID, "ID should be correct")
}
func TestSubscribe(t *testing.T) {
t.Parallel()
err := dispatch.EnsureRunning(dispatch.DefaultMaxWorkers, dispatch.DefaultJobsLimit)
require.NoError(t, err, "dispatch.EnsureRunning must not error")
p, err := MustNewAccounts(&mockEx{}).Subscribe()
require.NoError(t, err)
require.NotNil(t, p, "Subscribe must return a pipe")
require.Empty(t, p.Channel(), "Pipe must be empty before Saving anything")
}
func TestAccountsCurrencyBalances(t *testing.T) {
t.Parallel()
a := accountsFixture(t)
_, err := (*Accounts)(nil).CurrencyBalances(nil, asset.Spot)
assert.ErrorIs(t, err, common.ErrNilPointer)
_, err = a.CurrencyBalances(nil, asset.Empty)
assert.ErrorIs(t, err, asset.ErrNotSupported)
_, err = a.CurrencyBalances(creds3, asset.All)
require.ErrorIs(t, err, ErrNoBalances)
_, err = a.CurrencyBalances(creds3, asset.All)
assert.ErrorIs(t, err, ErrNoBalances)
assert.ErrorContains(t, err, "Key:[3")
// Add a balance with inconsistent currencies to cover err from currs.Add
a.subAccounts[*creds3] = map[key.SubAccountAsset]currencyBalances{
{Asset: asset.Futures}: {currency.DOGE.Item: &balance{internal: Balance{Currency: currency.ETH}}},
}
type cMap map[currency.Code]float64
for _, tc := range []struct {
c *Credentials
aT asset.Item
exp cMap
err error
}{
{nil, asset.Spot, cMap{currency.BTC: 6.0, currency.LTC: 10.0}, nil},
{creds1, asset.All, cMap{currency.BTC: 3.0, currency.LTC: 30.0}, nil},
{creds1, asset.Spot, cMap{currency.BTC: 3.0, currency.LTC: 10.0}, nil},
{creds1, asset.Futures, cMap{currency.LTC: 20.0}, nil},
{creds2, asset.Spot, cMap{currency.BTC: 3.0}, nil},
{creds3, asset.Futures, cMap{currency.DOGE: 50.0}, errBalanceCurrencyMismatch},
} {
t.Run(fmt.Sprintf("%s/%s", tc.c, tc.aT), func(t *testing.T) {
t.Parallel()
b, err := a.CurrencyBalances(tc.c, tc.aT)
if tc.err != nil {
require.ErrorIs(t, err, tc.err)
return
}
require.NoError(t, err)
require.Equal(t, len(tc.exp), len(b), "must get correct number of balances")
for c, expBal := range tc.exp {
assert.Contains(t, b, c)
assert.Equalf(t, expBal, b[c].Total, "should get correct total for %s", c)
}
})
}
}
func TestAccountsPrivateCurrencyBalances(t *testing.T) {
t.Parallel()
a := accountsFixture(t)
b := a.currencyBalances(creds3, "", asset.Spot)
r1 := a.subAccounts[*creds3]
// Using reflect since assert.Same cannot be used on maps to ensure same underlying pointer
assert.Equal(t,
reflect.ValueOf(b).UnsafePointer(),
reflect.ValueOf(r1[key.SubAccountAsset{Asset: asset.Spot}]).UnsafePointer(),
"should make and return the same map")
assert.Equal(t,
reflect.ValueOf(b).UnsafePointer(),
reflect.ValueOf(a.currencyBalances(creds3, "", asset.Spot)).UnsafePointer(),
"should return the same map on subsequent calls")
b = a.currencyBalances(creds3, "", asset.Futures)
assert.Equal(t,
reflect.ValueOf(r1).UnsafePointer(),
reflect.ValueOf(a.subAccounts[*creds3]).UnsafePointer(),
"should not make a new cred key")
assert.Equal(t,
reflect.ValueOf(b).UnsafePointer(),
reflect.ValueOf(r1[key.SubAccountAsset{Asset: asset.Futures}]).UnsafePointer(),
"should make and return the same map")
}
type tKey key.SubAccountAsset
func TestAccountsSubAccounts(t *testing.T) {
t.Parallel()
a := accountsFixture(t)
_, err := (*Accounts)(nil).SubAccounts(nil, asset.Spot)
assert.ErrorIs(t, err, common.ErrNilPointer)
_, err = a.SubAccounts(nil, asset.Empty)
assert.ErrorIs(t, err, asset.ErrNotSupported)
_, err = a.SubAccounts(creds3, asset.All)
require.ErrorIs(t, err, ErrNoSubAccounts)
require.ErrorContains(t, err, "Key:[3")
for _, tc := range []struct {
c *Credentials
aT asset.Item
exp []tKey
}{
{nil, asset.All, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}, {"1b", asset.Futures}, {"2a", asset.Spot}}},
{creds1, asset.All, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}, {"1b", asset.Futures}}},
{creds1, asset.Spot, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}}},
{creds1, asset.Futures, []tKey{{"1b", asset.Futures}}},
{creds2, asset.Spot, []tKey{{"2a", asset.Spot}}},
} {
t.Run(fmt.Sprintf("%v/%s", tc.c, tc.aT), func(t *testing.T) {
t.Parallel()
b, err := a.SubAccounts(tc.c, tc.aT)
require.NoError(t, err)
exp := subAccountsFixture(tc.exp)
require.Equal(t, len(exp), len(b), "must get correct number of subAccounts")
require.ElementsMatch(t, exp, b, "must get correct subAccounts")
})
}
}
func TestAccountsGetBalance(t *testing.T) {
t.Parallel()
a := accountsFixture(t)
_, err := (*Accounts)(nil).GetBalance("", nil, asset.Empty, currency.EMPTYCODE)
require.ErrorIs(t, err, common.ErrNilPointer)
_, err = a.GetBalance("", nil, asset.Empty, currency.EMPTYCODE)
assert.ErrorIs(t, err, asset.ErrNotSupported)
_, err = a.GetBalance("", nil, asset.Spot, currency.EMPTYCODE)
assert.ErrorIs(t, err, errCredentialsEmpty)
_, err = a.GetBalance("", creds3, asset.Spot, currency.EMPTYCODE)
assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty)
_, err = a.GetBalance("", creds3, asset.Spot, currency.DOGE)
assert.ErrorIs(t, err, ErrNoBalances)
assert.ErrorContains(t, err, "for Key:[3")
_, err = a.GetBalance("3a", creds1, asset.Spot, currency.DOGE)
assert.ErrorIs(t, err, ErrNoBalances)
assert.ErrorContains(t, err, `for mocky SubAccount "3a" spot`)
_, err = a.GetBalance("1a", creds1, asset.Spot, currency.DOGE)
assert.ErrorIs(t, err, ErrNoBalances)
assert.ErrorContains(t, err, `for mocky SubAccount "1a" spot DOGE`)
b, err := a.GetBalance("1b", creds1, asset.Spot, currency.BTC)
require.NoError(t, err)
assert.Equal(t, 2.0, b.Total, "Total should be correct")
}
func TestAccountsSave(t *testing.T) { //nolint:tparallel // Save's internal tests are sequential
t.Parallel()
a := accountsFixture(t)
relay := subscribeFixture(t, a)
beforeNow := time.Now()
ctx := t.Context()
assert.ErrorContains(t, (*Accounts)(nil).Save(ctx, nil, false), "nil pointer: *accounts.Accounts")
assert.ErrorContains(t, new(Accounts).Save(ctx, nil, false), "nil pointer: accounts.credSubAccounts")
for _, tc := range []struct {
name string
creds *Credentials
snapshot bool
accts SubAccounts
pre func(context.Context) context.Context
post func(t *testing.T) // Any additional assertions
err error
}{
{
name: "NoCredentials",
accts: SubAccounts{},
err: errCredentialsEmpty,
},
{
name: "BadCredentials",
accts: SubAccounts{},
err: common.ErrTypeAssertFailure,
pre: func(ctx context.Context) context.Context { return context.WithValue(ctx, ContextCredentialsFlag, 42) },
},
{
name: "BadAsset",
creds: creds1,
accts: SubAccounts{{AssetType: asset.All}},
err: asset.ErrNotSupported,
},
{
name: "CurrencyMismatch",
creds: creds1,
err: errBalanceCurrencyMismatch,
accts: SubAccounts{{
AssetType: asset.Spot,
ID: "1a",
Balances: CurrencyBalances{currency.BTC: {Currency: currency.DOGE}},
}},
},
{
name: "OutOfSequence",
creds: creds1,
err: errOutOfSequence,
accts: SubAccounts{{
AssetType: asset.Spot,
ID: "1a",
Balances: CurrencyBalances{currency.BTC: {UpdatedAt: skynetDate.Add(-time.Hour)}},
}},
},
{
name: "BasicSave",
creds: creds1,
accts: SubAccounts{
{
AssetType: asset.Spot,
ID: "1a",
Balances: CurrencyBalances{currency.BTC: {Total: 4, UpdatedAt: skynetDate.Add(time.Minute)}},
},
{
AssetType: asset.Spot,
ID: "1c",
Balances: CurrencyBalances{currency.ETH: {Total: 6}},
},
},
post: func(t *testing.T) {
t.Helper()
_, err := a.GetBalance("1a", creds1, asset.Spot, currency.LTC)
require.NoError(t, err, "Other balances must not be affected")
},
},
{
name: "NewCredsSaveAndPublish",
creds: creds3,
accts: SubAccounts{
{
AssetType: asset.Futures,
ID: "3a",
Balances: CurrencyBalances{currency.DOGE: {Total: 6.2}},
},
},
post: func(t *testing.T) {
t.Helper()
require.Eventually(t, func() bool { return len(relay) > 0 }, time.Second, time.Millisecond, "Publish must eventually send to Channel")
pub := <-relay
assert.Equal(t, "3a", pub.ID, "Publish should have correct ID")
assert.Contains(t, pub.Balances, currency.DOGE, "Should get DOGE Balance")
b := pub.Balances[currency.DOGE]
assert.Equal(t, currency.DOGE, b.Currency, "Currency should default to the Balances map key")
assert.WithinRange(t, b.UpdatedAt, beforeNow, time.Now(), "UpdatedAt should default to time.Now")
assert.Equal(t, 6.2, b.Total, "Total should be correct")
},
},
{
name: "SnapshotSave",
creds: creds1,
accts: SubAccounts{
{
AssetType: asset.Spot,
ID: "1a",
Balances: CurrencyBalances{currency.LTC: {Total: 12}},
},
},
snapshot: true,
post: func(t *testing.T) {
t.Helper()
_, err := a.GetBalance("1a", creds1, asset.Spot, currency.BTC)
require.ErrorIs(t, err, ErrNoBalances, "BTC balance must be removed")
},
},
{
name: "PublishError",
creds: creds1,
accts: SubAccounts{
{
AssetType: asset.Spot,
ID: "1a",
Balances: CurrencyBalances{currency.DOGE: {Total: 7.2}},
},
},
pre: func(ctx context.Context) context.Context {
a.mux = nil
return ctx
},
err: errPublish,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := t.Context()
if tc.creds != nil {
ctx = DeployCredentialsToContext(ctx, tc.creds)
}
expAccts := tc.accts.clone()
if tc.pre != nil {
ctx = tc.pre(ctx)
}
err := a.Save(ctx, tc.accts, tc.snapshot)
if tc.err != nil {
require.ErrorIs(t, err, tc.err)
return
}
require.NoError(t, err)
for i, acct := range tc.accts {
for curr := range acct.Balances {
t.Run(fmt.Sprintf("%s/%s/%s", acct.AssetType, acct.ID, curr), func(t *testing.T) {
exp := expAccts[i].Balances[curr]
got, err := a.GetBalance(acct.ID, tc.creds, acct.AssetType, curr)
require.NoError(t, err, "GetBalance must not error")
if !exp.UpdatedAt.IsZero() {
assert.Equal(t, exp.UpdatedAt, got.UpdatedAt, "UpdatedAt should match balance")
} else {
assert.WithinRange(t, got.UpdatedAt, beforeNow, time.Now(), "UpdatedAt should default to time.Now")
}
assert.Equal(t, exp.Total, got.Total, "Total should be correct")
if tc.post != nil {
tc.post(t)
}
})
}
}
})
}
}
var skynetDate = time.Unix(872896440, 0)
// accountsFixture returns an Accounts store with SubAccount IDs per credentials, and a subscription channel for updates
func accountsFixture(t *testing.T) *Accounts {
t.Helper()
a := MustNewAccounts(&mockEx{})
for _, f := range []struct {
c *Credentials
sA string
aT asset.Item
cC currency.Code
b float64
}{
{creds1, "1a", asset.Spot, currency.BTC, 1},
{creds1, "1a", asset.Spot, currency.LTC, 10},
{creds1, "1b", asset.Spot, currency.BTC, 2},
{creds1, "1b", asset.Futures, currency.LTC, 20},
{creds2, "2a", asset.Spot, currency.BTC, 3},
} {
// Not using t.Run because this is a helper
u, err := a.currencyBalances(f.c, f.sA, f.aT).balance(f.cC.Item).update(Balance{Total: f.b, UpdatedAt: skynetDate})
require.NoErrorf(t, err, "Deploy fixture balance must not error for %s/%s/%s/%s", f.c.Key, f.sA, f.aT, f.cC)
require.Truef(t, u, "Deploy fixture balance must apply an update for %s/%s/%s/%s", f.c.Key, f.sA, f.aT, f.cC)
}
return a
}
var subAccts = SubAccounts{
{
ID: "1a",
AssetType: asset.Spot,
Balances: CurrencyBalances{
currency.LTC: Balance{Currency: currency.LTC, Total: 10, UpdatedAt: skynetDate},
currency.BTC: Balance{Currency: currency.BTC, Total: 1, UpdatedAt: skynetDate},
},
},
{
ID: "1b",
AssetType: asset.Spot,
Balances: CurrencyBalances{currency.BTC: Balance{Currency: currency.BTC, Total: 2.0, UpdatedAt: skynetDate}},
},
{
ID: "1b",
AssetType: asset.Futures,
Balances: CurrencyBalances{currency.LTC: Balance{Currency: currency.LTC, Total: 20.0, UpdatedAt: skynetDate}},
},
{
ID: "2a",
AssetType: asset.Spot,
Balances: CurrencyBalances{currency.BTC: Balance{Currency: currency.BTC, Total: 3.0, UpdatedAt: skynetDate}},
},
}
func subAccountsFixture(keys []tKey) (a SubAccounts) {
if keys == nil {
return subAccts.clone()
}
for _, k := range keys {
i := slices.IndexFunc(subAccts, func(s *SubAccount) bool {
return k.SubAccount == s.ID && k.Asset == s.AssetType
})
if i == -1 {
panic(fmt.Sprintf("subAccountsFixture called with unknown subAccount key: %v", k))
}
a = append(a, subAccts[i])
}
return a
}
func subscribeFixture(t *testing.T, a *Accounts) chan *SubAccount {
t.Helper()
err := dispatch.EnsureRunning(dispatch.DefaultMaxWorkers, dispatch.DefaultJobsLimit)
require.NoError(t, err, "dispatch.EnsureRunning must not error")
p, err := a.Subscribe()
require.NoError(t, err, "Subscribe must not error")
require.NotNil(t, p, "Subscribe must return a pipe")
relay := make(chan *SubAccount, 64)
go func() {
for v := range p.Channel() {
if s, ok := v.(*SubAccount); ok && s.ID == "3a" { // Only interested in relaying events for a single test account
relay <- s
}
}
}()
runtime.Gosched()
return relay
}
func TestMerge(t *testing.T) {
t.Parallel()
s := subAccountsFixture(nil)
assert.Nil(t, s.Merge(nil), "Should return nil for a merge of nil SubAccounts")
exp := len(s)
a := &SubAccount{
ID: "1a",
AssetType: asset.Spot,
Balances: CurrencyBalances{currency.BTC: Balance{Total: 1}},
}
s = s.Merge(a)
require.Equal(t, exp, len(s), "Must contain correct number of accounts after merging")
for _, acct := range s {
if acct.ID == "1a" && acct.AssetType == asset.Spot {
assert.Equal(t, 2.0, acct.Balances[currency.BTC].Total)
}
}
a = &SubAccount{
ID: "new",
AssetType: asset.Spot,
Balances: CurrencyBalances{currency.BTC: Balance{Total: 1}},
}
s = s.Merge(a)
assert.Contains(t, s, a, "Should contain the new subaccount")
}
func TestSubAccountsClone(t *testing.T) {
t.Parallel()
s := SubAccounts{
{ID: "1", AssetType: asset.Spot, Balances: CurrencyBalances{currency.BTC: {Total: 1}}},
{ID: "2", AssetType: asset.Futures, Balances: CurrencyBalances{currency.LTC: {Total: 2}}},
}
c := s.clone()
require.Equal(t, s, c, "Clone must match original")
c[0].ID = "3"
assert.NotEqual(t, s, c, "Should not be equal after modification")
}
func (l SubAccounts) clone() (c SubAccounts) {
for _, s := range l {
bals := make(CurrencyBalances, len(s.Balances))
maps.Copy(bals, s.Balances)
c = append(c, &SubAccount{
ID: s.ID,
AssetType: s.AssetType,
Balances: bals,
})
}
return c
}

View File

@@ -0,0 +1,148 @@
package accounts
import (
"errors"
"fmt"
"sync"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
)
var (
errBalanceCurrencyMismatch = errors.New("balance currency does not match update currency")
errOutOfSequence = errors.New("out of sequence")
errUpdatedAtIsZero = errors.New("updatedAt may not be zero")
)
// Balance contains an exchange currency balance.
type Balance struct {
Currency currency.Code
Total float64
Hold float64
Free float64
AvailableWithoutBorrow float64
Borrowed float64
UpdatedAt time.Time
}
// Change defines incoming balance change on currency holdings.
type Change struct {
Account string
AssetType asset.Item
Balance Balance
}
// balance contains a balance with live updates.
type balance struct {
internal Balance
m sync.RWMutex
}
// CurrencyBalances provides a map of currencies to balances.
type CurrencyBalances map[currency.Code]Balance
// currencyBalances provides a map of currencies to balances.
type currencyBalances map[*currency.Item]*balance
// Set will set a currency balance, overwriting any previous Balance.
//
//nolint:gocritic // Ignoring hugeparam because we want the convenience of all callers passing by value
//nolint:gocritic // and we want to store a copy anyway so the hugeparam warning that this copies a value is not relevant
func (c *CurrencyBalances) Set(curr currency.Code, b Balance) {
b.Currency = curr
(*c)[curr] = b
}
// Add will add to a currency balance.
func (c *CurrencyBalances) Add(curr currency.Code, b Balance) error { //nolint:gocritic // hugeparam not relevant; we want to store a value so we'd deref anyway
if curr == currency.EMPTYCODE {
return currency.ErrCurrencyCodeEmpty
}
if b.Currency != currency.EMPTYCODE && !b.Currency.Equal(curr) {
return fmt.Errorf("%w: %q != %q", errBalanceCurrencyMismatch, b.Currency, curr)
}
if e, ok := (*c)[curr]; !ok {
b.Currency = curr
(*c)[curr] = b
} else {
(*c)[curr] = e.Add(b)
}
return nil
}
// Balance returns a snapshot copy of the Balance.
func (b *balance) Balance() Balance {
b.m.RLock()
defer b.m.RUnlock()
return b.internal
}
// Add returns a new Balance adding together a and b.
// UpdatedAt is the later of the two Balances.
func (b *Balance) Add(a Balance) Balance { //nolint:gocritic // hugeparam not relevant; We'd need to copy it in map iterations anyway
var u time.Time
if a.UpdatedAt.After(b.UpdatedAt) {
u = a.UpdatedAt
} else {
u = b.UpdatedAt
}
return Balance{
Total: b.Total + a.Total,
Hold: b.Hold + a.Hold,
Free: b.Free + a.Free,
AvailableWithoutBorrow: b.AvailableWithoutBorrow + a.AvailableWithoutBorrow,
Borrowed: b.Borrowed + a.Borrowed,
UpdatedAt: u,
}
}
// Public returns a copy of the currencyBalances converted to CurrencyBalances for use outside this package.
func (c currencyBalances) Public() CurrencyBalances {
n := make(CurrencyBalances, len(c))
for curr, bal := range c {
n[curr.Currency()] = bal.Balance()
}
return n
}
// update checks that an incoming change has a valid change, and returns if the balances were changed.
// If change does not have a Currency set, the existing Currency is preserved.
func (b *balance) update(change Balance) (bool, error) { //nolint:gocritic // hugeparam not relevant; We'd need to copy it later anyway
if err := common.NilGuard(b); err != nil {
return false, err
}
if change.UpdatedAt.IsZero() {
return false, errUpdatedAtIsZero
}
b.m.Lock()
defer b.m.Unlock()
if b.internal.Currency != currency.EMPTYCODE {
if change.Currency == currency.EMPTYCODE {
change.Currency = b.internal.Currency
} else if !change.Currency.Equal(b.internal.Currency) {
return false, fmt.Errorf("%w %q != %q", errBalanceCurrencyMismatch, b.internal.Currency, change.Currency)
}
}
if b.internal.UpdatedAt.After(change.UpdatedAt) {
return false, errOutOfSequence
}
b.internal.UpdatedAt = change.UpdatedAt // Set just the time, and then can compare easily
if b.internal == change {
return false, nil
}
b.internal = change
return true, nil
}
// balance returns a balance for a currency.
func (c currencyBalances) balance(curr *currency.Item) *balance {
b, ok := c[curr]
if !ok {
b = &balance{internal: Balance{Currency: curr.Currency()}}
c[curr] = b
}
return b
}

View File

@@ -0,0 +1,152 @@
package accounts
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
)
// TestCurrencyBalancesSet exercises CurrencyBalances.Set
func TestCurrencyBalancesSet(t *testing.T) {
t.Parallel()
c := CurrencyBalances{}
c.Set(currency.BTC, Balance{Total: 4.2})
require.Contains(t, c, currency.BTC, "must add an entry to an uninitialised CurrencyBalances")
assert.Equal(t, currency.BTC, c[currency.BTC].Currency, "should set the Currency")
c.Set(currency.LTC, Balance{Currency: currency.ETH, Total: 52.4})
require.Contains(t, c, currency.LTC, "must add an entry to an existing CurrencyBalances")
assert.Equal(t, currency.LTC, c[currency.LTC].Currency, "should overwrite Currency")
}
// TestCurrencyBalancesAdd exercises CurrencyBalances.Add
func TestCurrencyBalancesAdd(t *testing.T) {
t.Parallel()
c := CurrencyBalances{}
assert.ErrorIs(t, c.Add(currency.EMPTYCODE, Balance{}), currency.ErrCurrencyCodeEmpty)
err := c.Add(currency.BTC, Balance{Total: 4.2})
require.NoError(t, err)
require.Contains(t, c, currency.BTC, "must add an entry to an uninitialised CurrencyBalances")
assert.Equal(t, currency.BTC, c[currency.BTC].Currency, "should set the Currency")
assert.Equal(t, 4.2, c[currency.BTC].Total, "should initialise the Total")
err = c.Add(currency.BTC, Balance{Total: 1.3, Hold: 2.4})
require.NoError(t, err)
assert.Equal(t, 5.5, c[currency.BTC].Total, "should add to existing Total")
assert.Equal(t, 2.4, c[currency.BTC].Hold, "should initialise Hold")
err = c.Add(currency.LTC, Balance{Currency: currency.LTC, Total: 14.3})
require.NoError(t, err)
require.Contains(t, c, currency.LTC, "must add an entry to an existing CurrencyBalances")
assert.Equal(t, 14.3, c[currency.LTC].Total, "should add when Balance.Currency is equal")
err = c.Add(currency.ETH, Balance{Currency: currency.LTC, Total: 14.2})
assert.ErrorIs(t, err, errBalanceCurrencyMismatch)
}
// TestCurrencyBalancesPublic exercises currencyBalances.Public
func TestCurrencyBalancesPublic(t *testing.T) {
t.Parallel()
b := (&currencyBalances{
currency.BTC.Item: &balance{internal: Balance{Total: 4.2}},
currency.LTC.Item: &balance{internal: Balance{Total: 1.7}},
}).Public()
require.Equal(t, 2, len(b), "Pulbic must return the correct number of Balances")
require.Contains(t, b, currency.BTC)
require.Contains(t, b, currency.LTC)
assert.Equal(t, 4.2, b[currency.BTC].Total)
assert.Equal(t, 1.7, b[currency.LTC].Total)
}
// TestCurrencyBalancesBalance exercises currencyBalances.balance
func TestCurrencyBalancesBalance(t *testing.T) {
t.Parallel()
c := currencyBalances{}
b := c.balance(currency.BTC.Item)
require.NotNil(t, b)
assert.Same(t, c[currency.BTC.Item], b, "should make and return the same entry")
assert.Same(t, b, c.balance(currency.BTC.Item), "should make and return the same entry")
}
// TestBalanceBalance exercises balance.Balance
func TestBalanceBalance(t *testing.T) {
t.Parallel()
b := &balance{internal: Balance{Currency: currency.BTC}}
i := b.Balance()
assert.Equal(t, b.internal, i)
}
// TestBalanceAdd exercises Balance.Add
func TestBalanceAdd(t *testing.T) {
t.Parallel()
n1 := time.Now()
n2 := n1.Add(-2 * time.Minute)
b := new(Balance).Add(Balance{Total: 4.2, UpdatedAt: n2})
assert.Equal(t, 4.2, b.Total, "should initialise Total")
assert.Equal(t, n2, b.UpdatedAt, "should set UpdatedAt")
b = b.Add(Balance{Total: 1.3, Hold: 3.0, UpdatedAt: n1})
assert.Equal(t, 5.5, b.Total, "should add to Total")
assert.Equal(t, 3.0, b.Hold, "should initialise Hold")
assert.Equal(t, n1, b.UpdatedAt, "should set UpdatedAt")
b = b.Add(Balance{Total: 2.2, Hold: 4.0, UpdatedAt: n1.Add(-time.Minute)})
assert.Equal(t, 7.7, b.Total, "should add to Total")
assert.Equal(t, 7.0, b.Hold, "should add to Hold")
assert.Equal(t, n1, b.UpdatedAt, "should keep newer UpdatedAt")
}
// TestBalanceUpdate exercises balance.update
func TestBalanceUpdate(t *testing.T) {
t.Parallel()
_, err := (*balance)(nil).update(Balance{})
require.ErrorIs(t, err, common.ErrNilPointer)
n := time.Now()
b := &balance{internal: Balance{
Currency: currency.LTC,
Total: 4.2,
UpdatedAt: n,
}}
_, err = b.update(Balance{})
require.ErrorIs(t, err, errUpdatedAtIsZero)
_, err = b.update(Balance{UpdatedAt: n, Currency: currency.ETH})
assert.ErrorIs(t, err, errBalanceCurrencyMismatch)
_, err = b.update(Balance{UpdatedAt: n.Add(-time.Millisecond)})
assert.ErrorIs(t, err, errOutOfSequence, "should error when time out of sequence")
u, err := b.update(Balance{UpdatedAt: n, Total: 5.1})
require.NoError(t, err, "must not error when time is the same instant and currency is empty")
assert.Equal(t, 5.1, b.internal.Total, "Total should be correct")
assert.True(t, u, "should return updated")
n = time.Now()
u, err = b.update(Balance{UpdatedAt: n, Total: 5.1})
require.NoError(t, err)
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
assert.False(t, u, "should not return updated when nothing really changed")
n = time.Now()
u, err = b.update(Balance{UpdatedAt: n, Currency: currency.LTC, Total: 5.1})
require.NoError(t, err, "must not error when Currency matches")
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
assert.False(t, u, "should return not updated when only time changed")
n = time.Now()
u, err = b.update(Balance{UpdatedAt: n, Currency: currency.LTC, Total: 4.4})
require.NoError(t, err, "must not error when Currency matches")
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
assert.Equal(t, 4.4, b.internal.Total, "should update Total")
assert.True(t, u, "should return updated")
}

View File

@@ -0,0 +1,212 @@
package accounts
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"google.golang.org/grpc/metadata"
)
// contextCredential is a string flag for use with context values when setting
// credentials internally or via gRPC.
type contextCredential string
const (
// ContextCredentialsFlag used for retrieving api credentials from context
ContextCredentialsFlag contextCredential = "apicredentials"
// ContextSubAccountFlag used for retrieving just the sub account from
// context, when the default config credentials sub account needs to be
// changed while the same keys can be used.
ContextSubAccountFlag contextCredential = "subaccountoverride"
apiKeyDisplaySize = 16
)
// Default credential values
const (
Key = "key"
Secret = "secret"
SubAccountSTR = "subaccount"
ClientID = "clientid"
OneTimePassword = "otp"
PEMKey = "pemkey"
)
var (
errMetaDataIsNil = errors.New("meta data is nil")
errInvalidCredentialMetaDataLength = errors.New("invalid meta data to process credentials")
errMissingInfo = errors.New("cannot parse meta data missing information in key value pair")
)
// Credentials define parameters that allow for an authenticated request.
type Credentials struct {
Key string
Secret string
ClientID string // TODO: Implement with exchange orders functionality
PEMKey string
SubAccount string
OneTimePassword string
SecretBase64Decoded bool
// TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities.
}
// GetMetaData returns the credentials for metadata context deployment
func (c *Credentials) GetMetaData() (flag, values string) {
vals := make([]string, 0, 6)
if c.Key != "" {
vals = append(vals, Key+":"+c.Key)
}
if c.Secret != "" {
vals = append(vals, Secret+":"+c.Secret)
}
if c.SubAccount != "" {
vals = append(vals, SubAccountSTR+":"+c.SubAccount)
}
if c.ClientID != "" {
vals = append(vals, ClientID+":"+c.ClientID)
}
if c.PEMKey != "" {
vals = append(vals, PEMKey+":"+c.PEMKey)
}
if c.OneTimePassword != "" {
vals = append(vals, OneTimePassword+":"+c.OneTimePassword)
}
return string(ContextCredentialsFlag), strings.Join(vals, ",")
}
// String prints out basic credential info (obfuscated) to track key instances
// associated with exchanges.
func (c *Credentials) String() string {
obfuscated := c.Key
if len(obfuscated) > apiKeyDisplaySize {
obfuscated = obfuscated[:apiKeyDisplaySize]
}
return fmt.Sprintf("Key:[%s...] SubAccount:[%s] ClientID:[%s]",
obfuscated,
c.SubAccount,
c.ClientID)
}
// getInternal returns the values for assignment to an internal context
func (c *Credentials) getInternal() (contextCredential, *ContextCredentialsStore) {
if c.IsEmpty() {
return "", nil
}
store := &ContextCredentialsStore{}
store.Load(c)
return ContextCredentialsFlag, store
}
// IsEmpty return true if the underlying credentials type has not been filled
// with at least one item.
func (c *Credentials) IsEmpty() bool {
return c == nil || c.ClientID == "" &&
c.Key == "" &&
c.OneTimePassword == "" &&
c.PEMKey == "" &&
c.Secret == "" &&
c.SubAccount == ""
}
// Equal determines if the keys are the same.
// OTP omitted because it's generated per request.
// PEMKey and Secret omitted because of direct correlation with api key.
func (c *Credentials) Equal(other *Credentials) bool {
return c != nil &&
other != nil &&
c.Key == other.Key &&
c.ClientID == other.ClientID &&
(c.SubAccount == other.SubAccount || c.SubAccount == "" && other.SubAccount == "main" || c.SubAccount == "main" && other.SubAccount == "")
}
// ContextCredentialsStore protects the stored credentials for use in a context
type ContextCredentialsStore struct {
creds *Credentials
mu sync.RWMutex
}
// Load stores provided credentials
func (c *ContextCredentialsStore) Load(creds *Credentials) {
// Segregate from external call
cpy := *creds
c.mu.Lock()
c.creds = &cpy
c.mu.Unlock()
}
// Get returns the full credentials from the store
func (c *ContextCredentialsStore) Get() *Credentials {
c.mu.RLock()
creds := *c.creds
c.mu.RUnlock()
return &creds
}
// ParseCredentialsMetadata intercepts and converts credentials metadata to a
// static type for authentication processing and protection.
func ParseCredentialsMetadata(ctx context.Context, md metadata.MD) (context.Context, error) {
if md == nil {
return ctx, errMetaDataIsNil
}
credMD, ok := md[string(ContextCredentialsFlag)]
if !ok || len(credMD) == 0 {
return ctx, nil
}
if len(credMD) != 1 {
return ctx, errInvalidCredentialMetaDataLength
}
segregatedCreds := strings.Split(credMD[0], ",")
var ctxCreds Credentials
var subAccountHere string
for x := range segregatedCreds {
keyVals := strings.Split(segregatedCreds[x], ":")
if len(keyVals) != 2 {
return ctx, fmt.Errorf("%w received %v fields, expected 2 contains: %s",
errMissingInfo,
len(keyVals),
keyVals)
}
switch keyVals[0] {
case Key:
ctxCreds.Key = keyVals[1]
case Secret:
ctxCreds.Secret = keyVals[1]
case SubAccountSTR:
// Capture sub account as this can override if other values are
// not included in metadata.
subAccountHere = keyVals[1]
case ClientID:
ctxCreds.ClientID = keyVals[1]
case PEMKey:
ctxCreds.PEMKey = keyVals[1]
case OneTimePassword:
ctxCreds.OneTimePassword = keyVals[1]
}
}
if ctxCreds.IsEmpty() && subAccountHere != "" {
// This will override default sub account details if needed.
return DeploySubAccountOverrideToContext(ctx, subAccountHere), nil
}
// merge sub account to main context credentials
ctxCreds.SubAccount = subAccountHere
return DeployCredentialsToContext(ctx, &ctxCreds), nil
}
// DeployCredentialsToContext sets credentials for internal use to context which
// can override default credential values.
func DeployCredentialsToContext(ctx context.Context, creds *Credentials) context.Context {
flag, store := creds.getInternal()
return context.WithValue(ctx, flag, store)
}
// DeploySubAccountOverrideToContext sets subaccount as override to credentials
// as a separate flag.
func DeploySubAccountOverrideToContext(ctx context.Context, subAccount string) context.Context {
return context.WithValue(ctx, ContextSubAccountFlag, subAccount)
}

View File

@@ -0,0 +1,174 @@
package accounts
import (
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)
func TestIsEmpty(t *testing.T) {
t.Parallel()
var c *Credentials
if !c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
}
c = new(Credentials)
if !c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
}
c.SubAccount = "woow"
if c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", false, c.IsEmpty())
}
}
func TestParseCredentialsMetadata(t *testing.T) {
t.Parallel()
_, err := ParseCredentialsMetadata(t.Context(), nil)
require.ErrorIs(t, err, errMetaDataIsNil)
_, err = ParseCredentialsMetadata(t.Context(), metadata.MD{})
require.NoError(t, err)
ctx := metadata.AppendToOutgoingContext(t.Context(),
string(ContextCredentialsFlag), "wow", string(ContextCredentialsFlag), "wow2")
nortyMD, _ := metadata.FromOutgoingContext(ctx)
_, err = ParseCredentialsMetadata(t.Context(), nortyMD)
require.ErrorIs(t, err, errInvalidCredentialMetaDataLength)
ctx = metadata.AppendToOutgoingContext(t.Context(),
string(ContextCredentialsFlag), "brokenstring")
nortyMD, _ = metadata.FromOutgoingContext(ctx)
_, err = ParseCredentialsMetadata(t.Context(), nortyMD)
require.ErrorIs(t, err, errMissingInfo)
beforeCreds := Credentials{
Key: "superkey",
Secret: "supersecret",
SubAccount: "supersub",
ClientID: "superclient",
PEMKey: "superpem",
OneTimePassword: "superOneTimePasssssss",
}
flag, outGoing := beforeCreds.GetMetaData()
ctx = metadata.AppendToOutgoingContext(t.Context(), flag, outGoing)
lovelyMD, _ := metadata.FromOutgoingContext(ctx)
ctx, err = ParseCredentialsMetadata(t.Context(), lovelyMD)
require.NoError(t, err)
store, ok := ctx.Value(ContextCredentialsFlag).(*ContextCredentialsStore)
if !ok {
t.Fatal("should have processed")
}
afterCreds := store.Get()
if afterCreds.Key != "superkey" &&
afterCreds.Secret != "supersecret" &&
afterCreds.SubAccount != "supersub" &&
afterCreds.ClientID != "superclient" &&
afterCreds.PEMKey != "superpem" &&
afterCreds.OneTimePassword != "superOneTimePasssssss" {
t.Fatal("unexpected values")
}
// subaccount override
subaccount := Credentials{
SubAccount: "supersub",
}
flag, outGoing = subaccount.GetMetaData()
ctx = metadata.AppendToOutgoingContext(t.Context(), flag, outGoing)
lovelyMD, _ = metadata.FromOutgoingContext(ctx)
ctx, err = ParseCredentialsMetadata(t.Context(), lovelyMD)
require.NoError(t, err)
sa, ok := ctx.Value(ContextSubAccountFlag).(string)
if !ok {
t.Fatal("should have processed")
}
if sa != "supersub" {
t.Fatal("unexpected value")
}
}
func TestGetInternal(t *testing.T) {
t.Parallel()
flag, store := (&Credentials{}).getInternal()
if flag != "" {
t.Fatal("unexpected value")
}
if store != nil {
t.Fatal("unexpected value")
}
flag, store = (&Credentials{Key: "wow"}).getInternal()
if flag != ContextCredentialsFlag {
t.Fatal("unexpected value")
}
if store == nil {
t.Fatal("unexpected value")
}
if store.Get().Key != "wow" {
t.Fatal("unexpected value")
}
}
func TestString(t *testing.T) {
t.Parallel()
creds := Credentials{}
if s := creds.String(); s != "Key:[...] SubAccount:[] ClientID:[]" {
t.Fatal("unexpected value")
}
creds.Key = "12345678910111234"
creds.SubAccount = "sub"
creds.ClientID = "client"
if s := creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" {
t.Fatal("unexpected value")
}
}
func TestCredentialsEqual(t *testing.T) {
t.Parallel()
var this, that *Credentials
if this.Equal(that) {
t.Fatal("unexpected value")
}
this = &Credentials{}
if this.Equal(that) {
t.Fatal("unexpected value")
}
that = &Credentials{Key: "1337"}
if this.Equal(that) {
t.Fatal("unexpected value")
}
this.Key = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.ClientID = "1337"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.ClientID = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.SubAccount = "someSub"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.SubAccount = "someSub"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
}

View File

@@ -0,0 +1,65 @@
package accounts
import (
"context"
"sync"
"sync/atomic"
"github.com/thrasher-corp/gocryptotrader/dispatch"
)
// Store contains accounts for exchanges.
type Store struct {
exchangeAccounts exchangeMap
mu sync.Mutex
mux *dispatch.Mux
}
type exchangeMap map[exchange]*Accounts
type exchange interface {
GetName() string
GetCredentials(context.Context) (*Credentials, error)
}
type exchangeWrapper interface {
GetBase() exchange
}
var global atomic.Pointer[Store]
// NewStore returns a new store with the default global dispatcher mux.
func NewStore() *Store {
return &Store{
exchangeAccounts: make(exchangeMap),
mux: dispatch.GetNewMux(nil),
}
}
// GetStore returns the singleton accounts store for global use; Initialising if necessary.
func GetStore() *Store {
if s := global.Load(); s != nil {
return s
}
_ = global.CompareAndSwap(nil, NewStore())
return global.Load()
}
// GetExchangeAccounts returns accounts for a specific exchange.
func (s *Store) GetExchangeAccounts(e exchange) (a *Accounts, err error) {
s.mu.Lock()
defer s.mu.Unlock()
if w, ok := e.(exchangeWrapper); ok {
// Because SetupDefaults is called on Base, it's easiest to just use the Base pointer as the key
e = w.GetBase()
}
a, ok := s.exchangeAccounts[e]
if !ok {
a, err = NewAccounts(e, s.mux)
if err != nil {
return nil, err
}
s.exchangeAccounts[e] = a
}
return a, nil
}

View File

@@ -0,0 +1,86 @@
package accounts
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
)
func TestNewStore(t *testing.T) {
t.Parallel()
s := NewStore()
require.NotNil(t, s, "NewStore must return a store")
require.NotNil(t, s.mux, "NewStore must set mux")
require.NotNil(t, s.exchangeAccounts, "NewStore must set exchangeAccounts")
}
func TestGetStore(t *testing.T) {
t.Parallel()
// Initialise global in case of -count=N+; No other tests should be relying on it
global.Store(nil)
s := GetStore()
require.NotNil(t, s, "GetStore must return a Store")
require.Same(t, global.Load(), s, "GetStore must set the global store")
require.Same(t, s, GetStore(), "GetStore must return the global store on second call")
}
func TestGetExchangeAccounts(t *testing.T) {
t.Parallel()
s := NewStore()
m := &mockEx{"mocky"}
a := &Accounts{}
s.exchangeAccounts[m] = a
got, err := s.GetExchangeAccounts(m)
require.NoError(t, err)
assert.Same(t, a, got, "Should retrieve same existing Accounts")
m = &mockEx{"new"}
got, err = s.GetExchangeAccounts(m)
require.NoError(t, err)
assert.Same(t, s.exchangeAccounts[m], got, "Should retrieve the new exchange")
w := &mockExBase{m}
got, err = s.GetExchangeAccounts(w)
require.NoError(t, err)
assert.NotNil(t, got)
_, err = s.GetExchangeAccounts(nil)
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly on nil exchange")
}
type mockEx struct {
name string
}
func (m *mockEx) GetName() string {
return "mocky"
}
func (m *mockEx) GetCredentials(ctx context.Context) (*Credentials, error) {
if value := ctx.Value(ContextCredentialsFlag); value != nil {
if s, ok := value.(*ContextCredentialsStore); ok {
return s.Get(), nil
}
return nil, common.GetTypeAssertError("*accounts.ContextCredentialsStore", value)
}
return nil, nil
}
type mockExBase struct {
base exchange
}
func (m *mockExBase) GetBase() exchange {
return m.base
}
func (m *mockExBase) GetCredentials(ctx context.Context) (*Credentials, error) {
return m.base.GetCredentials(ctx)
}
func (m *mockExBase) GetName() string {
return m.base.GetName()
}