diff --git a/exchanges/account/account.go b/exchanges/account/account.go index bd3b8a9c..146016e7 100644 --- a/exchanges/account/account.go +++ b/exchanges/account/account.go @@ -16,6 +16,29 @@ func init() { service.mux = dispatch.GetNewMux() } +// CollectBalances converts a map of sub-account balances into a slice +func CollectBalances(accountBalances map[string][]Balance, assetType asset.Item) (accounts []SubAccount, err error) { + if accountBalances == nil { + return nil, errAccountBalancesIsNil + } + + if !assetType.IsValid() { + return nil, fmt.Errorf("%s, %w", assetType, asset.ErrNotSupported) + } + + accounts = make([]SubAccount, len(accountBalances)) + i := 0 + for accountID, balances := range accountBalances { + accounts[i] = SubAccount{ + ID: accountID, + AssetType: assetType, + Currencies: balances, + } + i++ + } + return +} + // SubscribeToExchangeAccount subcribes to your exchange account func SubscribeToExchangeAccount(exchange string) (dispatch.Pipe, error) { exchange = strings.ToLower(exchange) diff --git a/exchanges/account/account_test.go b/exchanges/account/account_test.go index 16c9ceff..5f688229 100644 --- a/exchanges/account/account_test.go +++ b/exchanges/account/account_test.go @@ -10,6 +10,47 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" ) +func TestCollectBalances(t *testing.T) { + accounts, err := CollectBalances( + map[string][]Balance{ + "someAccountID": { + {CurrencyName: currency.BTC, TotalValue: 40000, Hold: 1}, + }, + }, + asset.Spot, + ) + subAccount := accounts[0] + balance := subAccount.Currencies[0] + if subAccount.ID != "someAccountID" { + t.Error("subAccount ID not set correctly") + } + if subAccount.AssetType != asset.Spot { + t.Error("subAccount AssetType not set correctly") + } + if balance.CurrencyName != currency.BTC || balance.TotalValue != 40000 || balance.Hold != 1 { + t.Error("subAccount currency balance not set correctly") + } + if err != nil { + t.Error("err is not expected") + } + + accounts, err = CollectBalances(map[string][]Balance{}, asset.Spot) + if len(accounts) != 0 { + t.Error("accounts should be empty") + } + if err != nil { + t.Error("err is not expected") + } + + accounts, err = CollectBalances(nil, asset.Spot) + if len(accounts) != 0 { + t.Error("accounts should be empty") + } + if err == nil { + t.Errorf("expecting err %s", errAccountBalancesIsNil.Error()) + } +} + func TestHoldings(t *testing.T) { err := dispatch.Start(dispatch.DefaultMaxWorkers, dispatch.DefaultJobsLimit) if err != nil { diff --git a/exchanges/account/account_types.go b/exchanges/account/account_types.go index 7e1b2339..c1815d40 100644 --- a/exchanges/account/account_types.go +++ b/exchanges/account/account_types.go @@ -1,6 +1,7 @@ package account import ( + "errors" "sync" "github.com/gofrs/uuid" @@ -11,7 +12,8 @@ import ( // Vars for the ticker package var ( - service *Service + service *Service + errAccountBalancesIsNil = errors.New("account balances is nil") ) // Service holds ticker information for each individual exchange diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index e554aefd..406c7178 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -745,16 +745,21 @@ func (b *Binance) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( if err != nil { return info, err } - var currencyDetails []account.Balance + accountCurrencyDetails := make(map[string][]account.Balance) for i := range accData { - currencyDetails = append(currencyDetails, account.Balance{ - CurrencyName: currency.NewCode(accData[i].Asset), - TotalValue: accData[i].Balance, - Hold: accData[i].Balance - accData[i].AvailableBalance, - }) + currencyDetails := accountCurrencyDetails[accData[i].AccountAlias] + accountCurrencyDetails[accData[i].AccountAlias] = append( + currencyDetails, account.Balance{ + CurrencyName: currency.NewCode(accData[i].Asset), + TotalValue: accData[i].Balance, + Hold: accData[i].Balance - accData[i].AvailableBalance, + }, + ) } - acc.Currencies = currencyDetails + if info.Accounts, err = account.CollectBalances(accountCurrencyDetails, assetType); err != nil { + return account.Holdings{}, err + } case asset.Margin: accData, err := b.GetMarginAccount(ctx) if err != nil { diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 90fbbb04..2abbbfac 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -426,28 +426,28 @@ func (b *Bitmex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a return info, err } - var accountID string - var balances []account.Balance + accountBalances := make(map[string][]account.Balance) // Need to update to add Margin/Liquidity availability for i := range userMargins { - accountID = strconv.FormatInt(userMargins[i].Account, 10) + accountID := strconv.FormatInt(userMargins[i].Account, 10) - wallet, err := b.GetWalletInfo(ctx, userMargins[i].Currency) + var wallet WalletInfo + wallet, err = b.GetWalletInfo(ctx, userMargins[i].Currency) if err != nil { continue } - balances = append(balances, account.Balance{ - CurrencyName: currency.NewCode(wallet.Currency), - TotalValue: wallet.Amount, - }) + accountBalances[accountID] = append( + accountBalances[accountID], account.Balance{ + CurrencyName: currency.NewCode(wallet.Currency), + TotalValue: wallet.Amount, + }, + ) } - info.Accounts = append(info.Accounts, - account.SubAccount{ - ID: accountID, - Currencies: balances, - }) + if info.Accounts, err = account.CollectBalances(accountBalances, assetType); err != nil { + return account.Holdings{}, err + } info.Exchange = b.Name if err := account.Process(&info); err != nil { diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 9d117d34..fed210f1 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -328,19 +328,21 @@ func (c *CoinbasePro) UpdateAccountInfo(ctx context.Context, assetType asset.Ite return response, err } - var currencies []account.Balance + accountCurrencies := make(map[string][]account.Balance) for i := range accountBalance { var exchangeCurrency account.Balance exchangeCurrency.CurrencyName = currency.NewCode(accountBalance[i].Currency) exchangeCurrency.TotalValue = accountBalance[i].Available exchangeCurrency.Hold = accountBalance[i].Hold - currencies = append(currencies, exchangeCurrency) + profileID := accountBalance[i].ProfileID + currencies := accountCurrencies[profileID] + accountCurrencies[profileID] = append(currencies, exchangeCurrency) } - response.Accounts = append(response.Accounts, account.SubAccount{ - Currencies: currencies, - }) + if response.Accounts, err = account.CollectBalances(accountCurrencies, assetType); err != nil { + return account.Holdings{}, err + } err = account.Process(&response) if err != nil {