mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-01 15:10:44 +00:00
accounts: Move to instance methods, fix races and isolate tests (#1923)
* Bybit: Fix race in TestUpdateAccountInfo and TestWSHandleData * DriveBy rename TestWSHandleData * This doesn't address running with -race=2+ due to the singleton * Accounts: Add account.GetService() * exchange: Assertify TestSetupDefaults * Exchanges: Add account.Service override for testing * Exchanges: Remove duplicate IsWebsocketEnabled test from TestSetupDefaults * Dispatch: Replace nil checks with NilGuard * Engine: Remove deprecated printAccountHoldingsChangeSummary * Dispatcher: Add EnsureRunning method * Accounts: Move singleton accounts service to exchange Accounts * Move singleton accounts service to exchange Accounts This maintains the concept of a global store, whilst allowing exchanges to override it when needed, particularly for testing. APIServer: * Remove getAllActiveAccounts from apiserver Deprecated apiserver only thing using this, so remove it instead of updating it * Update comment for UpdateAccountBalances everywhere * Docs: Add punctuation to function comments * Bybit: Coverage for wsProcessWalletPushData Save
This commit is contained in:
312
exchange/accounts/accounts.go
Normal file
312
exchange/accounts/accounts.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/dispatch"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
)
|
||||
|
||||
// Public errors.
|
||||
var (
|
||||
ErrNoBalances = errors.New("no balances found")
|
||||
ErrNoSubAccounts = errors.New("no subAccounts found")
|
||||
)
|
||||
|
||||
var (
|
||||
errCredentialsEmpty = errors.New("no credentials provided")
|
||||
errUpdatingBalance = errors.New("error updating balance")
|
||||
errPublish = errors.New("error publishing account changes")
|
||||
)
|
||||
|
||||
// Accounts holds a stream ID and a map to the exchange holdings.
|
||||
type Accounts struct {
|
||||
Exchange exchange
|
||||
routingID uuid.UUID // GCT internal routing mux id
|
||||
subAccounts credSubAccounts
|
||||
mu sync.RWMutex
|
||||
mux *dispatch.Mux
|
||||
}
|
||||
|
||||
type (
|
||||
credSubAccounts map[Credentials]subAccounts
|
||||
subAccounts map[key.SubAccountAsset]currencyBalances
|
||||
)
|
||||
|
||||
// SubAccount contains an account for an asset type and its balances.
|
||||
// The SubAccount may be the main account depending on exchange structure.
|
||||
type SubAccount struct {
|
||||
ID string
|
||||
AssetType asset.Item
|
||||
Balances CurrencyBalances
|
||||
}
|
||||
|
||||
// SubAccounts contains a list of public SubAccounts.
|
||||
type SubAccounts []*SubAccount
|
||||
|
||||
// MustNewAccounts returns an initialised Accounts store for use in isolation from a global exchange accounts store.
|
||||
// mux is set to the global dispatch.Dispatcher.
|
||||
// Any errors in mux ID generation will panic, so users should balance risk vs utility accordingly depending on use-case.
|
||||
func MustNewAccounts(e exchange) *Accounts {
|
||||
a, err := NewAccounts(e, dispatch.GetNewMux(nil))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// NewAccounts returns an initialised Accounts store for use in isolation from a global exchange accounts store.
|
||||
func NewAccounts(e exchange, mux *dispatch.Mux) (*Accounts, error) {
|
||||
if err := common.NilGuard(e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := mux.GetID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Accounts{
|
||||
Exchange: e,
|
||||
subAccounts: make(credSubAccounts),
|
||||
routingID: id,
|
||||
mux: mux,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewSubAccount returns a new SubAccount.
|
||||
// id may be empty.
|
||||
func NewSubAccount(a asset.Item, id string) *SubAccount {
|
||||
return &SubAccount{
|
||||
AssetType: a,
|
||||
ID: id,
|
||||
Balances: CurrencyBalances{},
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe subscribes to your exchange accounts.
|
||||
func (a *Accounts) Subscribe() (dispatch.Pipe, error) {
|
||||
if err := common.NilGuard(a); err != nil {
|
||||
return dispatch.Pipe{}, err
|
||||
}
|
||||
return a.mux.Subscribe(a.routingID)
|
||||
}
|
||||
|
||||
// CurrencyBalances returns the balances for the Accounts grouped by currency.
|
||||
// If creds is nil, all credential SubAccounts will be collated.
|
||||
// If assetType is asset.All, all assets will be collated.
|
||||
func (a *Accounts) CurrencyBalances(creds *Credentials, assetType asset.Item) (CurrencyBalances, error) {
|
||||
if err := common.NilGuard(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !assetType.IsValid() && assetType != asset.All {
|
||||
return nil, fmt.Errorf("%s %s %w", a.Exchange.GetName(), assetType, asset.ErrNotSupported)
|
||||
}
|
||||
|
||||
currs := CurrencyBalances{}
|
||||
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
for credsKey, subAccountsForCreds := range a.subAccounts {
|
||||
if !creds.IsEmpty() && *creds != credsKey {
|
||||
continue
|
||||
}
|
||||
for subAcctKey, balances := range subAccountsForCreds {
|
||||
if assetType != asset.All && assetType != subAcctKey.Asset {
|
||||
continue
|
||||
}
|
||||
for curr, bal := range balances {
|
||||
if err := currs.Add(curr.Currency(), bal.Balance()); err != nil {
|
||||
return nil, err // Should be impossible, so return immediately
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(currs) == 0 {
|
||||
return nil, fmt.Errorf("%w for %s credentials %s asset %s", ErrNoBalances, a.Exchange.GetName(), creds, assetType)
|
||||
}
|
||||
return currs, nil
|
||||
}
|
||||
|
||||
// SubAccounts returns the public SubAccounts and their balances.
|
||||
// If creds is nil, all credential SubAccounts will be returned.
|
||||
// If assetType is asset.All, all assets will be returned.
|
||||
func (a *Accounts) SubAccounts(creds *Credentials, assetType asset.Item) (SubAccounts, error) {
|
||||
if err := common.NilGuard(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !assetType.IsValid() && assetType != asset.All {
|
||||
return nil, fmt.Errorf("%s %s %w", a.Exchange.GetName(), assetType, asset.ErrNotSupported)
|
||||
}
|
||||
|
||||
var subAccts SubAccounts
|
||||
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
for credsKey, subAccountsForCreds := range a.subAccounts {
|
||||
if !creds.IsEmpty() && *creds != credsKey {
|
||||
continue
|
||||
}
|
||||
for subAcctKey, balances := range subAccountsForCreds {
|
||||
if assetType != asset.All && assetType != subAcctKey.Asset {
|
||||
continue
|
||||
}
|
||||
subAccts = append(subAccts, &SubAccount{
|
||||
ID: subAcctKey.SubAccount,
|
||||
AssetType: subAcctKey.Asset,
|
||||
Balances: balances.Public(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(subAccts) == 0 {
|
||||
return nil, fmt.Errorf("%w for %s credentials %s asset %s", ErrNoSubAccounts, a.Exchange.GetName(), creds, assetType)
|
||||
}
|
||||
return subAccts, nil
|
||||
}
|
||||
|
||||
// GetBalance returns a copy of the balance for that asset item.
|
||||
func (a *Accounts) GetBalance(subAccount string, creds *Credentials, aType asset.Item, c currency.Code) (Balance, error) {
|
||||
if err := common.NilGuard(a); err != nil {
|
||||
return Balance{}, err
|
||||
}
|
||||
if !aType.IsValid() {
|
||||
return Balance{}, fmt.Errorf("cannot get balance: %w: %q", asset.ErrNotSupported, aType)
|
||||
}
|
||||
|
||||
if creds.IsEmpty() {
|
||||
return Balance{}, fmt.Errorf("cannot get balance: %w", errCredentialsEmpty)
|
||||
}
|
||||
|
||||
if c.IsEmpty() {
|
||||
return Balance{}, fmt.Errorf("cannot get balance: %w", currency.ErrCurrencyCodeEmpty)
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
subAccts, ok := a.subAccounts[*creds]
|
||||
if !ok {
|
||||
return Balance{}, fmt.Errorf("%w for %s", ErrNoBalances, creds)
|
||||
}
|
||||
|
||||
assets, ok := subAccts[key.SubAccountAsset{
|
||||
SubAccount: subAccount,
|
||||
Asset: aType,
|
||||
}]
|
||||
if !ok {
|
||||
return Balance{}, fmt.Errorf("%w for %s SubAccount %q %s", ErrNoBalances, a.Exchange.GetName(), subAccount, aType)
|
||||
}
|
||||
b, ok := assets[c.Item]
|
||||
if !ok {
|
||||
return Balance{}, fmt.Errorf("%w for %s SubAccount %q %s %s", ErrNoBalances, a.Exchange.GetName(), subAccount, aType, c)
|
||||
}
|
||||
return b.Balance(), nil
|
||||
}
|
||||
|
||||
// Save updates the account balances.
|
||||
// If isSnapshot is true any missing currencies will be removed.
|
||||
// Credentials will be retrieved from ctx, Use DeployCredentialsToContext.
|
||||
// Changes to balances are published individually.
|
||||
func (a *Accounts) Save(ctx context.Context, subAccts SubAccounts, isSnapshot bool) error {
|
||||
if err := common.NilGuard(a); err != nil {
|
||||
return fmt.Errorf("cannot save holdings: %w", err)
|
||||
}
|
||||
if err := common.NilGuard(a.subAccounts); err != nil {
|
||||
return fmt.Errorf("cannot save holdings: %w", err)
|
||||
}
|
||||
|
||||
creds, err := a.Exchange.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if creds.IsEmpty() {
|
||||
return fmt.Errorf("%w: %w", errUpdatingBalance, errCredentialsEmpty)
|
||||
}
|
||||
|
||||
var errs error
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
for _, s := range subAccts {
|
||||
if !s.AssetType.IsValid() {
|
||||
errs = common.AppendError(errs, fmt.Errorf("error loading %s[%s] SubAccount holdings: %w", s.ID, s.AssetType, asset.ErrNotSupported))
|
||||
continue
|
||||
}
|
||||
|
||||
accBalances := a.currencyBalances(creds, s.ID, s.AssetType)
|
||||
|
||||
updated := false
|
||||
missing := maps.Clone(accBalances)
|
||||
for curr, newBal := range s.Balances {
|
||||
delete(missing, curr.Item)
|
||||
if newBal.UpdatedAt.IsZero() {
|
||||
newBal.UpdatedAt = time.Now()
|
||||
}
|
||||
if newBal.Currency.IsEmpty() {
|
||||
newBal.Currency = curr
|
||||
}
|
||||
s.Balances[curr] = newBal
|
||||
if u, err := accBalances.balance(curr.Item).update(newBal); err != nil {
|
||||
errs = common.AppendError(errs, fmt.Errorf("%w for account ID %q [%s %s]: %w", errUpdatingBalance, s.ID, s.AssetType, curr, err))
|
||||
} else if u {
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if isSnapshot {
|
||||
for cur := range missing {
|
||||
delete(accBalances, cur)
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
if err := a.mux.Publish(s, a.routingID); err != nil {
|
||||
errs = common.AppendError(errs, fmt.Errorf("%w for %s %w", errPublish, a.Exchange, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// Merge adds CurrencyBalances in s to the SubAccount in l with a matching AssetType and ID.
|
||||
// If no SubAccount matches, s is appended.
|
||||
// Duplicate Currency Balances are added together.
|
||||
func (l SubAccounts) Merge(s *SubAccount) SubAccounts {
|
||||
if err := common.NilGuard(s); err != nil {
|
||||
return nil
|
||||
}
|
||||
i := slices.IndexFunc(l, func(b *SubAccount) bool { return s.AssetType == b.AssetType && s.ID == b.ID })
|
||||
if i == -1 {
|
||||
return append(l, s)
|
||||
}
|
||||
for curr, newBal := range s.Balances {
|
||||
l[i].Balances[curr] = newBal.Add(l[i].Balances[curr])
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// currencyBalances returns a currencyBalances entry for Credentials, SubAccount and asset.
|
||||
// No nilguard protection provided, since this is a private function.
|
||||
func (a *Accounts) currencyBalances(c *Credentials, subAcct string, aType asset.Item) currencyBalances {
|
||||
k := key.SubAccountAsset{SubAccount: subAcct, Asset: aType}
|
||||
if _, ok := a.subAccounts[*c]; !ok {
|
||||
a.subAccounts[*c] = make(subAccounts)
|
||||
}
|
||||
if _, ok := a.subAccounts[*c][k]; !ok {
|
||||
a.subAccounts[*c][k] = make(currencyBalances)
|
||||
}
|
||||
return a.subAccounts[*c][k]
|
||||
}
|
||||
533
exchange/accounts/accounts_test.go
Normal file
533
exchange/accounts/accounts_test.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/dispatch"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
)
|
||||
|
||||
var (
|
||||
creds1 = &Credentials{Key: "1"}
|
||||
creds2 = &Credentials{Key: "2"}
|
||||
creds3 = &Credentials{Key: "3"}
|
||||
)
|
||||
|
||||
func TestNewAccounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
a, err := NewAccounts(&mockEx{"mocky"}, dispatch.GetNewMux(nil))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, a)
|
||||
assert.Equal(t, "mocky", a.Exchange.GetName(), "Exchange name should set correctly")
|
||||
assert.NotNil(t, a.subAccounts, "subAccounts should be initialised")
|
||||
assert.NotEmpty(t, a.routingID, "routingID should not be empty")
|
||||
assert.NotNil(t, a.mux, "mux should be set correctly")
|
||||
_, err = NewAccounts(nil, dispatch.GetNewMux(nil))
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer)
|
||||
_, err = NewAccounts(&mockEx{"mocky"}, nil)
|
||||
assert.ErrorContains(t, err, "nil pointer: *dispatch.Mux")
|
||||
}
|
||||
|
||||
func TestMustNewAccounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := MustNewAccounts(&mockEx{"mocky"})
|
||||
require.NotNil(t, a)
|
||||
require.Panics(t, func() { _ = MustNewAccounts(nil) })
|
||||
}
|
||||
|
||||
func TestNewSubAccount(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := NewSubAccount(asset.Spot, "")
|
||||
require.NotNil(t, a, "must not return nil with no id")
|
||||
assert.Equal(t, asset.Spot, a.AssetType, "AssetType should be correct")
|
||||
assert.Empty(t, a.ID, "ID should not default to anything")
|
||||
a = NewSubAccount(asset.Spot, "42")
|
||||
assert.Equal(t, "42", a.ID, "ID should be correct")
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := dispatch.EnsureRunning(dispatch.DefaultMaxWorkers, dispatch.DefaultJobsLimit)
|
||||
require.NoError(t, err, "dispatch.EnsureRunning must not error")
|
||||
p, err := MustNewAccounts(&mockEx{}).Subscribe()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p, "Subscribe must return a pipe")
|
||||
require.Empty(t, p.Channel(), "Pipe must be empty before Saving anything")
|
||||
}
|
||||
|
||||
func TestAccountsCurrencyBalances(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := accountsFixture(t)
|
||||
|
||||
_, err := (*Accounts)(nil).CurrencyBalances(nil, asset.Spot)
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
_, err = a.CurrencyBalances(nil, asset.Empty)
|
||||
assert.ErrorIs(t, err, asset.ErrNotSupported)
|
||||
|
||||
_, err = a.CurrencyBalances(creds3, asset.All)
|
||||
require.ErrorIs(t, err, ErrNoBalances)
|
||||
|
||||
_, err = a.CurrencyBalances(creds3, asset.All)
|
||||
assert.ErrorIs(t, err, ErrNoBalances)
|
||||
assert.ErrorContains(t, err, "Key:[3")
|
||||
|
||||
// Add a balance with inconsistent currencies to cover err from currs.Add
|
||||
a.subAccounts[*creds3] = map[key.SubAccountAsset]currencyBalances{
|
||||
{Asset: asset.Futures}: {currency.DOGE.Item: &balance{internal: Balance{Currency: currency.ETH}}},
|
||||
}
|
||||
|
||||
type cMap map[currency.Code]float64
|
||||
for _, tc := range []struct {
|
||||
c *Credentials
|
||||
aT asset.Item
|
||||
exp cMap
|
||||
err error
|
||||
}{
|
||||
{nil, asset.Spot, cMap{currency.BTC: 6.0, currency.LTC: 10.0}, nil},
|
||||
{creds1, asset.All, cMap{currency.BTC: 3.0, currency.LTC: 30.0}, nil},
|
||||
{creds1, asset.Spot, cMap{currency.BTC: 3.0, currency.LTC: 10.0}, nil},
|
||||
{creds1, asset.Futures, cMap{currency.LTC: 20.0}, nil},
|
||||
{creds2, asset.Spot, cMap{currency.BTC: 3.0}, nil},
|
||||
{creds3, asset.Futures, cMap{currency.DOGE: 50.0}, errBalanceCurrencyMismatch},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%s/%s", tc.c, tc.aT), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
b, err := a.CurrencyBalances(tc.c, tc.aT)
|
||||
if tc.err != nil {
|
||||
require.ErrorIs(t, err, tc.err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(tc.exp), len(b), "must get correct number of balances")
|
||||
for c, expBal := range tc.exp {
|
||||
assert.Contains(t, b, c)
|
||||
assert.Equalf(t, expBal, b[c].Total, "should get correct total for %s", c)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountsPrivateCurrencyBalances(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := accountsFixture(t)
|
||||
b := a.currencyBalances(creds3, "", asset.Spot)
|
||||
r1 := a.subAccounts[*creds3]
|
||||
// Using reflect since assert.Same cannot be used on maps to ensure same underlying pointer
|
||||
assert.Equal(t,
|
||||
reflect.ValueOf(b).UnsafePointer(),
|
||||
reflect.ValueOf(r1[key.SubAccountAsset{Asset: asset.Spot}]).UnsafePointer(),
|
||||
"should make and return the same map")
|
||||
assert.Equal(t,
|
||||
reflect.ValueOf(b).UnsafePointer(),
|
||||
reflect.ValueOf(a.currencyBalances(creds3, "", asset.Spot)).UnsafePointer(),
|
||||
"should return the same map on subsequent calls")
|
||||
b = a.currencyBalances(creds3, "", asset.Futures)
|
||||
assert.Equal(t,
|
||||
reflect.ValueOf(r1).UnsafePointer(),
|
||||
reflect.ValueOf(a.subAccounts[*creds3]).UnsafePointer(),
|
||||
"should not make a new cred key")
|
||||
assert.Equal(t,
|
||||
reflect.ValueOf(b).UnsafePointer(),
|
||||
reflect.ValueOf(r1[key.SubAccountAsset{Asset: asset.Futures}]).UnsafePointer(),
|
||||
"should make and return the same map")
|
||||
}
|
||||
|
||||
type tKey key.SubAccountAsset
|
||||
|
||||
func TestAccountsSubAccounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := accountsFixture(t)
|
||||
|
||||
_, err := (*Accounts)(nil).SubAccounts(nil, asset.Spot)
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
_, err = a.SubAccounts(nil, asset.Empty)
|
||||
assert.ErrorIs(t, err, asset.ErrNotSupported)
|
||||
|
||||
_, err = a.SubAccounts(creds3, asset.All)
|
||||
require.ErrorIs(t, err, ErrNoSubAccounts)
|
||||
require.ErrorContains(t, err, "Key:[3")
|
||||
|
||||
for _, tc := range []struct {
|
||||
c *Credentials
|
||||
aT asset.Item
|
||||
exp []tKey
|
||||
}{
|
||||
{nil, asset.All, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}, {"1b", asset.Futures}, {"2a", asset.Spot}}},
|
||||
{creds1, asset.All, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}, {"1b", asset.Futures}}},
|
||||
{creds1, asset.Spot, []tKey{{"1a", asset.Spot}, {"1b", asset.Spot}}},
|
||||
{creds1, asset.Futures, []tKey{{"1b", asset.Futures}}},
|
||||
{creds2, asset.Spot, []tKey{{"2a", asset.Spot}}},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%v/%s", tc.c, tc.aT), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
b, err := a.SubAccounts(tc.c, tc.aT)
|
||||
require.NoError(t, err)
|
||||
exp := subAccountsFixture(tc.exp)
|
||||
require.Equal(t, len(exp), len(b), "must get correct number of subAccounts")
|
||||
require.ElementsMatch(t, exp, b, "must get correct subAccounts")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountsGetBalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := accountsFixture(t)
|
||||
|
||||
_, err := (*Accounts)(nil).GetBalance("", nil, asset.Empty, currency.EMPTYCODE)
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
_, err = a.GetBalance("", nil, asset.Empty, currency.EMPTYCODE)
|
||||
assert.ErrorIs(t, err, asset.ErrNotSupported)
|
||||
|
||||
_, err = a.GetBalance("", nil, asset.Spot, currency.EMPTYCODE)
|
||||
assert.ErrorIs(t, err, errCredentialsEmpty)
|
||||
|
||||
_, err = a.GetBalance("", creds3, asset.Spot, currency.EMPTYCODE)
|
||||
assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty)
|
||||
|
||||
_, err = a.GetBalance("", creds3, asset.Spot, currency.DOGE)
|
||||
assert.ErrorIs(t, err, ErrNoBalances)
|
||||
assert.ErrorContains(t, err, "for Key:[3")
|
||||
|
||||
_, err = a.GetBalance("3a", creds1, asset.Spot, currency.DOGE)
|
||||
assert.ErrorIs(t, err, ErrNoBalances)
|
||||
assert.ErrorContains(t, err, `for mocky SubAccount "3a" spot`)
|
||||
|
||||
_, err = a.GetBalance("1a", creds1, asset.Spot, currency.DOGE)
|
||||
assert.ErrorIs(t, err, ErrNoBalances)
|
||||
assert.ErrorContains(t, err, `for mocky SubAccount "1a" spot DOGE`)
|
||||
|
||||
b, err := a.GetBalance("1b", creds1, asset.Spot, currency.BTC)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2.0, b.Total, "Total should be correct")
|
||||
}
|
||||
|
||||
func TestAccountsSave(t *testing.T) { //nolint:tparallel // Save's internal tests are sequential
|
||||
t.Parallel()
|
||||
|
||||
a := accountsFixture(t)
|
||||
relay := subscribeFixture(t, a)
|
||||
beforeNow := time.Now()
|
||||
|
||||
ctx := t.Context()
|
||||
assert.ErrorContains(t, (*Accounts)(nil).Save(ctx, nil, false), "nil pointer: *accounts.Accounts")
|
||||
assert.ErrorContains(t, new(Accounts).Save(ctx, nil, false), "nil pointer: accounts.credSubAccounts")
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
creds *Credentials
|
||||
snapshot bool
|
||||
accts SubAccounts
|
||||
pre func(context.Context) context.Context
|
||||
post func(t *testing.T) // Any additional assertions
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "NoCredentials",
|
||||
accts: SubAccounts{},
|
||||
err: errCredentialsEmpty,
|
||||
},
|
||||
{
|
||||
name: "BadCredentials",
|
||||
accts: SubAccounts{},
|
||||
err: common.ErrTypeAssertFailure,
|
||||
pre: func(ctx context.Context) context.Context { return context.WithValue(ctx, ContextCredentialsFlag, 42) },
|
||||
},
|
||||
{
|
||||
name: "BadAsset",
|
||||
creds: creds1,
|
||||
accts: SubAccounts{{AssetType: asset.All}},
|
||||
err: asset.ErrNotSupported,
|
||||
},
|
||||
{
|
||||
name: "CurrencyMismatch",
|
||||
creds: creds1,
|
||||
err: errBalanceCurrencyMismatch,
|
||||
accts: SubAccounts{{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1a",
|
||||
Balances: CurrencyBalances{currency.BTC: {Currency: currency.DOGE}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "OutOfSequence",
|
||||
creds: creds1,
|
||||
err: errOutOfSequence,
|
||||
accts: SubAccounts{{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1a",
|
||||
Balances: CurrencyBalances{currency.BTC: {UpdatedAt: skynetDate.Add(-time.Hour)}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "BasicSave",
|
||||
creds: creds1,
|
||||
accts: SubAccounts{
|
||||
{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1a",
|
||||
Balances: CurrencyBalances{currency.BTC: {Total: 4, UpdatedAt: skynetDate.Add(time.Minute)}},
|
||||
},
|
||||
{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1c",
|
||||
Balances: CurrencyBalances{currency.ETH: {Total: 6}},
|
||||
},
|
||||
},
|
||||
post: func(t *testing.T) {
|
||||
t.Helper()
|
||||
_, err := a.GetBalance("1a", creds1, asset.Spot, currency.LTC)
|
||||
require.NoError(t, err, "Other balances must not be affected")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewCredsSaveAndPublish",
|
||||
creds: creds3,
|
||||
accts: SubAccounts{
|
||||
{
|
||||
AssetType: asset.Futures,
|
||||
ID: "3a",
|
||||
Balances: CurrencyBalances{currency.DOGE: {Total: 6.2}},
|
||||
},
|
||||
},
|
||||
post: func(t *testing.T) {
|
||||
t.Helper()
|
||||
require.Eventually(t, func() bool { return len(relay) > 0 }, time.Second, time.Millisecond, "Publish must eventually send to Channel")
|
||||
pub := <-relay
|
||||
assert.Equal(t, "3a", pub.ID, "Publish should have correct ID")
|
||||
assert.Contains(t, pub.Balances, currency.DOGE, "Should get DOGE Balance")
|
||||
b := pub.Balances[currency.DOGE]
|
||||
assert.Equal(t, currency.DOGE, b.Currency, "Currency should default to the Balances map key")
|
||||
assert.WithinRange(t, b.UpdatedAt, beforeNow, time.Now(), "UpdatedAt should default to time.Now")
|
||||
assert.Equal(t, 6.2, b.Total, "Total should be correct")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SnapshotSave",
|
||||
creds: creds1,
|
||||
accts: SubAccounts{
|
||||
{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1a",
|
||||
Balances: CurrencyBalances{currency.LTC: {Total: 12}},
|
||||
},
|
||||
},
|
||||
snapshot: true,
|
||||
post: func(t *testing.T) {
|
||||
t.Helper()
|
||||
_, err := a.GetBalance("1a", creds1, asset.Spot, currency.BTC)
|
||||
require.ErrorIs(t, err, ErrNoBalances, "BTC balance must be removed")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PublishError",
|
||||
creds: creds1,
|
||||
accts: SubAccounts{
|
||||
{
|
||||
AssetType: asset.Spot,
|
||||
ID: "1a",
|
||||
Balances: CurrencyBalances{currency.DOGE: {Total: 7.2}},
|
||||
},
|
||||
},
|
||||
pre: func(ctx context.Context) context.Context {
|
||||
a.mux = nil
|
||||
return ctx
|
||||
},
|
||||
err: errPublish,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
if tc.creds != nil {
|
||||
ctx = DeployCredentialsToContext(ctx, tc.creds)
|
||||
}
|
||||
expAccts := tc.accts.clone()
|
||||
if tc.pre != nil {
|
||||
ctx = tc.pre(ctx)
|
||||
}
|
||||
err := a.Save(ctx, tc.accts, tc.snapshot)
|
||||
if tc.err != nil {
|
||||
require.ErrorIs(t, err, tc.err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
for i, acct := range tc.accts {
|
||||
for curr := range acct.Balances {
|
||||
t.Run(fmt.Sprintf("%s/%s/%s", acct.AssetType, acct.ID, curr), func(t *testing.T) {
|
||||
exp := expAccts[i].Balances[curr]
|
||||
got, err := a.GetBalance(acct.ID, tc.creds, acct.AssetType, curr)
|
||||
require.NoError(t, err, "GetBalance must not error")
|
||||
if !exp.UpdatedAt.IsZero() {
|
||||
assert.Equal(t, exp.UpdatedAt, got.UpdatedAt, "UpdatedAt should match balance")
|
||||
} else {
|
||||
assert.WithinRange(t, got.UpdatedAt, beforeNow, time.Now(), "UpdatedAt should default to time.Now")
|
||||
}
|
||||
assert.Equal(t, exp.Total, got.Total, "Total should be correct")
|
||||
if tc.post != nil {
|
||||
tc.post(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var skynetDate = time.Unix(872896440, 0)
|
||||
|
||||
// accountsFixture returns an Accounts store with SubAccount IDs per credentials, and a subscription channel for updates
|
||||
func accountsFixture(t *testing.T) *Accounts {
|
||||
t.Helper()
|
||||
a := MustNewAccounts(&mockEx{})
|
||||
for _, f := range []struct {
|
||||
c *Credentials
|
||||
sA string
|
||||
aT asset.Item
|
||||
cC currency.Code
|
||||
b float64
|
||||
}{
|
||||
{creds1, "1a", asset.Spot, currency.BTC, 1},
|
||||
{creds1, "1a", asset.Spot, currency.LTC, 10},
|
||||
{creds1, "1b", asset.Spot, currency.BTC, 2},
|
||||
{creds1, "1b", asset.Futures, currency.LTC, 20},
|
||||
{creds2, "2a", asset.Spot, currency.BTC, 3},
|
||||
} {
|
||||
// Not using t.Run because this is a helper
|
||||
u, err := a.currencyBalances(f.c, f.sA, f.aT).balance(f.cC.Item).update(Balance{Total: f.b, UpdatedAt: skynetDate})
|
||||
require.NoErrorf(t, err, "Deploy fixture balance must not error for %s/%s/%s/%s", f.c.Key, f.sA, f.aT, f.cC)
|
||||
require.Truef(t, u, "Deploy fixture balance must apply an update for %s/%s/%s/%s", f.c.Key, f.sA, f.aT, f.cC)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
var subAccts = SubAccounts{
|
||||
{
|
||||
ID: "1a",
|
||||
AssetType: asset.Spot,
|
||||
Balances: CurrencyBalances{
|
||||
currency.LTC: Balance{Currency: currency.LTC, Total: 10, UpdatedAt: skynetDate},
|
||||
currency.BTC: Balance{Currency: currency.BTC, Total: 1, UpdatedAt: skynetDate},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "1b",
|
||||
AssetType: asset.Spot,
|
||||
Balances: CurrencyBalances{currency.BTC: Balance{Currency: currency.BTC, Total: 2.0, UpdatedAt: skynetDate}},
|
||||
},
|
||||
{
|
||||
ID: "1b",
|
||||
AssetType: asset.Futures,
|
||||
Balances: CurrencyBalances{currency.LTC: Balance{Currency: currency.LTC, Total: 20.0, UpdatedAt: skynetDate}},
|
||||
},
|
||||
{
|
||||
ID: "2a",
|
||||
AssetType: asset.Spot,
|
||||
Balances: CurrencyBalances{currency.BTC: Balance{Currency: currency.BTC, Total: 3.0, UpdatedAt: skynetDate}},
|
||||
},
|
||||
}
|
||||
|
||||
func subAccountsFixture(keys []tKey) (a SubAccounts) {
|
||||
if keys == nil {
|
||||
return subAccts.clone()
|
||||
}
|
||||
for _, k := range keys {
|
||||
i := slices.IndexFunc(subAccts, func(s *SubAccount) bool {
|
||||
return k.SubAccount == s.ID && k.Asset == s.AssetType
|
||||
})
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("subAccountsFixture called with unknown subAccount key: %v", k))
|
||||
}
|
||||
a = append(a, subAccts[i])
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func subscribeFixture(t *testing.T, a *Accounts) chan *SubAccount {
|
||||
t.Helper()
|
||||
err := dispatch.EnsureRunning(dispatch.DefaultMaxWorkers, dispatch.DefaultJobsLimit)
|
||||
require.NoError(t, err, "dispatch.EnsureRunning must not error")
|
||||
p, err := a.Subscribe()
|
||||
require.NoError(t, err, "Subscribe must not error")
|
||||
require.NotNil(t, p, "Subscribe must return a pipe")
|
||||
relay := make(chan *SubAccount, 64)
|
||||
go func() {
|
||||
for v := range p.Channel() {
|
||||
if s, ok := v.(*SubAccount); ok && s.ID == "3a" { // Only interested in relaying events for a single test account
|
||||
relay <- s
|
||||
}
|
||||
}
|
||||
}()
|
||||
runtime.Gosched()
|
||||
return relay
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := subAccountsFixture(nil)
|
||||
assert.Nil(t, s.Merge(nil), "Should return nil for a merge of nil SubAccounts")
|
||||
exp := len(s)
|
||||
a := &SubAccount{
|
||||
ID: "1a",
|
||||
AssetType: asset.Spot,
|
||||
Balances: CurrencyBalances{currency.BTC: Balance{Total: 1}},
|
||||
}
|
||||
s = s.Merge(a)
|
||||
require.Equal(t, exp, len(s), "Must contain correct number of accounts after merging")
|
||||
|
||||
for _, acct := range s {
|
||||
if acct.ID == "1a" && acct.AssetType == asset.Spot {
|
||||
assert.Equal(t, 2.0, acct.Balances[currency.BTC].Total)
|
||||
}
|
||||
}
|
||||
|
||||
a = &SubAccount{
|
||||
ID: "new",
|
||||
AssetType: asset.Spot,
|
||||
Balances: CurrencyBalances{currency.BTC: Balance{Total: 1}},
|
||||
}
|
||||
s = s.Merge(a)
|
||||
assert.Contains(t, s, a, "Should contain the new subaccount")
|
||||
}
|
||||
|
||||
func TestSubAccountsClone(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := SubAccounts{
|
||||
{ID: "1", AssetType: asset.Spot, Balances: CurrencyBalances{currency.BTC: {Total: 1}}},
|
||||
{ID: "2", AssetType: asset.Futures, Balances: CurrencyBalances{currency.LTC: {Total: 2}}},
|
||||
}
|
||||
c := s.clone()
|
||||
require.Equal(t, s, c, "Clone must match original")
|
||||
c[0].ID = "3"
|
||||
assert.NotEqual(t, s, c, "Should not be equal after modification")
|
||||
}
|
||||
|
||||
func (l SubAccounts) clone() (c SubAccounts) {
|
||||
for _, s := range l {
|
||||
bals := make(CurrencyBalances, len(s.Balances))
|
||||
maps.Copy(bals, s.Balances)
|
||||
c = append(c, &SubAccount{
|
||||
ID: s.ID,
|
||||
AssetType: s.AssetType,
|
||||
Balances: bals,
|
||||
})
|
||||
}
|
||||
return c
|
||||
}
|
||||
148
exchange/accounts/balance.go
Normal file
148
exchange/accounts/balance.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
)
|
||||
|
||||
var (
|
||||
errBalanceCurrencyMismatch = errors.New("balance currency does not match update currency")
|
||||
errOutOfSequence = errors.New("out of sequence")
|
||||
errUpdatedAtIsZero = errors.New("updatedAt may not be zero")
|
||||
)
|
||||
|
||||
// Balance contains an exchange currency balance.
|
||||
type Balance struct {
|
||||
Currency currency.Code
|
||||
Total float64
|
||||
Hold float64
|
||||
Free float64
|
||||
AvailableWithoutBorrow float64
|
||||
Borrowed float64
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Change defines incoming balance change on currency holdings.
|
||||
type Change struct {
|
||||
Account string
|
||||
AssetType asset.Item
|
||||
Balance Balance
|
||||
}
|
||||
|
||||
// balance contains a balance with live updates.
|
||||
type balance struct {
|
||||
internal Balance
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
// CurrencyBalances provides a map of currencies to balances.
|
||||
type CurrencyBalances map[currency.Code]Balance
|
||||
|
||||
// currencyBalances provides a map of currencies to balances.
|
||||
type currencyBalances map[*currency.Item]*balance
|
||||
|
||||
// Set will set a currency balance, overwriting any previous Balance.
|
||||
//
|
||||
//nolint:gocritic // Ignoring hugeparam because we want the convenience of all callers passing by value
|
||||
//nolint:gocritic // and we want to store a copy anyway so the hugeparam warning that this copies a value is not relevant
|
||||
func (c *CurrencyBalances) Set(curr currency.Code, b Balance) {
|
||||
b.Currency = curr
|
||||
(*c)[curr] = b
|
||||
}
|
||||
|
||||
// Add will add to a currency balance.
|
||||
func (c *CurrencyBalances) Add(curr currency.Code, b Balance) error { //nolint:gocritic // hugeparam not relevant; we want to store a value so we'd deref anyway
|
||||
if curr == currency.EMPTYCODE {
|
||||
return currency.ErrCurrencyCodeEmpty
|
||||
}
|
||||
if b.Currency != currency.EMPTYCODE && !b.Currency.Equal(curr) {
|
||||
return fmt.Errorf("%w: %q != %q", errBalanceCurrencyMismatch, b.Currency, curr)
|
||||
}
|
||||
if e, ok := (*c)[curr]; !ok {
|
||||
b.Currency = curr
|
||||
(*c)[curr] = b
|
||||
} else {
|
||||
(*c)[curr] = e.Add(b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Balance returns a snapshot copy of the Balance.
|
||||
func (b *balance) Balance() Balance {
|
||||
b.m.RLock()
|
||||
defer b.m.RUnlock()
|
||||
return b.internal
|
||||
}
|
||||
|
||||
// Add returns a new Balance adding together a and b.
|
||||
// UpdatedAt is the later of the two Balances.
|
||||
func (b *Balance) Add(a Balance) Balance { //nolint:gocritic // hugeparam not relevant; We'd need to copy it in map iterations anyway
|
||||
var u time.Time
|
||||
if a.UpdatedAt.After(b.UpdatedAt) {
|
||||
u = a.UpdatedAt
|
||||
} else {
|
||||
u = b.UpdatedAt
|
||||
}
|
||||
return Balance{
|
||||
Total: b.Total + a.Total,
|
||||
Hold: b.Hold + a.Hold,
|
||||
Free: b.Free + a.Free,
|
||||
AvailableWithoutBorrow: b.AvailableWithoutBorrow + a.AvailableWithoutBorrow,
|
||||
Borrowed: b.Borrowed + a.Borrowed,
|
||||
UpdatedAt: u,
|
||||
}
|
||||
}
|
||||
|
||||
// Public returns a copy of the currencyBalances converted to CurrencyBalances for use outside this package.
|
||||
func (c currencyBalances) Public() CurrencyBalances {
|
||||
n := make(CurrencyBalances, len(c))
|
||||
for curr, bal := range c {
|
||||
n[curr.Currency()] = bal.Balance()
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// update checks that an incoming change has a valid change, and returns if the balances were changed.
|
||||
// If change does not have a Currency set, the existing Currency is preserved.
|
||||
func (b *balance) update(change Balance) (bool, error) { //nolint:gocritic // hugeparam not relevant; We'd need to copy it later anyway
|
||||
if err := common.NilGuard(b); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if change.UpdatedAt.IsZero() {
|
||||
return false, errUpdatedAtIsZero
|
||||
}
|
||||
b.m.Lock()
|
||||
defer b.m.Unlock()
|
||||
if b.internal.Currency != currency.EMPTYCODE {
|
||||
if change.Currency == currency.EMPTYCODE {
|
||||
change.Currency = b.internal.Currency
|
||||
} else if !change.Currency.Equal(b.internal.Currency) {
|
||||
return false, fmt.Errorf("%w %q != %q", errBalanceCurrencyMismatch, b.internal.Currency, change.Currency)
|
||||
}
|
||||
}
|
||||
if b.internal.UpdatedAt.After(change.UpdatedAt) {
|
||||
return false, errOutOfSequence
|
||||
}
|
||||
b.internal.UpdatedAt = change.UpdatedAt // Set just the time, and then can compare easily
|
||||
if b.internal == change {
|
||||
return false, nil
|
||||
}
|
||||
b.internal = change
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// balance returns a balance for a currency.
|
||||
func (c currencyBalances) balance(curr *currency.Item) *balance {
|
||||
b, ok := c[curr]
|
||||
if !ok {
|
||||
b = &balance{internal: Balance{Currency: curr.Currency()}}
|
||||
c[curr] = b
|
||||
}
|
||||
return b
|
||||
}
|
||||
152
exchange/accounts/balance_test.go
Normal file
152
exchange/accounts/balance_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
)
|
||||
|
||||
// TestCurrencyBalancesSet exercises CurrencyBalances.Set
|
||||
func TestCurrencyBalancesSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := CurrencyBalances{}
|
||||
|
||||
c.Set(currency.BTC, Balance{Total: 4.2})
|
||||
require.Contains(t, c, currency.BTC, "must add an entry to an uninitialised CurrencyBalances")
|
||||
assert.Equal(t, currency.BTC, c[currency.BTC].Currency, "should set the Currency")
|
||||
|
||||
c.Set(currency.LTC, Balance{Currency: currency.ETH, Total: 52.4})
|
||||
require.Contains(t, c, currency.LTC, "must add an entry to an existing CurrencyBalances")
|
||||
assert.Equal(t, currency.LTC, c[currency.LTC].Currency, "should overwrite Currency")
|
||||
}
|
||||
|
||||
// TestCurrencyBalancesAdd exercises CurrencyBalances.Add
|
||||
func TestCurrencyBalancesAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := CurrencyBalances{}
|
||||
assert.ErrorIs(t, c.Add(currency.EMPTYCODE, Balance{}), currency.ErrCurrencyCodeEmpty)
|
||||
|
||||
err := c.Add(currency.BTC, Balance{Total: 4.2})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, c, currency.BTC, "must add an entry to an uninitialised CurrencyBalances")
|
||||
assert.Equal(t, currency.BTC, c[currency.BTC].Currency, "should set the Currency")
|
||||
assert.Equal(t, 4.2, c[currency.BTC].Total, "should initialise the Total")
|
||||
|
||||
err = c.Add(currency.BTC, Balance{Total: 1.3, Hold: 2.4})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5.5, c[currency.BTC].Total, "should add to existing Total")
|
||||
assert.Equal(t, 2.4, c[currency.BTC].Hold, "should initialise Hold")
|
||||
|
||||
err = c.Add(currency.LTC, Balance{Currency: currency.LTC, Total: 14.3})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, c, currency.LTC, "must add an entry to an existing CurrencyBalances")
|
||||
assert.Equal(t, 14.3, c[currency.LTC].Total, "should add when Balance.Currency is equal")
|
||||
|
||||
err = c.Add(currency.ETH, Balance{Currency: currency.LTC, Total: 14.2})
|
||||
assert.ErrorIs(t, err, errBalanceCurrencyMismatch)
|
||||
}
|
||||
|
||||
// TestCurrencyBalancesPublic exercises currencyBalances.Public
|
||||
func TestCurrencyBalancesPublic(t *testing.T) {
|
||||
t.Parallel()
|
||||
b := (¤cyBalances{
|
||||
currency.BTC.Item: &balance{internal: Balance{Total: 4.2}},
|
||||
currency.LTC.Item: &balance{internal: Balance{Total: 1.7}},
|
||||
}).Public()
|
||||
require.Equal(t, 2, len(b), "Pulbic must return the correct number of Balances")
|
||||
require.Contains(t, b, currency.BTC)
|
||||
require.Contains(t, b, currency.LTC)
|
||||
assert.Equal(t, 4.2, b[currency.BTC].Total)
|
||||
assert.Equal(t, 1.7, b[currency.LTC].Total)
|
||||
}
|
||||
|
||||
// TestCurrencyBalancesBalance exercises currencyBalances.balance
|
||||
func TestCurrencyBalancesBalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := currencyBalances{}
|
||||
b := c.balance(currency.BTC.Item)
|
||||
require.NotNil(t, b)
|
||||
assert.Same(t, c[currency.BTC.Item], b, "should make and return the same entry")
|
||||
assert.Same(t, b, c.balance(currency.BTC.Item), "should make and return the same entry")
|
||||
}
|
||||
|
||||
// TestBalanceBalance exercises balance.Balance
|
||||
func TestBalanceBalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
b := &balance{internal: Balance{Currency: currency.BTC}}
|
||||
i := b.Balance()
|
||||
assert.Equal(t, b.internal, i)
|
||||
}
|
||||
|
||||
// TestBalanceAdd exercises Balance.Add
|
||||
func TestBalanceAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
n1 := time.Now()
|
||||
n2 := n1.Add(-2 * time.Minute)
|
||||
b := new(Balance).Add(Balance{Total: 4.2, UpdatedAt: n2})
|
||||
assert.Equal(t, 4.2, b.Total, "should initialise Total")
|
||||
assert.Equal(t, n2, b.UpdatedAt, "should set UpdatedAt")
|
||||
b = b.Add(Balance{Total: 1.3, Hold: 3.0, UpdatedAt: n1})
|
||||
assert.Equal(t, 5.5, b.Total, "should add to Total")
|
||||
assert.Equal(t, 3.0, b.Hold, "should initialise Hold")
|
||||
assert.Equal(t, n1, b.UpdatedAt, "should set UpdatedAt")
|
||||
b = b.Add(Balance{Total: 2.2, Hold: 4.0, UpdatedAt: n1.Add(-time.Minute)})
|
||||
assert.Equal(t, 7.7, b.Total, "should add to Total")
|
||||
assert.Equal(t, 7.0, b.Hold, "should add to Hold")
|
||||
assert.Equal(t, n1, b.UpdatedAt, "should keep newer UpdatedAt")
|
||||
}
|
||||
|
||||
// TestBalanceUpdate exercises balance.update
|
||||
func TestBalanceUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := (*balance)(nil).update(Balance{})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
n := time.Now()
|
||||
b := &balance{internal: Balance{
|
||||
Currency: currency.LTC,
|
||||
Total: 4.2,
|
||||
UpdatedAt: n,
|
||||
}}
|
||||
|
||||
_, err = b.update(Balance{})
|
||||
require.ErrorIs(t, err, errUpdatedAtIsZero)
|
||||
|
||||
_, err = b.update(Balance{UpdatedAt: n, Currency: currency.ETH})
|
||||
assert.ErrorIs(t, err, errBalanceCurrencyMismatch)
|
||||
|
||||
_, err = b.update(Balance{UpdatedAt: n.Add(-time.Millisecond)})
|
||||
assert.ErrorIs(t, err, errOutOfSequence, "should error when time out of sequence")
|
||||
|
||||
u, err := b.update(Balance{UpdatedAt: n, Total: 5.1})
|
||||
require.NoError(t, err, "must not error when time is the same instant and currency is empty")
|
||||
assert.Equal(t, 5.1, b.internal.Total, "Total should be correct")
|
||||
assert.True(t, u, "should return updated")
|
||||
|
||||
n = time.Now()
|
||||
u, err = b.update(Balance{UpdatedAt: n, Total: 5.1})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
|
||||
assert.False(t, u, "should not return updated when nothing really changed")
|
||||
|
||||
n = time.Now()
|
||||
u, err = b.update(Balance{UpdatedAt: n, Currency: currency.LTC, Total: 5.1})
|
||||
require.NoError(t, err, "must not error when Currency matches")
|
||||
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
|
||||
assert.False(t, u, "should return not updated when only time changed")
|
||||
|
||||
n = time.Now()
|
||||
u, err = b.update(Balance{UpdatedAt: n, Currency: currency.LTC, Total: 4.4})
|
||||
require.NoError(t, err, "must not error when Currency matches")
|
||||
assert.Equal(t, n, b.internal.UpdatedAt, "should update UpdatedAt")
|
||||
assert.Equal(t, 4.4, b.internal.Total, "should update Total")
|
||||
assert.True(t, u, "should return updated")
|
||||
}
|
||||
212
exchange/accounts/credentials.go
Normal file
212
exchange/accounts/credentials.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// contextCredential is a string flag for use with context values when setting
|
||||
// credentials internally or via gRPC.
|
||||
type contextCredential string
|
||||
|
||||
const (
|
||||
// ContextCredentialsFlag used for retrieving api credentials from context
|
||||
ContextCredentialsFlag contextCredential = "apicredentials"
|
||||
// ContextSubAccountFlag used for retrieving just the sub account from
|
||||
// context, when the default config credentials sub account needs to be
|
||||
// changed while the same keys can be used.
|
||||
ContextSubAccountFlag contextCredential = "subaccountoverride"
|
||||
|
||||
apiKeyDisplaySize = 16
|
||||
)
|
||||
|
||||
// Default credential values
|
||||
const (
|
||||
Key = "key"
|
||||
Secret = "secret"
|
||||
SubAccountSTR = "subaccount"
|
||||
ClientID = "clientid"
|
||||
OneTimePassword = "otp"
|
||||
PEMKey = "pemkey"
|
||||
)
|
||||
|
||||
var (
|
||||
errMetaDataIsNil = errors.New("meta data is nil")
|
||||
errInvalidCredentialMetaDataLength = errors.New("invalid meta data to process credentials")
|
||||
errMissingInfo = errors.New("cannot parse meta data missing information in key value pair")
|
||||
)
|
||||
|
||||
// Credentials define parameters that allow for an authenticated request.
|
||||
type Credentials struct {
|
||||
Key string
|
||||
Secret string
|
||||
ClientID string // TODO: Implement with exchange orders functionality
|
||||
PEMKey string
|
||||
SubAccount string
|
||||
OneTimePassword string
|
||||
SecretBase64Decoded bool
|
||||
// TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities.
|
||||
}
|
||||
|
||||
// GetMetaData returns the credentials for metadata context deployment
|
||||
func (c *Credentials) GetMetaData() (flag, values string) {
|
||||
vals := make([]string, 0, 6)
|
||||
if c.Key != "" {
|
||||
vals = append(vals, Key+":"+c.Key)
|
||||
}
|
||||
if c.Secret != "" {
|
||||
vals = append(vals, Secret+":"+c.Secret)
|
||||
}
|
||||
if c.SubAccount != "" {
|
||||
vals = append(vals, SubAccountSTR+":"+c.SubAccount)
|
||||
}
|
||||
if c.ClientID != "" {
|
||||
vals = append(vals, ClientID+":"+c.ClientID)
|
||||
}
|
||||
if c.PEMKey != "" {
|
||||
vals = append(vals, PEMKey+":"+c.PEMKey)
|
||||
}
|
||||
if c.OneTimePassword != "" {
|
||||
vals = append(vals, OneTimePassword+":"+c.OneTimePassword)
|
||||
}
|
||||
return string(ContextCredentialsFlag), strings.Join(vals, ",")
|
||||
}
|
||||
|
||||
// String prints out basic credential info (obfuscated) to track key instances
|
||||
// associated with exchanges.
|
||||
func (c *Credentials) String() string {
|
||||
obfuscated := c.Key
|
||||
if len(obfuscated) > apiKeyDisplaySize {
|
||||
obfuscated = obfuscated[:apiKeyDisplaySize]
|
||||
}
|
||||
return fmt.Sprintf("Key:[%s...] SubAccount:[%s] ClientID:[%s]",
|
||||
obfuscated,
|
||||
c.SubAccount,
|
||||
c.ClientID)
|
||||
}
|
||||
|
||||
// getInternal returns the values for assignment to an internal context
|
||||
func (c *Credentials) getInternal() (contextCredential, *ContextCredentialsStore) {
|
||||
if c.IsEmpty() {
|
||||
return "", nil
|
||||
}
|
||||
store := &ContextCredentialsStore{}
|
||||
store.Load(c)
|
||||
return ContextCredentialsFlag, store
|
||||
}
|
||||
|
||||
// IsEmpty return true if the underlying credentials type has not been filled
|
||||
// with at least one item.
|
||||
func (c *Credentials) IsEmpty() bool {
|
||||
return c == nil || c.ClientID == "" &&
|
||||
c.Key == "" &&
|
||||
c.OneTimePassword == "" &&
|
||||
c.PEMKey == "" &&
|
||||
c.Secret == "" &&
|
||||
c.SubAccount == ""
|
||||
}
|
||||
|
||||
// Equal determines if the keys are the same.
|
||||
// OTP omitted because it's generated per request.
|
||||
// PEMKey and Secret omitted because of direct correlation with api key.
|
||||
func (c *Credentials) Equal(other *Credentials) bool {
|
||||
return c != nil &&
|
||||
other != nil &&
|
||||
c.Key == other.Key &&
|
||||
c.ClientID == other.ClientID &&
|
||||
(c.SubAccount == other.SubAccount || c.SubAccount == "" && other.SubAccount == "main" || c.SubAccount == "main" && other.SubAccount == "")
|
||||
}
|
||||
|
||||
// ContextCredentialsStore protects the stored credentials for use in a context
|
||||
type ContextCredentialsStore struct {
|
||||
creds *Credentials
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Load stores provided credentials
|
||||
func (c *ContextCredentialsStore) Load(creds *Credentials) {
|
||||
// Segregate from external call
|
||||
cpy := *creds
|
||||
c.mu.Lock()
|
||||
c.creds = &cpy
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the full credentials from the store
|
||||
func (c *ContextCredentialsStore) Get() *Credentials {
|
||||
c.mu.RLock()
|
||||
creds := *c.creds
|
||||
c.mu.RUnlock()
|
||||
return &creds
|
||||
}
|
||||
|
||||
// ParseCredentialsMetadata intercepts and converts credentials metadata to a
|
||||
// static type for authentication processing and protection.
|
||||
func ParseCredentialsMetadata(ctx context.Context, md metadata.MD) (context.Context, error) {
|
||||
if md == nil {
|
||||
return ctx, errMetaDataIsNil
|
||||
}
|
||||
|
||||
credMD, ok := md[string(ContextCredentialsFlag)]
|
||||
if !ok || len(credMD) == 0 {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
if len(credMD) != 1 {
|
||||
return ctx, errInvalidCredentialMetaDataLength
|
||||
}
|
||||
|
||||
segregatedCreds := strings.Split(credMD[0], ",")
|
||||
var ctxCreds Credentials
|
||||
var subAccountHere string
|
||||
for x := range segregatedCreds {
|
||||
keyVals := strings.Split(segregatedCreds[x], ":")
|
||||
if len(keyVals) != 2 {
|
||||
return ctx, fmt.Errorf("%w received %v fields, expected 2 contains: %s",
|
||||
errMissingInfo,
|
||||
len(keyVals),
|
||||
keyVals)
|
||||
}
|
||||
switch keyVals[0] {
|
||||
case Key:
|
||||
ctxCreds.Key = keyVals[1]
|
||||
case Secret:
|
||||
ctxCreds.Secret = keyVals[1]
|
||||
case SubAccountSTR:
|
||||
// Capture sub account as this can override if other values are
|
||||
// not included in metadata.
|
||||
subAccountHere = keyVals[1]
|
||||
case ClientID:
|
||||
ctxCreds.ClientID = keyVals[1]
|
||||
case PEMKey:
|
||||
ctxCreds.PEMKey = keyVals[1]
|
||||
case OneTimePassword:
|
||||
ctxCreds.OneTimePassword = keyVals[1]
|
||||
}
|
||||
}
|
||||
if ctxCreds.IsEmpty() && subAccountHere != "" {
|
||||
// This will override default sub account details if needed.
|
||||
return DeploySubAccountOverrideToContext(ctx, subAccountHere), nil
|
||||
}
|
||||
// merge sub account to main context credentials
|
||||
ctxCreds.SubAccount = subAccountHere
|
||||
return DeployCredentialsToContext(ctx, &ctxCreds), nil
|
||||
}
|
||||
|
||||
// DeployCredentialsToContext sets credentials for internal use to context which
|
||||
// can override default credential values.
|
||||
func DeployCredentialsToContext(ctx context.Context, creds *Credentials) context.Context {
|
||||
flag, store := creds.getInternal()
|
||||
return context.WithValue(ctx, flag, store)
|
||||
}
|
||||
|
||||
// DeploySubAccountOverrideToContext sets subaccount as override to credentials
|
||||
// as a separate flag.
|
||||
func DeploySubAccountOverrideToContext(ctx context.Context, subAccount string) context.Context {
|
||||
return context.WithValue(ctx, ContextSubAccountFlag, subAccount)
|
||||
}
|
||||
174
exchange/accounts/credentials_test.go
Normal file
174
exchange/accounts/credentials_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestIsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c *Credentials
|
||||
if !c.IsEmpty() {
|
||||
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
|
||||
}
|
||||
c = new(Credentials)
|
||||
if !c.IsEmpty() {
|
||||
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
|
||||
}
|
||||
|
||||
c.SubAccount = "woow"
|
||||
if c.IsEmpty() {
|
||||
t.Fatalf("expected: %v but received: %v", false, c.IsEmpty())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCredentialsMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := ParseCredentialsMetadata(t.Context(), nil)
|
||||
require.ErrorIs(t, err, errMetaDataIsNil)
|
||||
|
||||
_, err = ParseCredentialsMetadata(t.Context(), metadata.MD{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(t.Context(),
|
||||
string(ContextCredentialsFlag), "wow", string(ContextCredentialsFlag), "wow2")
|
||||
nortyMD, _ := metadata.FromOutgoingContext(ctx)
|
||||
|
||||
_, err = ParseCredentialsMetadata(t.Context(), nortyMD)
|
||||
require.ErrorIs(t, err, errInvalidCredentialMetaDataLength)
|
||||
|
||||
ctx = metadata.AppendToOutgoingContext(t.Context(),
|
||||
string(ContextCredentialsFlag), "brokenstring")
|
||||
nortyMD, _ = metadata.FromOutgoingContext(ctx)
|
||||
|
||||
_, err = ParseCredentialsMetadata(t.Context(), nortyMD)
|
||||
require.ErrorIs(t, err, errMissingInfo)
|
||||
|
||||
beforeCreds := Credentials{
|
||||
Key: "superkey",
|
||||
Secret: "supersecret",
|
||||
SubAccount: "supersub",
|
||||
ClientID: "superclient",
|
||||
PEMKey: "superpem",
|
||||
OneTimePassword: "superOneTimePasssssss",
|
||||
}
|
||||
|
||||
flag, outGoing := beforeCreds.GetMetaData()
|
||||
ctx = metadata.AppendToOutgoingContext(t.Context(), flag, outGoing)
|
||||
lovelyMD, _ := metadata.FromOutgoingContext(ctx)
|
||||
|
||||
ctx, err = ParseCredentialsMetadata(t.Context(), lovelyMD)
|
||||
require.NoError(t, err)
|
||||
|
||||
store, ok := ctx.Value(ContextCredentialsFlag).(*ContextCredentialsStore)
|
||||
if !ok {
|
||||
t.Fatal("should have processed")
|
||||
}
|
||||
|
||||
afterCreds := store.Get()
|
||||
|
||||
if afterCreds.Key != "superkey" &&
|
||||
afterCreds.Secret != "supersecret" &&
|
||||
afterCreds.SubAccount != "supersub" &&
|
||||
afterCreds.ClientID != "superclient" &&
|
||||
afterCreds.PEMKey != "superpem" &&
|
||||
afterCreds.OneTimePassword != "superOneTimePasssssss" {
|
||||
t.Fatal("unexpected values")
|
||||
}
|
||||
|
||||
// subaccount override
|
||||
subaccount := Credentials{
|
||||
SubAccount: "supersub",
|
||||
}
|
||||
|
||||
flag, outGoing = subaccount.GetMetaData()
|
||||
ctx = metadata.AppendToOutgoingContext(t.Context(), flag, outGoing)
|
||||
lovelyMD, _ = metadata.FromOutgoingContext(ctx)
|
||||
|
||||
ctx, err = ParseCredentialsMetadata(t.Context(), lovelyMD)
|
||||
require.NoError(t, err)
|
||||
|
||||
sa, ok := ctx.Value(ContextSubAccountFlag).(string)
|
||||
if !ok {
|
||||
t.Fatal("should have processed")
|
||||
}
|
||||
|
||||
if sa != "supersub" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInternal(t *testing.T) {
|
||||
t.Parallel()
|
||||
flag, store := (&Credentials{}).getInternal()
|
||||
if flag != "" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
if store != nil {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
flag, store = (&Credentials{Key: "wow"}).getInternal()
|
||||
if flag != ContextCredentialsFlag {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
if store == nil {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
if store.Get().Key != "wow" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
t.Parallel()
|
||||
creds := Credentials{}
|
||||
if s := creds.String(); s != "Key:[...] SubAccount:[] ClientID:[]" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
|
||||
creds.Key = "12345678910111234"
|
||||
creds.SubAccount = "sub"
|
||||
creds.ClientID = "client"
|
||||
|
||||
if s := creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
var this, that *Credentials
|
||||
if this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
this = &Credentials{}
|
||||
if this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
that = &Credentials{Key: "1337"}
|
||||
if this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
this.Key = "1337"
|
||||
if !this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
this.ClientID = "1337"
|
||||
if this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
that.ClientID = "1337"
|
||||
if !this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
this.SubAccount = "someSub"
|
||||
if this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
that.SubAccount = "someSub"
|
||||
if !this.Equal(that) {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
65
exchange/accounts/store.go
Normal file
65
exchange/accounts/store.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/dispatch"
|
||||
)
|
||||
|
||||
// Store contains accounts for exchanges.
|
||||
type Store struct {
|
||||
exchangeAccounts exchangeMap
|
||||
mu sync.Mutex
|
||||
mux *dispatch.Mux
|
||||
}
|
||||
|
||||
type exchangeMap map[exchange]*Accounts
|
||||
|
||||
type exchange interface {
|
||||
GetName() string
|
||||
GetCredentials(context.Context) (*Credentials, error)
|
||||
}
|
||||
|
||||
type exchangeWrapper interface {
|
||||
GetBase() exchange
|
||||
}
|
||||
|
||||
var global atomic.Pointer[Store]
|
||||
|
||||
// NewStore returns a new store with the default global dispatcher mux.
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
exchangeAccounts: make(exchangeMap),
|
||||
mux: dispatch.GetNewMux(nil),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStore returns the singleton accounts store for global use; Initialising if necessary.
|
||||
func GetStore() *Store {
|
||||
if s := global.Load(); s != nil {
|
||||
return s
|
||||
}
|
||||
_ = global.CompareAndSwap(nil, NewStore())
|
||||
return global.Load()
|
||||
}
|
||||
|
||||
// GetExchangeAccounts returns accounts for a specific exchange.
|
||||
func (s *Store) GetExchangeAccounts(e exchange) (a *Accounts, err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if w, ok := e.(exchangeWrapper); ok {
|
||||
// Because SetupDefaults is called on Base, it's easiest to just use the Base pointer as the key
|
||||
e = w.GetBase()
|
||||
}
|
||||
a, ok := s.exchangeAccounts[e]
|
||||
if !ok {
|
||||
a, err = NewAccounts(e, s.mux)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.exchangeAccounts[e] = a
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
86
exchange/accounts/store_test.go
Normal file
86
exchange/accounts/store_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package accounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
)
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStore()
|
||||
require.NotNil(t, s, "NewStore must return a store")
|
||||
require.NotNil(t, s.mux, "NewStore must set mux")
|
||||
require.NotNil(t, s.exchangeAccounts, "NewStore must set exchangeAccounts")
|
||||
}
|
||||
|
||||
func TestGetStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Initialise global in case of -count=N+; No other tests should be relying on it
|
||||
global.Store(nil)
|
||||
s := GetStore()
|
||||
require.NotNil(t, s, "GetStore must return a Store")
|
||||
require.Same(t, global.Load(), s, "GetStore must set the global store")
|
||||
require.Same(t, s, GetStore(), "GetStore must return the global store on second call")
|
||||
}
|
||||
|
||||
func TestGetExchangeAccounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := NewStore()
|
||||
m := &mockEx{"mocky"}
|
||||
a := &Accounts{}
|
||||
s.exchangeAccounts[m] = a
|
||||
got, err := s.GetExchangeAccounts(m)
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, a, got, "Should retrieve same existing Accounts")
|
||||
|
||||
m = &mockEx{"new"}
|
||||
got, err = s.GetExchangeAccounts(m)
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, s.exchangeAccounts[m], got, "Should retrieve the new exchange")
|
||||
|
||||
w := &mockExBase{m}
|
||||
got, err = s.GetExchangeAccounts(w)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
|
||||
_, err = s.GetExchangeAccounts(nil)
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly on nil exchange")
|
||||
}
|
||||
|
||||
type mockEx struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockEx) GetName() string {
|
||||
return "mocky"
|
||||
}
|
||||
|
||||
func (m *mockEx) GetCredentials(ctx context.Context) (*Credentials, error) {
|
||||
if value := ctx.Value(ContextCredentialsFlag); value != nil {
|
||||
if s, ok := value.(*ContextCredentialsStore); ok {
|
||||
return s.Get(), nil
|
||||
}
|
||||
return nil, common.GetTypeAssertError("*accounts.ContextCredentialsStore", value)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockExBase struct {
|
||||
base exchange
|
||||
}
|
||||
|
||||
func (m *mockExBase) GetBase() exchange {
|
||||
return m.base
|
||||
}
|
||||
|
||||
func (m *mockExBase) GetCredentials(ctx context.Context) (*Credentials, error) {
|
||||
return m.base.GetCredentials(ctx)
|
||||
}
|
||||
|
||||
func (m *mockExBase) GetName() string {
|
||||
return m.base.GetName()
|
||||
}
|
||||
Reference in New Issue
Block a user