From 663e753f52f772badaa6dd33b14bbed648aa8f9b Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Thu, 21 Jul 2022 15:05:31 +1000 Subject: [PATCH] account: segregate holdings by credentials for future multi-key management (#956) * exchanges/account: shift credentials to account package and segregate funds to keys * merge: fixes * linter: fix * Update exchanges/account/account.go Co-authored-by: Scott * glorious: nits + protection for string panic * glorious_suggestion: add method for matching keys * linter: fix tests * account: add protected method for credentials minimizing access, display full account details to rpc. * linter: spelling kweeeeeeen * accounts/portfolio: clean/check portfolio code and quickly check balances from change. Add protected method for future matching. * accounts: theres no point in pointerising everything * linter: ok pointerise this then... * exchanges: fix regression add in little notes. * glorious: nits * Update exchanges/account/credentials.go Co-authored-by: Scott * Update exchanges/account/credentials_test.go Co-authored-by: Scott * Update exchanges/account/credentials_test.go Co-authored-by: Scott * glorious: nits * gloriously: fix glorious glorious test gloriously Co-authored-by: Ryan O'Hara-Reid Co-authored-by: Scott --- cmd/exchange_template/wrapper_file.tmpl | 10 + cmd/gctcli/main.go | 4 +- engine/portfolio_manager.go | 88 +++---- engine/rpcserver.go | 4 +- engine/rpcserver_test.go | 16 +- exchanges/account/account.go | 126 ++++++--- exchanges/account/account_test.go | 85 ++++--- exchanges/account/account_types.go | 22 +- exchanges/account/credentials.go | 218 ++++++++++++++++ exchanges/account/credentials_test.go | 240 ++++++++++++++++++ exchanges/alphapoint/alphapoint_wrapper.go | 13 +- exchanges/binance/binance_wrapper.go | 12 +- exchanges/bitfinex/bitfinex_wrapper.go | 13 +- exchanges/bitflyer/bitflyer_wrapper.go | 7 +- exchanges/bithumb/bithumb_wrapper.go | 13 +- exchanges/bitmex/bitmex_wrapper.go | 13 +- exchanges/bitstamp/bitstamp_wrapper.go | 13 +- exchanges/bittrex/bittrex_wrapper.go | 12 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 12 +- exchanges/btse/btse_wrapper.go | 12 +- .../coinbasepro/coinbasepro_websocket.go | 4 +- exchanges/coinbasepro/coinbasepro_wrapper.go | 13 +- exchanges/coinut/coinut.go | 3 +- exchanges/coinut/coinut_test.go | 4 +- exchanges/coinut/coinut_wrapper.go | 13 +- exchanges/credentials.go | 212 ++-------------- exchanges/credentials_test.go | 174 ++----------- exchanges/exchange.go | 3 +- exchanges/exchange_types.go | 3 +- exchanges/exmo/exmo_wrapper.go | 13 +- exchanges/ftx/ftx_wrapper.go | 9 +- exchanges/gateio/gateio_wrapper.go | 13 +- exchanges/gemini/gemini_wrapper.go | 13 +- exchanges/hitbtc/hitbtc_wrapper.go | 12 +- exchanges/huobi/huobi_websocket.go | 10 +- exchanges/huobi/huobi_wrapper.go | 12 +- exchanges/interfaces.go | 2 +- exchanges/itbit/itbit_wrapper.go | 13 +- exchanges/kraken/kraken_wrapper.go | 12 +- exchanges/lbank/lbank_test.go | 4 +- exchanges/lbank/lbank_wrapper.go | 12 +- .../localbitcoins/localbitcoins_wrapper.go | 13 +- exchanges/okgroup/okgroup.go | 3 +- exchanges/okgroup/okgroup_wrapper.go | 13 +- exchanges/poloniex/poloniex_websocket.go | 5 +- exchanges/poloniex/poloniex_wrapper.go | 12 +- exchanges/yobit/yobit_wrapper.go | 13 +- exchanges/zb/zb_wrapper.go | 13 +- 48 files changed, 1010 insertions(+), 549 deletions(-) create mode 100644 exchanges/account/credentials.go create mode 100644 exchanges/account/credentials_test.go diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index b01c5169..908b13c6 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -368,6 +368,16 @@ func ({{.Variable}} *{{.CapitalName}}) UpdateAccountInfo(ctx context.Context, as // FetchAccountInfo retrieves balances for all enabled currencies func ({{.Variable}} *{{.CapitalName}}) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { + // Example implementation below: + // creds, err := {{.Variable}}.GetCredentials(ctx) + // if err != nil { + // return account.Holdings{}, err + // } + // acc, err := account.GetHoldings({{.Variable}}.Name, creds, assetType) + // if err != nil { + // return {{.Variable}}.UpdateAccountInfo(ctx, assetType) + // } + // return acc, nil return account.Holdings{}, common.ErrNotYetImplemented } diff --git a/cmd/gctcli/main.go b/cmd/gctcli/main.go index f368bb77..60dab27c 100644 --- a/cmd/gctcli/main.go +++ b/cmd/gctcli/main.go @@ -12,7 +12,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/core" - exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/gctrpc/auth" "github.com/thrasher-corp/gocryptotrader/signaler" "github.com/urfave/cli/v2" @@ -28,7 +28,7 @@ var ( pairDelimiter string certPath string timeout time.Duration - exchangeCreds exchange.Credentials + exchangeCreds account.Credentials verbose bool ) diff --git a/engine/portfolio_manager.go b/engine/portfolio_manager.go index 7bda1633..11fef5e5 100644 --- a/engine/portfolio_manager.go +++ b/engine/portfolio_manager.go @@ -57,10 +57,7 @@ func setupPortfolioManager(e *ExchangeManager, portfolioManagerDelay time.Durati // IsRunning safely checks whether the subsystem is running func (m *portfolioManager) IsRunning() bool { - if m == nil { - return false - } - return atomic.LoadInt32(&m.started) == 1 + return m != nil && atomic.LoadInt32(&m.started) == 1 } // Start runs the subsystem @@ -160,11 +157,10 @@ func (m *portfolioManager) seedExchangeAccountInfo(accounts []account.Holdings) return } for x := range accounts { - exchangeName := accounts[x].Exchange var currencies []account.Balance for y := range accounts[x].Accounts { + next: for z := range accounts[x].Accounts[y].Currencies { - var update bool for i := range currencies { if !accounts[x].Accounts[y].Currencies[z].CurrencyName.Equal(currencies[i].CurrencyName) { continue @@ -174,10 +170,7 @@ func (m *portfolioManager) seedExchangeAccountInfo(accounts []account.Holdings) currencies[i].AvailableWithoutBorrow += accounts[x].Accounts[y].Currencies[z].AvailableWithoutBorrow currencies[i].Free += accounts[x].Accounts[y].Currencies[z].Free currencies[i].Borrowed += accounts[x].Accounts[y].Currencies[z].Borrowed - update = true - } - if update { - continue + continue next } currencies = append(currencies, account.Balance{ CurrencyName: accounts[x].Accounts[y].Currencies[z].CurrencyName, @@ -190,51 +183,50 @@ func (m *portfolioManager) seedExchangeAccountInfo(accounts []account.Holdings) } } - for x := range currencies { - currencyName := currencies[x].CurrencyName - total := currencies[x].Total - - if !m.base.ExchangeAddressExists(exchangeName, currencyName) { - if total <= 0 { + for j := range currencies { + if !m.base.ExchangeAddressExists(accounts[x].Exchange, currencies[j].CurrencyName) { + if currencies[j].Total <= 0 { continue } log.Debugf(log.PortfolioMgr, "Portfolio: Adding new exchange address: %s, %s, %f, %s\n", - exchangeName, - currencyName, - total, + accounts[x].Exchange, + currencies[j].CurrencyName, + currencies[j].Total, portfolio.ExchangeAddress) - m.base.Addresses = append( - m.base.Addresses, - portfolio.Address{Address: exchangeName, - CoinType: currencyName, - Balance: total, - Description: portfolio.ExchangeAddress}) - } else { - if total <= 0 { - log.Debugf(log.PortfolioMgr, "Portfolio: Removing %s %s entry.\n", - exchangeName, - currencyName) - m.base.RemoveExchangeAddress(exchangeName, currencyName) - } else { - balance, ok := m.base.GetAddressBalance(exchangeName, - portfolio.ExchangeAddress, - currencyName) - if !ok { - continue - } + m.base.Addresses = append(m.base.Addresses, portfolio.Address{ + Address: accounts[x].Exchange, + CoinType: currencies[j].CurrencyName, + Balance: currencies[j].Total, + Description: portfolio.ExchangeAddress, + }) + continue + } - if balance != total { - log.Debugf(log.PortfolioMgr, "Portfolio: Updating %s %s entry with balance %f.\n", - exchangeName, - currencyName, - total) - m.base.UpdateExchangeAddressBalance(exchangeName, - currencyName, - total) - } - } + if currencies[j].Total <= 0 { + log.Debugf(log.PortfolioMgr, "Portfolio: Removing %s %s entry.\n", + accounts[x].Exchange, + currencies[j].CurrencyName) + m.base.RemoveExchangeAddress(accounts[x].Exchange, currencies[j].CurrencyName) + continue + } + + balance, ok := m.base.GetAddressBalance(accounts[x].Exchange, + portfolio.ExchangeAddress, + currencies[j].CurrencyName) + if !ok { + continue + } + + if balance != currencies[j].Total { + log.Debugf(log.PortfolioMgr, "Portfolio: Updating %s %s entry with balance %f.\n", + accounts[x].Exchange, + currencies[j].CurrencyName, + currencies[j].Total) + m.base.UpdateExchangeAddressBalance(accounts[x].Exchange, + currencies[j].CurrencyName, + currencies[j].Total) } } } diff --git a/engine/rpcserver.go b/engine/rpcserver.go index ee651223..1642111a 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -112,7 +112,7 @@ func (s *RPCServer) authenticateClient(ctx context.Context) (context.Context, er password != s.Config.RemoteControl.Password { return ctx, fmt.Errorf("username/password mismatch") } - ctx, err = exchange.ParseCredentialsMetadata(ctx, md) + ctx, err = account.ParseCredentialsMetadata(ctx, md) if err != nil { return ctx, err } @@ -618,7 +618,7 @@ func createAccountInfoRequest(h account.Holdings) (*gctrpc.GetAccountInfoRespons accounts := make([]*gctrpc.Account, len(h.Accounts)) for x := range h.Accounts { var a gctrpc.Account - a.Id = h.Accounts[x].ID + a.Id = h.Accounts[x].Credentials.String() for _, y := range h.Accounts[x].Currencies { if y.Total == 0 && y.Hold == 0 && diff --git a/engine/rpcserver_test.go b/engine/rpcserver_test.go index aab43081..a1fde255 100644 --- a/engine/rpcserver_test.go +++ b/engine/rpcserver_test.go @@ -2198,10 +2198,12 @@ func TestGetFuturesPositions(t *testing.T) { t.Fatalf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) } - ctx := exchange.DeployCredentialsToContext(context.Background(), &exchange.Credentials{ - Key: "wow", - Secret: "super wow", - }) + ctx := account.DeployCredentialsToContext(context.Background(), + &account.Credentials{ + Key: "wow", + Secret: "super wow", + }, + ) _, err = s.GetFuturesPositions(ctx, &gctrpc.GetFuturesPositionsRequest{ Exchange: fakeExchangeName, @@ -2312,7 +2314,8 @@ func TestGetCollateral(t *testing.T) { t.Fatalf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) } - ctx := exchange.DeployCredentialsToContext(context.Background(), &exchange.Credentials{Key: "fakerino", Secret: "supafake"}) + ctx := account.DeployCredentialsToContext(context.Background(), + &account.Credentials{Key: "fakerino", Secret: "supafake"}) _, err = s.GetCollateral(ctx, &gctrpc.GetCollateralRequest{ Exchange: fakeExchangeName, @@ -2322,7 +2325,8 @@ func TestGetCollateral(t *testing.T) { t.Fatalf("received '%v', expected '%v'", err, errNoAccountInformation) } - ctx = exchange.DeployCredentialsToContext(context.Background(), &exchange.Credentials{Key: "fakerino", Secret: "supafake", SubAccount: "1337"}) + ctx = account.DeployCredentialsToContext(context.Background(), + &account.Credentials{Key: "fakerino", Secret: "supafake", SubAccount: "1337"}) r, err := s.GetCollateral(ctx, &gctrpc.GetCollateralRequest{ Exchange: fakeExchangeName, diff --git a/exchanges/account/account.go b/exchanges/account/account.go index 11cb22e4..8c573cf9 100644 --- a/exchanges/account/account.go +++ b/exchanges/account/account.go @@ -26,6 +26,8 @@ var ( errNoExchangeSubAccountBalances = errors.New("no exchange sub account balances") errNoBalanceFound = errors.New("no balance found") errBalanceIsNil = errors.New("balance is nil") + errNoCredentialBalances = errors.New("no balances associated with credentials") + errCredentialsAreNil = errors.New("credentials are nil") ) // CollectBalances converts a map of sub-account balances into a slice @@ -64,16 +66,22 @@ func SubscribeToExchangeAccount(exchange string) (dispatch.Pipe, error) { } // Process processes new account holdings updates -func Process(h *Holdings) error { - return service.Update(h) +func Process(h *Holdings, c *Credentials) error { + return service.Update(h, c) } -// GetHoldings returns full holdings for an exchange -func GetHoldings(exch string, assetType asset.Item) (Holdings, error) { +// GetHoldings returns full holdings for an exchange. +// NOTE: Due to credentials these amounts could be N*APIKEY actual holdings. +// TODO: Add jurisdiction and differentiation between APIKEY holdings. +func GetHoldings(exch string, creds *Credentials, assetType asset.Item) (Holdings, error) { if exch == "" { return Holdings{}, errExchangeNameUnset } + if creds.IsEmpty() { + return Holdings{}, fmt.Errorf("%s %s %w", exch, assetType, errCredentialsAreNil) + } + if !assetType.IsValid() { return Holdings{}, fmt.Errorf("%s %s %w", exch, assetType, asset.ErrNotSupported) } @@ -88,7 +96,16 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) { } var accountsHoldings []SubAccount - for subAccount, assetHoldings := range accounts.SubAccounts { + subAccountHoldings, ok := accounts.SubAccounts[*creds] + if !ok { + return Holdings{}, fmt.Errorf("%s %s %s %w", + exch, + creds, + assetType, + errNoCredentialBalances) + } + + for subAccount, assetHoldings := range subAccountHoldings { for ai, currencyHoldings := range assetHoldings { if ai != assetType { continue @@ -113,10 +130,16 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) { continue } + cpy := *creds + if cpy.SubAccount == "" { + cpy.SubAccount = subAccount + } + accountsHoldings = append(accountsHoldings, SubAccount{ - ID: subAccount, - AssetType: ai, - Currencies: currencyBalances, + Credentials: Protected{creds: cpy}, + ID: subAccount, + AssetType: ai, + Currencies: currencyBalances, }) break } @@ -132,22 +155,24 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) { } // GetBalance returns the internal balance for that asset item. -func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*ProtectedBalance, error) { +func GetBalance(exch, subAccount string, creds *Credentials, ai asset.Item, c currency.Code) (*ProtectedBalance, error) { if exch == "" { - return nil, errExchangeNameUnset + return nil, fmt.Errorf("cannot get balance: %w", errExchangeNameUnset) } if !ai.IsValid() { - return nil, fmt.Errorf("%s %w", ai, asset.ErrNotSupported) + return nil, fmt.Errorf("cannot get balance: %s %w", ai, asset.ErrNotSupported) + } + + if creds.IsEmpty() { + return nil, fmt.Errorf("cannot get balance: %w", errCredentialsAreNil) } if c.IsEmpty() { - return nil, currency.ErrCurrencyCodeEmpty + return nil, fmt.Errorf("cannot get balance: %w", currency.ErrCurrencyCodeEmpty) } exch = strings.ToLower(exch) - subAccount = strings.ToLower(subAccount) - service.mu.Lock() defer service.mu.Unlock() @@ -156,7 +181,13 @@ func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*Prote return nil, fmt.Errorf("%s %w", exch, errExchangeHoldingsNotFound) } - assetBalances, ok := accounts.SubAccounts[subAccount] + subAccounts, ok := accounts.SubAccounts[*creds] + if !ok { + return nil, fmt.Errorf("%s %s %w", + exch, creds, errNoCredentialBalances) + } + + assetBalances, ok := subAccounts[subAccount] if !ok { return nil, fmt.Errorf("%s %s %w", exch, subAccount, errNoExchangeSubAccountBalances) @@ -177,16 +208,20 @@ func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*Prote } // Update updates holdings with new account info -func (s *Service) Update(a *Holdings) error { - if a == nil { - return errHoldingsIsNil +func (s *Service) Update(incoming *Holdings, creds *Credentials) error { + if incoming == nil { + return fmt.Errorf("cannot update holdings: %w", errHoldingsIsNil) } - if a.Exchange == "" { - return errExchangeNameUnset + if incoming.Exchange == "" { + return fmt.Errorf("cannot update holdings: %w", errExchangeNameUnset) } - exch := strings.ToLower(a.Exchange) + if creds.IsEmpty() { + return fmt.Errorf("cannot update holdings: %w", errCredentialsAreNil) + } + + exch := strings.ToLower(incoming.Exchange) s.mu.Lock() defer s.mu.Unlock() accounts, ok := s.exchangeAccounts[exch] @@ -197,46 +232,65 @@ func (s *Service) Update(a *Holdings) error { } accounts = &Accounts{ ID: id, - SubAccounts: make(map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance), + SubAccounts: make(map[Credentials]map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance), } s.exchangeAccounts[exch] = accounts } var errs common.Errors - for x := range a.Accounts { - if !a.Accounts[x].AssetType.IsValid() { + for x := range incoming.Accounts { + if !incoming.Accounts[x].AssetType.IsValid() { errs = append(errs, fmt.Errorf("cannot load sub account holdings for %s [%s] %w", - a.Accounts[x].ID, - a.Accounts[x].AssetType, + incoming.Accounts[x].ID, + incoming.Accounts[x].AssetType, asset.ErrNotSupported)) continue } - lowerSA := strings.ToLower(a.Accounts[x].ID) + // This assignment outside of scope is designed to have minimal impact + // on the exchange implementation UpdateAccountInfo() and portfoio + // management. + // TODO: Update incoming Holdings type to already be populated. (Suggestion) + cpy := *creds + if cpy.SubAccount == "" { + cpy.SubAccount = incoming.Accounts[x].ID + } + incoming.Accounts[x].Credentials.creds = cpy + + var subAccounts map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance + subAccounts, ok = accounts.SubAccounts[*creds] + if !ok { + subAccounts = make(map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance) + accounts.SubAccounts[*creds] = subAccounts + } + var accountAssets map[asset.Item]map[*currency.Item]*ProtectedBalance - accountAssets, ok = accounts.SubAccounts[lowerSA] + accountAssets, ok = subAccounts[incoming.Accounts[x].ID] if !ok { accountAssets = make(map[asset.Item]map[*currency.Item]*ProtectedBalance) - accounts.SubAccounts[lowerSA] = accountAssets + // Note: Sub accounts are case sensitive and an account "name" is + // different to account "naMe". + subAccounts[incoming.Accounts[x].ID] = accountAssets } var currencyBalances map[*currency.Item]*ProtectedBalance - currencyBalances, ok = accountAssets[a.Accounts[x].AssetType] + currencyBalances, ok = accountAssets[incoming.Accounts[x].AssetType] if !ok { currencyBalances = make(map[*currency.Item]*ProtectedBalance) - accountAssets[a.Accounts[x].AssetType] = currencyBalances + accountAssets[incoming.Accounts[x].AssetType] = currencyBalances } - for y := range a.Accounts[x].Currencies { - bal := currencyBalances[a.Accounts[x].Currencies[y].CurrencyName.Item] + for y := range incoming.Accounts[x].Currencies { + bal := currencyBalances[incoming.Accounts[x].Currencies[y].CurrencyName.Item] if bal == nil { bal = &ProtectedBalance{} - currencyBalances[a.Accounts[x].Currencies[y].CurrencyName.Item] = bal + currencyBalances[incoming.Accounts[x].Currencies[y].CurrencyName.Item] = bal } - bal.load(a.Accounts[x].Currencies[y]) + bal.load(incoming.Accounts[x].Currencies[y]) } } - err := s.mux.Publish(a, accounts.ID) + + err := s.mux.Publish(incoming, accounts.ID) if err != nil { return err } diff --git a/exchanges/account/account_test.go b/exchanges/account/account_test.go index 3d1de036..520b9020 100644 --- a/exchanges/account/account_test.go +++ b/exchanges/account/account_test.go @@ -11,6 +11,8 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/asset" ) +var happyCredentials = &Credentials{Key: "AAAAA"} + func TestCollectBalances(t *testing.T) { t.Parallel() accounts, err := CollectBalances( @@ -63,12 +65,12 @@ func TestGetHoldings(t *testing.T) { if err != nil { t.Fatal(err) } - err = Process(nil) + err = Process(nil, nil) if !errors.Is(err, errHoldingsIsNil) { t.Fatalf("received: '%v' but expected: '%v'", err, errHoldingsIsNil) } - err = Process(&Holdings{}) + err = Process(&Holdings{}, nil) if !errors.Is(err, errExchangeNameUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset) } @@ -77,9 +79,14 @@ func TestGetHoldings(t *testing.T) { Exchange: "Test", } - err = Process(&holdings) - if err != nil { - t.Error(err) + err = Process(&holdings, nil) + if !errors.Is(err, errCredentialsAreNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil) + } + + err = Process(&holdings, happyCredentials) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) } err = Process(&Holdings{ @@ -88,7 +95,7 @@ func TestGetHoldings(t *testing.T) { { ID: "1337", }}, - }) + }, happyCredentials) if !errors.Is(err, asset.ErrNotSupported) { t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) } @@ -111,7 +118,7 @@ func TestGetHoldings(t *testing.T) { }, }, }}, - }) + }, happyCredentials) if err != nil { t.Error(err) } @@ -131,32 +138,42 @@ func TestGetHoldings(t *testing.T) { }, }, }}, - }) + }, happyCredentials) if err != nil { t.Error(err) } - _, err = GetHoldings("", asset.Spot) + _, err = GetHoldings("", nil, asset.Spot) if !errors.Is(err, errExchangeNameUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset) } - _, err = GetHoldings("bla", asset.Spot) + _, err = GetHoldings("bla", nil, asset.Spot) + if !errors.Is(err, errCredentialsAreNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil) + } + + _, err = GetHoldings("bla", happyCredentials, asset.Spot) if !errors.Is(err, errExchangeHoldingsNotFound) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeHoldingsNotFound) } - _, err = GetHoldings("bla", asset.Empty) + _, err = GetHoldings("bla", happyCredentials, asset.Empty) if !errors.Is(err, asset.ErrNotSupported) { t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) } - _, err = GetHoldings("Test", asset.UpsideProfitContract) + _, err = GetHoldings("Test", happyCredentials, asset.UpsideProfitContract) if !errors.Is(err, errAssetHoldingsNotFound) { t.Fatalf("received: '%v' but expected: '%v'", err, errAssetHoldingsNotFound) } - u, err := GetHoldings("Test", asset.Spot) + _, err = GetHoldings("Test", &Credentials{Key: "BBBBB"}, asset.Spot) + if !errors.Is(err, errNoCredentialBalances) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNoCredentialBalances) + } + + u, err := GetHoldings("Test", happyCredentials, asset.Spot) if err != nil { t.Error(err) } @@ -217,7 +234,7 @@ func TestGetHoldings(t *testing.T) { }, }, }}, - }) + }, happyCredentials) if err != nil { t.Error(err) } @@ -226,22 +243,27 @@ func TestGetHoldings(t *testing.T) { } func TestGetBalance(t *testing.T) { - _, err := GetBalance("", "", asset.Empty, currency.Code{}) + _, err := GetBalance("", "", nil, asset.Empty, currency.Code{}) if !errors.Is(err, errExchangeNameUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset) } - _, err = GetBalance("bruh", "", asset.Empty, currency.Code{}) + _, err = GetBalance("bruh", "", nil, asset.Empty, currency.Code{}) if !errors.Is(err, asset.ErrNotSupported) { t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) } - _, err = GetBalance("bruh", "", asset.Spot, currency.Code{}) + _, err = GetBalance("bruh", "", nil, asset.Spot, currency.Code{}) + if !errors.Is(err, errCredentialsAreNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil) + } + + _, err = GetBalance("bruh", "", happyCredentials, asset.Spot, currency.Code{}) if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyCodeEmpty) } - _, err = GetBalance("bruh", "", asset.Spot, currency.BTC) + _, err = GetBalance("bruh", "", happyCredentials, asset.Spot, currency.BTC) if !errors.Is(err, errExchangeHoldingsNotFound) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeHoldingsNotFound) } @@ -254,22 +276,27 @@ func TestGetBalance(t *testing.T) { ID: "1337", }, }, - }) + }, happyCredentials) if err != nil { t.Error(err) } - _, err = GetBalance("bruh", "1336", asset.Spot, currency.BTC) + _, err = GetBalance("bruh", "1336", &Credentials{Key: "BBBBB"}, asset.Spot, currency.BTC) + if !errors.Is(err, errNoCredentialBalances) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNoCredentialBalances) + } + + _, err = GetBalance("bruh", "1336", happyCredentials, asset.Spot, currency.BTC) if !errors.Is(err, errNoExchangeSubAccountBalances) { t.Fatalf("received: '%v' but expected: '%v'", err, errNoExchangeSubAccountBalances) } - _, err = GetBalance("bruh", "1337", asset.Futures, currency.BTC) + _, err = GetBalance("bruh", "1337", happyCredentials, asset.Futures, currency.BTC) if !errors.Is(err, errAssetHoldingsNotFound) { t.Fatalf("received: '%v' but expected: '%v'", err, errAssetHoldingsNotFound) } - _, err = GetBalance("bruh", "1337", asset.Spot, currency.BTC) + _, err = GetBalance("bruh", "1337", happyCredentials, asset.Spot, currency.BTC) if !errors.Is(err, errNoBalanceFound) { t.Fatalf("received: '%v' but expected: '%v'", err, errNoBalanceFound) } @@ -289,12 +316,12 @@ func TestGetBalance(t *testing.T) { }, }, }, - }) + }, happyCredentials) if err != nil { t.Error(err) } - bal, err := GetBalance("bruh", "1337", asset.Spot, currency.BTC) + bal, err := GetBalance("bruh", "1337", happyCredentials, asset.Spot, currency.BTC) if !errors.Is(err, nil) { t.Fatalf("received: '%v' but expected: '%v'", err, nil) } @@ -379,12 +406,12 @@ func TestGetFree(t *testing.T) { func TestUpdate(t *testing.T) { t.Parallel() s := &Service{exchangeAccounts: make(map[string]*Accounts), mux: dispatch.GetNewMux(nil)} - err := s.Update(nil) + err := s.Update(nil, nil) if !errors.Is(err, errHoldingsIsNil) { t.Fatalf("received: '%v' but expected: '%v'", err, errHoldingsIsNil) } - err = s.Update(&Holdings{}) + err = s.Update(&Holdings{}, nil) if !errors.Is(err, errExchangeNameUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset) } @@ -416,7 +443,7 @@ func TestUpdate(t *testing.T) { }, }, }, - }) + }, happyCredentials) if !errors.Is(err, asset.ErrNotSupported) { t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) } @@ -436,7 +463,7 @@ func TestUpdate(t *testing.T) { }, }, }, - }) + }, happyCredentials) if !errors.Is(err, nil) { t.Fatalf("received: '%v' but expected: '%v'", err, nil) } @@ -446,7 +473,7 @@ func TestUpdate(t *testing.T) { t.Fatal("account should be loaded") } - b, ok := acc.SubAccounts["1337"][asset.Spot][currency.BTC.Item] + b, ok := acc.SubAccounts[Credentials{Key: "AAAAA"}]["1337"][asset.Spot][currency.BTC.Item] if !ok { t.Fatal("account should be loaded") } diff --git a/exchanges/account/account_types.go b/exchanges/account/account_types.go index e2cdb916..644c34a6 100644 --- a/exchanges/account/account_types.go +++ b/exchanges/account/account_types.go @@ -26,8 +26,13 @@ type Service struct { // Accounts holds a stream ID and a map to the exchange holdings type Accounts struct { - ID uuid.UUID - SubAccounts map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance + ID uuid.UUID + // NOTE: Credentials is a place holder for a future interface type, which + // will need - + // TODO: Credential tracker to match to keys that are managed and return + // pointer. + // TODO: Have different cred struct for centralized verse DEFI exchanges. + SubAccounts map[Credentials]map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance } // Holdings is a generic type to hold each exchange's holdings for all enabled @@ -39,9 +44,10 @@ type Holdings struct { // SubAccount defines a singular account type with associated currency balances type SubAccount struct { - ID string - AssetType asset.Item - Currencies []Balance + Credentials Protected + ID string + AssetType asset.Item + Currencies []Balance } // Balance is a sub type to store currency name and individual totals @@ -76,3 +82,9 @@ type ProtectedBalance struct { // usage. notice alert.Notice } + +// Protected limits the access to the underlying credentials outside of this +// package. +type Protected struct { + creds Credentials +} diff --git a/exchanges/account/credentials.go b/exchanges/account/credentials.go new file mode 100644 index 00000000..c3455b1c --- /dev/null +++ b/exchanges/account/credentials.go @@ -0,0 +1,218 @@ +package account + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + + "google.golang.org/grpc/metadata" +) + +// contextCredential is a string flag for use with context values when setting +// credentials internally or via gRPC. +type contextCredential string + +const ( + // ContextCredentialsFlag used for retrieving api credentials from context + ContextCredentialsFlag contextCredential = "apicredentials" + // ContextSubAccountFlag used for retrieving just the sub account from + // context, when the default config credentials sub account needs to be + // changed while the same keys can be used. + ContextSubAccountFlag contextCredential = "subaccountoverride" + + Key = "key" + Secret = "secret" + SubAccountSTR = "subaccount" + ClientID = "clientid" + OneTimePassword = "otp" + PEMKey = "pemkey" + + apiKeyDisplaySize = 16 +) + +var ( + errMetaDataIsNil = errors.New("meta data is nil") + errInvalidCredentialMetaDataLength = errors.New("invalid meta data to process credentials") + errMissingInfo = errors.New("cannot parse meta data missing information in key value pair") +) + +// Credentials define parameters that allow for an authenticated request. +type Credentials struct { + Key string + Secret string + ClientID string // TODO: Implement with exchange orders functionality + PEMKey string + SubAccount string + OneTimePassword string + // TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities. +} + +// GetMetaData returns the credentials for metadata context deployment +func (c *Credentials) GetMetaData() (flag, values string) { + vals := make([]string, 0, 6) + if c.Key != "" { + vals = append(vals, Key+":"+c.Key) + } + if c.Secret != "" { + vals = append(vals, Secret+":"+c.Secret) + } + if c.SubAccount != "" { + vals = append(vals, SubAccountSTR+":"+c.SubAccount) + } + if c.ClientID != "" { + vals = append(vals, ClientID+":"+c.ClientID) + } + if c.PEMKey != "" { + vals = append(vals, PEMKey+":"+c.PEMKey) + } + if c.OneTimePassword != "" { + vals = append(vals, OneTimePassword+":"+c.OneTimePassword) + } + return string(ContextCredentialsFlag), strings.Join(vals, ",") +} + +// String prints out basic credential info (obfuscated) to track key instances +// associated with exchanges. +func (c *Credentials) String() string { + obfuscated := c.Key + if len(obfuscated) > apiKeyDisplaySize { + obfuscated = obfuscated[:apiKeyDisplaySize] + } + return fmt.Sprintf("Key:[%s...] SubAccount:[%s] ClientID:[%s]", + obfuscated, + c.SubAccount, + c.ClientID) +} + +// getInternal returns the values for assignment to an internal context +func (c *Credentials) getInternal() (contextCredential, *ContextCredentialsStore) { + if c.IsEmpty() { + return "", nil + } + store := &ContextCredentialsStore{} + store.Load(c) + return ContextCredentialsFlag, store +} + +// IsEmpty return true if the underlying credentials type has not been filled +// with at least one item. +func (c *Credentials) IsEmpty() bool { + return c == nil || c.ClientID == "" && + c.Key == "" && + c.OneTimePassword == "" && + c.PEMKey == "" && + c.Secret == "" && + c.SubAccount == "" +} + +// Equal determines if the keys are the same. +// OTP omitted because it's generated per request. +// PEMKey and Secret omitted because of direct correlation with api key. +func (c *Credentials) Equal(other *Credentials) bool { + return c != nil && + other != nil && + c.Key == other.Key && + c.ClientID == other.ClientID && + c.SubAccount == other.SubAccount +} + +// ContextCredentialsStore protects the stored credentials for use in a context +type ContextCredentialsStore struct { + creds *Credentials + mu sync.RWMutex +} + +// Load stores provided credentials +func (c *ContextCredentialsStore) Load(creds *Credentials) { + // Segregate from external call + cpy := *creds + c.mu.Lock() + c.creds = &cpy + c.mu.Unlock() +} + +// Get returns the full credentials from the store +func (c *ContextCredentialsStore) Get() *Credentials { + c.mu.RLock() + creds := *c.creds + c.mu.RUnlock() + return &creds +} + +// ParseCredentialsMetadata intercepts and converts credentials metadata to a +// static type for authentication processing and protection. +func ParseCredentialsMetadata(ctx context.Context, md metadata.MD) (context.Context, error) { + if md == nil { + return ctx, errMetaDataIsNil + } + + credMD, ok := md[string(ContextCredentialsFlag)] + if !ok || len(credMD) == 0 { + return ctx, nil + } + + if len(credMD) != 1 { + return ctx, errInvalidCredentialMetaDataLength + } + + segregatedCreds := strings.Split(credMD[0], ",") + var ctxCreds Credentials + var subAccountHere string + for x := range segregatedCreds { + keyVals := strings.Split(segregatedCreds[x], ":") + if len(keyVals) != 2 { + return ctx, fmt.Errorf("%w received %v fields, expected 2 contains: %s", + errMissingInfo, + len(keyVals), + keyVals) + } + switch keyVals[0] { + case Key: + ctxCreds.Key = keyVals[1] + case Secret: + ctxCreds.Secret = keyVals[1] + case SubAccountSTR: + // Capture sub account as this can override if other values are + // not included in metadata. + subAccountHere = keyVals[1] + case ClientID: + ctxCreds.ClientID = keyVals[1] + case PEMKey: + ctxCreds.PEMKey = keyVals[1] + case OneTimePassword: + ctxCreds.OneTimePassword = keyVals[1] + } + } + if ctxCreds.IsEmpty() && subAccountHere != "" { + // This will override default sub account details if needed. + return deploySubAccountOverrideToContext(ctx, subAccountHere), nil + } + // merge sub account to main context credentials + ctxCreds.SubAccount = subAccountHere + return DeployCredentialsToContext(ctx, &ctxCreds), nil +} + +// DeployCredentialsToContext sets credentials for internal use to context which +// can override default credential values. +func DeployCredentialsToContext(ctx context.Context, creds *Credentials) context.Context { + flag, store := creds.getInternal() + return context.WithValue(ctx, flag, store) +} + +// deploySubAccountOverrideToContext sets subaccount as override to credentials +// as a separate flag. +func deploySubAccountOverrideToContext(ctx context.Context, subAccount string) context.Context { + return context.WithValue(ctx, ContextSubAccountFlag, subAccount) +} + +// String strings the credentials in a protected way. +func (p *Protected) String() string { + return p.creds.String() +} + +// Equal determines if the keys are the same +func (p *Protected) Equal(other *Credentials) bool { + return p.creds.Equal(other) +} diff --git a/exchanges/account/credentials_test.go b/exchanges/account/credentials_test.go new file mode 100644 index 00000000..25b6e494 --- /dev/null +++ b/exchanges/account/credentials_test.go @@ -0,0 +1,240 @@ +package account + +import ( + "context" + "errors" + "testing" + + "google.golang.org/grpc/metadata" +) + +func TestIsEmpty(t *testing.T) { + t.Parallel() + var c *Credentials + if !c.IsEmpty() { + t.Fatalf("expected: %v but received: %v", true, c.IsEmpty()) + } + c = new(Credentials) + if !c.IsEmpty() { + t.Fatalf("expected: %v but received: %v", true, c.IsEmpty()) + } + + c.SubAccount = "woow" + if c.IsEmpty() { + t.Fatalf("expected: %v but received: %v", false, c.IsEmpty()) + } +} + +func TestParseCredentialsMetadata(t *testing.T) { + t.Parallel() + _, err := ParseCredentialsMetadata(context.Background(), nil) + if !errors.Is(err, errMetaDataIsNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errMetaDataIsNil) + } + + _, err = ParseCredentialsMetadata(context.Background(), metadata.MD{}) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + ctx := metadata.AppendToOutgoingContext(context.Background(), + string(ContextCredentialsFlag), "wow", string(ContextCredentialsFlag), "wow2") + nortyMD, _ := metadata.FromOutgoingContext(ctx) + + _, err = ParseCredentialsMetadata(context.Background(), nortyMD) + if !errors.Is(err, errInvalidCredentialMetaDataLength) { + t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidCredentialMetaDataLength) + } + + ctx = metadata.AppendToOutgoingContext(context.Background(), + string(ContextCredentialsFlag), "brokenstring") + nortyMD, _ = metadata.FromOutgoingContext(ctx) + + _, err = ParseCredentialsMetadata(context.Background(), nortyMD) + if !errors.Is(err, errMissingInfo) { + t.Fatalf("received: '%v' but expected: '%v'", err, errMissingInfo) + } + + beforeCreds := Credentials{ + Key: "superkey", + Secret: "supersecret", + SubAccount: "supersub", + ClientID: "superclient", + PEMKey: "superpem", + OneTimePassword: "superOneTimePasssssss", + } + + flag, outGoing := beforeCreds.GetMetaData() + ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing) + lovelyMD, _ := metadata.FromOutgoingContext(ctx) + + ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + store, ok := ctx.Value(ContextCredentialsFlag).(*ContextCredentialsStore) + if !ok { + t.Fatal("should have processed") + } + + afterCreds := store.Get() + + if afterCreds.Key != "superkey" && + afterCreds.Secret != "supersecret" && + afterCreds.SubAccount != "supersub" && + afterCreds.ClientID != "superclient" && + afterCreds.PEMKey != "superpem" && + afterCreds.OneTimePassword != "superOneTimePasssssss" { + t.Fatal("unexpected values") + } + + // subaccount override + subaccount := Credentials{ + SubAccount: "supersub", + } + + flag, outGoing = subaccount.GetMetaData() + ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing) + lovelyMD, _ = metadata.FromOutgoingContext(ctx) + + ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + sa, ok := ctx.Value(ContextSubAccountFlag).(string) + if !ok { + t.Fatal("should have processed") + } + + if sa != "supersub" { + t.Fatal("unexpected value") + } +} + +func TestGetInternal(t *testing.T) { + t.Parallel() + flag, store := (&Credentials{}).getInternal() + if flag != "" { + t.Fatal("unexpected value") + } + if store != nil { + t.Fatal("unexpected value") + } + flag, store = (&Credentials{Key: "wow"}).getInternal() + if flag != ContextCredentialsFlag { + t.Fatal("unexpected value") + } + if store == nil { + t.Fatal("unexpected value") + } + if store.Get().Key != "wow" { + t.Fatal("unexpected value") + } +} + +func TestString(t *testing.T) { + t.Parallel() + creds := Credentials{} + if s := creds.String(); s != "Key:[...] SubAccount:[] ClientID:[]" { + t.Fatal("unexpected value") + } + + creds.Key = "12345678910111234" + creds.SubAccount = "sub" + creds.ClientID = "client" + + if s := creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" { + t.Fatal("unexpected value") + } +} + +func TestCredentialsEqual(t *testing.T) { + t.Parallel() + var this, that *Credentials + if this.Equal(that) { + t.Fatal("unexpected value") + } + this = &Credentials{} + if this.Equal(that) { + t.Fatal("unexpected value") + } + that = &Credentials{Key: "1337"} + if this.Equal(that) { + t.Fatal("unexpected value") + } + this.Key = "1337" + if !this.Equal(that) { + t.Fatal("unexpected value") + } + this.ClientID = "1337" + if this.Equal(that) { + t.Fatal("unexpected value") + } + that.ClientID = "1337" + if !this.Equal(that) { + t.Fatal("unexpected value") + } + this.SubAccount = "someSub" + if this.Equal(that) { + t.Fatal("unexpected value") + } + that.SubAccount = "someSub" + if !this.Equal(that) { + t.Fatal("unexpected value") + } +} + +func TestProtectedString(t *testing.T) { + t.Parallel() + p := Protected{} + if s := p.String(); s != "Key:[...] SubAccount:[] ClientID:[]" { + t.Fatal("unexpected value") + } + + p.creds.Key = "12345678910111234" + p.creds.SubAccount = "sub" + p.creds.ClientID = "client" + + if s := p.creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" { + t.Fatal("unexpected value") + } +} + +func TestProtectedCredentialsEqual(t *testing.T) { + t.Parallel() + var this Protected + var that *Credentials + if this.Equal(that) { + t.Fatal("unexpected value") + } + this.creds = Credentials{} + if this.Equal(that) { + t.Fatal("unexpected value") + } + that = &Credentials{Key: "1337"} + if this.Equal(that) { + t.Fatal("unexpected value") + } + this.creds.Key = "1337" + if !this.Equal(that) { + t.Fatal("unexpected value") + } + this.creds.ClientID = "1337" + if this.Equal(that) { + t.Fatal("unexpected value") + } + that.ClientID = "1337" + if !this.Equal(that) { + t.Fatal("unexpected value") + } + this.creds.SubAccount = "someSub" + if this.Equal(that) { + t.Fatal("unexpected value") + } + that.SubAccount = "someSub" + if !this.Equal(that) { + t.Fatal("unexpected value") + } +} diff --git a/exchanges/alphapoint/alphapoint_wrapper.go b/exchanges/alphapoint/alphapoint_wrapper.go index b6495eee..b5f4a4b3 100644 --- a/exchanges/alphapoint/alphapoint_wrapper.go +++ b/exchanges/alphapoint/alphapoint_wrapper.go @@ -116,7 +116,12 @@ func (a *Alphapoint) UpdateAccountInfo(ctx context.Context, assetType asset.Item AssetType: assetType, }) - err = account.Process(&response) + creds, err := a.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -127,7 +132,11 @@ func (a *Alphapoint) UpdateAccountInfo(ctx context.Context, assetType asset.Item // FetchAccountInfo retrieves balances for all enabled currencies on the // Alphapoint exchange func (a *Alphapoint) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(a.Name, assetType) + creds, err := a.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(a.Name, creds, assetType) if err != nil { return a.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index 8718cb94..40bef77d 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -788,7 +788,11 @@ func (b *Binance) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( } acc.AssetType = assetType info.Accounts = append(info.Accounts, acc) - if err := account.Process(&info); err != nil { + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return account.Holdings{}, err } return info, nil @@ -796,7 +800,11 @@ func (b *Binance) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( // FetchAccountInfo retrieves balances for all enabled currencies func (b *Binance) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 0bbc28fc..ff70771e 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -519,7 +519,11 @@ func (b *Bitfinex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) } response.Accounts = Accounts - err = account.Process(&response) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -529,11 +533,14 @@ func (b *Bitfinex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bitfinex) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/bitflyer/bitflyer_wrapper.go b/exchanges/bitflyer/bitflyer_wrapper.go index fbf6238e..11bd7827 100644 --- a/exchanges/bitflyer/bitflyer_wrapper.go +++ b/exchanges/bitflyer/bitflyer_wrapper.go @@ -320,11 +320,14 @@ func (b *Bitflyer) UpdateAccountInfo(_ context.Context, _ asset.Item) (account.H // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bitflyer) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 2156fc01..b4b869a1 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -394,7 +394,11 @@ func (b *Bithumb) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( }) info.Exchange = b.Name - err = account.Process(&info) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&info, creds) if err != nil { return account.Holdings{}, err } @@ -404,11 +408,14 @@ func (b *Bithumb) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bithumb) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index dcc34277..f2bc1fd7 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -457,7 +457,11 @@ func (b *Bitmex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a } info.Exchange = b.Name - if err := account.Process(&info); err != nil { + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return account.Holdings{}, err } @@ -466,11 +470,14 @@ func (b *Bitmex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bitmex) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index e9b6fb47..fbe71100 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -452,7 +452,11 @@ func (b *Bitstamp) UpdateAccountInfo(ctx context.Context, assetType asset.Item) Currencies: currencies, }) - err = account.Process(&response) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -462,11 +466,14 @@ func (b *Bitstamp) UpdateAccountInfo(ctx context.Context, assetType asset.Item) // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bitstamp) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/bittrex/bittrex_wrapper.go b/exchanges/bittrex/bittrex_wrapper.go index 108fc3f8..dad979f8 100644 --- a/exchanges/bittrex/bittrex_wrapper.go +++ b/exchanges/bittrex/bittrex_wrapper.go @@ -425,12 +425,20 @@ func (b *Bittrex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( }) resp.Exchange = b.Name - return resp, account.Process(&resp) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + return resp, account.Process(&resp, creds) } // FetchAccountInfo retrieves balances for all enabled currencies func (b *Bittrex) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - resp, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + resp, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index e010e359..5e4e9683 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -457,7 +457,11 @@ func (b *BTCMarkets) UpdateAccountInfo(ctx context.Context, assetType asset.Item resp.Accounts = append(resp.Accounts, acc) resp.Exchange = b.Name - err = account.Process(&resp) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&resp, creds) if err != nil { return account.Holdings{}, err } @@ -467,7 +471,11 @@ func (b *BTCMarkets) UpdateAccountInfo(ctx context.Context, assetType asset.Item // FetchAccountInfo retrieves balances for all enabled currencies func (b *BTCMarkets) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index bb597bc6..d4021021 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -416,7 +416,11 @@ func (b *BTSE) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (acc }, } - err = account.Process(&a) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&a, creds) if err != nil { return account.Holdings{}, err } @@ -426,7 +430,11 @@ func (b *BTSE) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (acc // FetchAccountInfo retrieves balances for all enabled currencies func (b *BTSE) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(b.Name, assetType) + creds, err := b.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(b.Name, creds, assetType) if err != nil { return b.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 450558f7..0f9971d4 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -14,7 +14,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" - exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" @@ -401,7 +401,7 @@ func (c *CoinbasePro) GenerateDefaultSubscriptions() ([]stream.ChannelSubscripti // Subscribe sends a websocket message to receive data from the channel func (c *CoinbasePro) Subscribe(channelsToSubscribe []stream.ChannelSubscription) error { - var creds *exchange.Credentials + var creds *account.Credentials var err error if c.IsWebsocketAuthenticationSupported() { creds, err = c.GetCredentials(context.TODO()) diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index b509778c..4f0c46f5 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -346,7 +346,11 @@ func (c *CoinbasePro) UpdateAccountInfo(ctx context.Context, assetType asset.Ite return account.Holdings{}, err } - err = account.Process(&response) + creds, err := c.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -356,11 +360,14 @@ func (c *CoinbasePro) UpdateAccountInfo(ctx context.Context, assetType asset.Ite // FetchAccountInfo retrieves balances for all enabled currencies func (c *CoinbasePro) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(c.Name, assetType) + creds, err := c.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(c.Name, creds, assetType) if err != nil { return c.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/coinut/coinut.go b/exchanges/coinut/coinut.go index 1dd0388b..7a19becd 100644 --- a/exchanges/coinut/coinut.go +++ b/exchanges/coinut/coinut.go @@ -14,6 +14,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/request" @@ -275,7 +276,7 @@ func (c *COINUT) SendHTTPRequest(ctx context.Context, ep exchange.URL, apiReques headers := make(map[string]string) if authenticated { - var creds *exchange.Credentials + var creds *account.Credentials creds, err = c.GetCredentials(ctx) if err != nil { return nil, err diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 84502a7b..07e0f00a 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -16,6 +16,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/core" "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" @@ -832,7 +833,8 @@ func TestWsLogin(t *testing.T) { "unverified_email":"", "username":"test" }`) - ctx := exchange.DeployCredentialsToContext(context.Background(), &exchange.Credentials{Key: "b46e658f-d4c4-433c-b032-093423b1aaa4", ClientID: "dummy"}) + ctx := account.DeployCredentialsToContext(context.Background(), + &account.Credentials{Key: "b46e658f-d4c4-433c-b032-093423b1aaa4", ClientID: "dummy"}) err := c.wsHandleData(ctx, pressXToJSON) if err != nil { t.Error(err) diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 0db4514b..29153421 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -395,7 +395,11 @@ func (c *COINUT) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a Currencies: balances, }) - err = account.Process(&info) + creds, err := c.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&info, creds) if err != nil { return account.Holdings{}, err } @@ -405,11 +409,14 @@ func (c *COINUT) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (c *COINUT) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(c.Name, assetType) + creds, err := c.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(c.Name, creds, assetType) if err != nil { return c.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/credentials.go b/exchanges/credentials.go index 4806962a..1380b9d2 100644 --- a/exchanges/credentials.go +++ b/exchanges/credentials.go @@ -5,28 +5,11 @@ import ( "errors" "fmt" "strings" - "sync" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/log" - "google.golang.org/grpc/metadata" -) - -// contextCredential is a string flag for use with context values when setting -// credentials internally or via gRPC. -type contextCredential string - -const ( - contextCredentialsFlag contextCredential = "apicredentials" - contextSubAccountFlag contextCredential = "subaccountoverride" - - key = "key" - secret = "secret" - subAccount = "subaccount" - clientID = "clientid" - oneTimePassword = "otp" - _PEMKey = "pemkey" ) var ( @@ -39,167 +22,20 @@ var ( // undertake an authenticated HTTP request. ErrCredentialsAreEmpty = errors.New("credentials are empty") - errRequiresAPIKey = errors.New("requires API key but default/empty one set") - errRequiresAPISecret = errors.New("requires API secret but default/empty one set") - errRequiresAPIPEMKey = errors.New("requires API PEM key but default/empty one set") - errRequiresAPIClientID = errors.New("requires API Client ID but default/empty one set") - errBase64DecodeFailure = errors.New("base64 decode has failed") - errMissingInfo = errors.New("cannot parse meta data missing information in key value pair") - errInvalidCredentialMetaDataLength = errors.New("invalid meta data to process credentials") - errContextCredentialsFailure = errors.New("context credentials type assertion failure") - errMetaDataIsNil = errors.New("meta data is nil") + errRequiresAPIKey = errors.New("requires API key but default/empty one set") + errRequiresAPISecret = errors.New("requires API secret but default/empty one set") + errRequiresAPIPEMKey = errors.New("requires API PEM key but default/empty one set") + errRequiresAPIClientID = errors.New("requires API Client ID but default/empty one set") + errBase64DecodeFailure = errors.New("base64 decode has failed") + errContextCredentialsFailure = errors.New("context credentials type assertion failure") ) -// ParseCredentialsMetadata intercepts and converts credentials metadata to a -// static type for authentication processing and protection. -func ParseCredentialsMetadata(ctx context.Context, md metadata.MD) (context.Context, error) { - if md == nil { - return ctx, errMetaDataIsNil - } - - credMD, ok := md[string(contextCredentialsFlag)] - if !ok || len(credMD) == 0 { - return ctx, nil - } - - if len(credMD) != 1 { - return ctx, errInvalidCredentialMetaDataLength - } - - segregatedCreds := strings.Split(credMD[0], ",") - var ctxCreds Credentials - var subAccountHere string - for x := range segregatedCreds { - keyVals := strings.Split(segregatedCreds[x], ":") - if len(keyVals) != 2 { - return ctx, fmt.Errorf("%w received %v fields, expected 2 contains: %s", - errMissingInfo, - len(keyVals), - keyVals) - } - switch keyVals[0] { - case key: - ctxCreds.Key = keyVals[1] - case secret: - ctxCreds.Secret = keyVals[1] - case subAccount: - // Capture sub account as this can override if other values are - // not included in metadata. - subAccountHere = keyVals[1] - case clientID: - ctxCreds.ClientID = keyVals[1] - case _PEMKey: - ctxCreds.PEMKey = keyVals[1] - case oneTimePassword: - ctxCreds.OneTimePassword = keyVals[1] - } - } - if ctxCreds.IsEmpty() && subAccountHere != "" { - // This will override default sub account details if needed. - return deploySubAccountOverrideToContext(ctx, subAccountHere), nil - } - // merge sub account to main context credentials - ctxCreds.SubAccount = subAccountHere - return DeployCredentialsToContext(ctx, &ctxCreds), nil -} - -// Credentials define parameters that allow for an authenticated request. -type Credentials struct { - Key string - Secret string - ClientID string - PEMKey string - SubAccount string - OneTimePassword string -} - -// DeployCredentialsToContext sets credentials for internal use to context which -// can override default credential values. -func DeployCredentialsToContext(ctx context.Context, creds *Credentials) context.Context { - flag, store := creds.getInternal() - return context.WithValue(ctx, flag, store) -} - -// deploySubAccountOverrideToContext sets subaccount as override to credentials -// as a separate flag. -func deploySubAccountOverrideToContext(ctx context.Context, subAccount string) context.Context { - return context.WithValue(ctx, contextSubAccountFlag, subAccount) -} - -// GetMetaData returns the credentials for metadata context deployment -func (c *Credentials) GetMetaData() (flag, values string) { - vals := make([]string, 0, 6) - if c.Key != "" { - vals = append(vals, key+":"+c.Key) - } - if c.Secret != "" { - vals = append(vals, secret+":"+c.Secret) - } - if c.SubAccount != "" { - vals = append(vals, subAccount+":"+c.SubAccount) - } - if c.ClientID != "" { - vals = append(vals, clientID+":"+c.ClientID) - } - if c.PEMKey != "" { - vals = append(vals, _PEMKey+":"+c.PEMKey) - } - if c.OneTimePassword != "" { - vals = append(vals, oneTimePassword+":"+c.OneTimePassword) - } - return string(contextCredentialsFlag), strings.Join(vals, ",") -} - -// IsEmpty return true if the underlying credentials type has not been filled -// with at least one item. -func (c *Credentials) IsEmpty() bool { - return c == nil || c.ClientID == "" && - c.Key == "" && - c.OneTimePassword == "" && - c.PEMKey == "" && - c.Secret == "" && - c.SubAccount == "" -} - -// contextCredentialsStore protects the stored credentials for use in a context -type contextCredentialsStore struct { - creds *Credentials - mu sync.RWMutex -} - -// getInternal returns the values for assignment to an internal context -func (c *Credentials) getInternal() (contextCredential, *contextCredentialsStore) { - if c.IsEmpty() { - return "", nil - } - store := &contextCredentialsStore{} - store.Load(c) - return contextCredentialsFlag, store -} - -// Load stores provided credentials -func (c *contextCredentialsStore) Load(creds *Credentials) { - // Segregate from external call - cpy := *creds - c.mu.Lock() - c.creds = &cpy - c.mu.Unlock() -} - -// Get returns the full credentials from the store -func (c *contextCredentialsStore) Get() *Credentials { - c.mu.RLock() - creds := *c.creds - c.mu.RUnlock() - return &creds -} - // SetKey sets new key for the default credentials func (a *API) SetKey(key string) { a.credMu.Lock() defer a.credMu.Unlock() if a.credentials == nil { - a.credentials = &Credentials{} + a.credentials = &account.Credentials{} } a.credentials.Key = key } @@ -209,7 +45,7 @@ func (a *API) SetSecret(secret string) { a.credMu.Lock() defer a.credMu.Unlock() if a.credentials == nil { - a.credentials = &Credentials{} + a.credentials = &account.Credentials{} } a.credentials.Secret = secret } @@ -219,7 +55,7 @@ func (a *API) SetClientID(clientID string) { a.credMu.Lock() defer a.credMu.Unlock() if a.credentials == nil { - a.credentials = &Credentials{} + a.credentials = &account.Credentials{} } a.credentials.ClientID = clientID } @@ -229,7 +65,7 @@ func (a *API) SetPEMKey(pem string) { a.credMu.Lock() defer a.credMu.Unlock() if a.credentials == nil { - a.credentials = &Credentials{} + a.credentials = &account.Credentials{} } a.credentials.PEMKey = pem } @@ -239,14 +75,14 @@ func (a *API) SetSubAccount(sub string) { a.credMu.Lock() defer a.credMu.Unlock() if a.credentials == nil { - a.credentials = &Credentials{} + a.credentials = &account.Credentials{} } a.credentials.SubAccount = sub } // CheckCredentials checks to see if the required fields have been set before // sending an authenticated API request -func (b *Base) CheckCredentials(creds *Credentials, isContext bool) error { +func (b *Base) CheckCredentials(creds *account.Credentials, isContext bool) error { if b.SkipAuthCheck { return nil } @@ -277,7 +113,7 @@ func (b *Base) AreCredentialsValid(ctx context.Context) bool { // GetDefaultCredentials returns the exchange.Base api credentials loaded by // config.json -func (b *Base) GetDefaultCredentials() *Credentials { +func (b *Base) GetDefaultCredentials() *account.Credentials { b.API.credMu.RLock() defer b.API.credMu.RUnlock() if b.API.credentials == nil { @@ -289,12 +125,14 @@ func (b *Base) GetDefaultCredentials() *Credentials { // GetCredentials checks and validates current credentials, context credentials // override default credentials, if no credentials found, will return an error. -func (b *Base) GetCredentials(ctx context.Context) (*Credentials, error) { - value := ctx.Value(contextCredentialsFlag) +func (b *Base) GetCredentials(ctx context.Context) (*account.Credentials, error) { + value := ctx.Value(account.ContextCredentialsFlag) if value != nil { - ctxCredStore, ok := value.(*contextCredentialsStore) + ctxCredStore, ok := value.(*account.ContextCredentialsStore) if !ok { - return &Credentials{}, errContextCredentialsFailure + // NOTE: Return empty credentials on error to limit panic on + // websocket handling. + return &account.Credentials{}, errContextCredentialsFailure } creds := ctxCredStore.Get() @@ -306,9 +144,11 @@ func (b *Base) GetCredentials(ctx context.Context) (*Credentials, error) { err := b.CheckCredentials(b.API.credentials, false) if err != nil { - return &Credentials{}, err + // NOTE: Return empty credentials on error to limit panic on websocket + // handling. + return &account.Credentials{}, err } - subAccountOverride, ok := ctx.Value(contextSubAccountFlag).(string) + subAccountOverride, ok := ctx.Value(account.ContextSubAccountFlag).(string) b.API.credMu.RLock() defer b.API.credMu.RUnlock() creds := *b.API.credentials @@ -319,7 +159,7 @@ func (b *Base) GetCredentials(ctx context.Context) (*Credentials, error) { } // ValidateAPICredentials validates the exchanges API credentials -func (b *Base) ValidateAPICredentials(creds *Credentials) error { +func (b *Base) ValidateAPICredentials(creds *account.Credentials) error { b.API.credMu.RLock() defer b.API.credMu.RUnlock() if creds.IsEmpty() { @@ -359,7 +199,7 @@ func (b *Base) SetCredentials(apiKey, apiSecret, clientID, subaccount, pemKey, o b.API.credMu.Lock() defer b.API.credMu.Unlock() if b.API.credentials == nil { - b.API.credentials = &Credentials{} + b.API.credentials = &account.Credentials{} } b.API.credentials.Key = apiKey b.API.credentials.ClientID = clientID diff --git a/exchanges/credentials_test.go b/exchanges/credentials_test.go index 6538504d..7e2bef11 100644 --- a/exchanges/credentials_test.go +++ b/exchanges/credentials_test.go @@ -6,97 +6,9 @@ import ( "testing" "github.com/thrasher-corp/gocryptotrader/config" - "google.golang.org/grpc/metadata" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" ) -func TestParseCredentialsMetadata(t *testing.T) { - t.Parallel() - _, err := ParseCredentialsMetadata(context.Background(), nil) - if !errors.Is(err, errMetaDataIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errMetaDataIsNil) - } - - _, err = ParseCredentialsMetadata(context.Background(), metadata.MD{}) - if !errors.Is(err, nil) { - t.Fatalf("received: '%v' but expected: '%v'", err, nil) - } - - ctx := metadata.AppendToOutgoingContext(context.Background(), - string(contextCredentialsFlag), "wow", string(contextCredentialsFlag), "wow2") - nortyMD, _ := metadata.FromOutgoingContext(ctx) - - _, err = ParseCredentialsMetadata(context.Background(), nortyMD) - if !errors.Is(err, errInvalidCredentialMetaDataLength) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidCredentialMetaDataLength) - } - - ctx = metadata.AppendToOutgoingContext(context.Background(), - string(contextCredentialsFlag), "brokenstring") - nortyMD, _ = metadata.FromOutgoingContext(ctx) - - _, err = ParseCredentialsMetadata(context.Background(), nortyMD) - if !errors.Is(err, errMissingInfo) { - t.Fatalf("received: '%v' but expected: '%v'", err, errMissingInfo) - } - - beforeCreds := Credentials{ - Key: "superkey", - Secret: "supersecret", - SubAccount: "supersub", - ClientID: "superclient", - PEMKey: "superpem", - OneTimePassword: "superOneTimePasssssss", - } - - flag, outGoing := beforeCreds.GetMetaData() - ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing) - lovelyMD, _ := metadata.FromOutgoingContext(ctx) - - ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD) - if !errors.Is(err, nil) { - t.Fatalf("received: '%v' but expected: '%v'", err, nil) - } - - store, ok := ctx.Value(contextCredentialsFlag).(*contextCredentialsStore) - if !ok { - t.Fatal("should have processed") - } - - afterCreds := store.Get() - - if afterCreds.Key != "superkey" && - afterCreds.Secret != "supersecret" && - afterCreds.SubAccount != "supersub" && - afterCreds.ClientID != "superclient" && - afterCreds.PEMKey != "superpem" && - afterCreds.OneTimePassword != "superOneTimePasssssss" { - t.Fatal("unexpected values") - } - - // subaccount override - subaccount := Credentials{ - SubAccount: "supersub", - } - - flag, outGoing = subaccount.GetMetaData() - ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing) - lovelyMD, _ = metadata.FromOutgoingContext(ctx) - - ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD) - if !errors.Is(err, nil) { - t.Fatalf("received: '%v' but expected: '%v'", err, nil) - } - - sa, ok := ctx.Value(contextSubAccountFlag).(string) - if !ok { - t.Fatal("should have processed") - } - - if sa != "supersub" { - t.Fatal("unexpected value") - } -} - func TestGetCredentials(t *testing.T) { t.Parallel() var b Base @@ -106,26 +18,26 @@ func TestGetCredentials(t *testing.T) { } b.API.CredentialsValidator.RequiresKey = true - ctx := DeployCredentialsToContext(context.Background(), &Credentials{Secret: "wow"}) + ctx := account.DeployCredentialsToContext(context.Background(), &account.Credentials{Secret: "wow"}) _, err = b.GetCredentials(ctx) if !errors.Is(err, errRequiresAPIKey) { t.Fatalf("received: %v but expected: %v", err, errRequiresAPIKey) } b.API.CredentialsValidator.RequiresSecret = true - ctx = DeployCredentialsToContext(context.Background(), &Credentials{Key: "wow"}) + ctx = account.DeployCredentialsToContext(context.Background(), &account.Credentials{Key: "wow"}) _, err = b.GetCredentials(ctx) if !errors.Is(err, errRequiresAPISecret) { t.Fatalf("received: %v but expected: %v", err, errRequiresAPISecret) } - ctx = context.WithValue(context.Background(), contextCredentialsFlag, "pewpew") + ctx = context.WithValue(context.Background(), account.ContextCredentialsFlag, "pewpew") _, err = b.GetCredentials(ctx) if !errors.Is(err, errContextCredentialsFailure) { t.Fatalf("received: %v but expected: %v", err, errContextCredentialsFailure) } - fullCred := Credentials{ + fullCred := &account.Credentials{ Key: "superkey", Secret: "supersecret", SubAccount: "supersub", @@ -134,9 +46,7 @@ func TestGetCredentials(t *testing.T) { OneTimePassword: "superOneTimePasssssss", } - flag, store := fullCred.getInternal() - - ctx = context.WithValue(context.Background(), flag, store) + ctx = account.DeployCredentialsToContext(context.Background(), fullCred) creds, err := b.GetCredentials(ctx) if !errors.Is(err, nil) { t.Fatalf("received: %v but expected: %v", err, nil) @@ -151,7 +61,7 @@ func TestGetCredentials(t *testing.T) { t.Fatal("unexpected values") } - lonelyCred := Credentials{ + lonelyCred := &account.Credentials{ Key: "superkey", Secret: "supersecret", SubAccount: "supersub", @@ -159,9 +69,7 @@ func TestGetCredentials(t *testing.T) { OneTimePassword: "superOneTimePasssssss", } - flag, store = lonelyCred.getInternal() - - ctx = context.WithValue(context.Background(), flag, store) + ctx = account.DeployCredentialsToContext(context.Background(), lonelyCred) b.API.CredentialsValidator.RequiresClientID = true _, err = b.GetCredentials(ctx) if !errors.Is(err, errRequiresAPIClientID) { @@ -171,7 +79,8 @@ func TestGetCredentials(t *testing.T) { b.API.SetKey("hello") b.API.SetSecret("sir") b.API.SetClientID("1337") - ctx = deploySubAccountOverrideToContext(context.Background(), "superaccount") + + ctx = context.WithValue(context.Background(), account.ContextSubAccountFlag, "superaccount") overridedSA, err := b.GetCredentials(ctx) if !errors.Is(err, nil) { t.Fatalf("received: %v but expected: %v", err, nil) @@ -203,7 +112,7 @@ func TestAreCredentialsValid(t *testing.T) { if b.AreCredentialsValid(context.Background()) { t.Fatal("should not be valid") } - ctx := DeployCredentialsToContext(context.Background(), &Credentials{Key: "hello"}) + ctx := account.DeployCredentialsToContext(context.Background(), &account.Credentials{Key: "hello"}) if !b.AreCredentialsValid(ctx) { t.Fatal("should be valid") } @@ -277,11 +186,11 @@ func TestCheckCredentials(t *testing.T) { b := Base{ SkipAuthCheck: true, - API: API{credentials: &Credentials{}}, + API: API{credentials: &account.Credentials{}}, } // Test SkipAuthCheck - err := b.CheckCredentials(&Credentials{}, false) + err := b.CheckCredentials(&account.Credentials{}, false) if !errors.Is(err, nil) { t.Errorf("received '%v' expected '%v'", err, nil) } @@ -322,56 +231,35 @@ func TestCheckCredentials(t *testing.T) { } } -func TestGetInternal(t *testing.T) { - t.Parallel() - flag, store := (&Credentials{}).getInternal() - if flag != "" { - t.Fatal("unexpected value") - } - if store != nil { - t.Fatal("unexpected value") - } - flag, store = (&Credentials{Key: "wow"}).getInternal() - if flag != contextCredentialsFlag { - t.Fatal("unexpected value") - } - if store == nil { - t.Fatal("unexpected value") - } - if store.Get().Key != "wow" { - t.Fatal("unexpected value") - } -} - func TestAPISetters(t *testing.T) { t.Parallel() api := API{} - api.SetKey(key) - if api.credentials.Key != key { + api.SetKey(account.Key) + if api.credentials.Key != account.Key { t.Fatal("unexpected value") } api = API{} - api.SetSecret(secret) - if api.credentials.Secret != secret { + api.SetSecret(account.Secret) + if api.credentials.Secret != account.Secret { t.Fatal("unexpected value") } api = API{} - api.SetClientID((clientID)) - if api.credentials.ClientID != clientID { + api.SetClientID(account.ClientID) + if api.credentials.ClientID != account.ClientID { t.Fatal("unexpected value") } api = API{} - api.SetPEMKey(_PEMKey) - if api.credentials.PEMKey != _PEMKey { + api.SetPEMKey(account.PEMKey) + if api.credentials.PEMKey != account.PEMKey { t.Fatal("unexpected value") } api = API{} - api.SetSubAccount(subAccount) - if api.credentials.SubAccount != subAccount { + api.SetSubAccount(account.SubAccountSTR) + if api.credentials.SubAccount != account.SubAccountSTR { t.Fatal("unexpected value") } } @@ -471,19 +359,3 @@ func TestGetAuthenticatedAPISupport(t *testing.T) { t.Fatal("Expected WebsocketAuthentication to return true") } } - -func TestIsEmpty(t *testing.T) { - var c *Credentials - if !c.IsEmpty() { - t.Fatalf("expected: %v but received: %v", true, c.IsEmpty()) - } - c = new(Credentials) - if !c.IsEmpty() { - t.Fatalf("expected: %v but received: %v", true, c.IsEmpty()) - } - - c.SubAccount = "woow" - if c.IsEmpty() { - t.Fatalf("expected: %v but received: %v", false, c.IsEmpty()) - } -} diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 1529f423..e734456e 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -14,6 +14,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/currencystate" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -470,7 +471,7 @@ func (b *Base) SetupDefaults(exch *config.Exchange) error { b.API.AuthenticatedSupport = exch.API.AuthenticatedSupport b.API.AuthenticatedWebsocketSupport = exch.API.AuthenticatedWebsocketSupport if b.API.credentials == nil { - b.API.credentials = &Credentials{} + b.API.credentials = &account.Credentials{} } b.API.credentials.SubAccount = exch.API.Credentials.Subaccount if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { diff --git a/exchanges/exchange_types.go b/exchanges/exchange_types.go index 0e914607..2488e153 100644 --- a/exchanges/exchange_types.go +++ b/exchanges/exchange_types.go @@ -6,6 +6,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/currencystate" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -186,7 +187,7 @@ type API struct { Endpoints *Endpoints - credentials *Credentials + credentials *account.Credentials credMu sync.RWMutex CredentialsValidator struct { diff --git a/exchanges/exmo/exmo_wrapper.go b/exchanges/exmo/exmo_wrapper.go index 54deb3a8..2bcdf7ff 100644 --- a/exchanges/exmo/exmo_wrapper.go +++ b/exchanges/exmo/exmo_wrapper.go @@ -383,7 +383,11 @@ func (e *EXMO) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (acc Currencies: currencies, }) - err = account.Process(&response) + creds, err := e.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -393,11 +397,14 @@ func (e *EXMO) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (acc // FetchAccountInfo retrieves balances for all enabled currencies func (e *EXMO) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(e.Name, assetType) + creds, err := e.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(e.Name, creds, assetType) if err != nil { return e.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/ftx/ftx_wrapper.go b/exchanges/ftx/ftx_wrapper.go index 5f14ba03..d189b6fe 100644 --- a/exchanges/ftx/ftx_wrapper.go +++ b/exchanges/ftx/ftx_wrapper.go @@ -507,7 +507,7 @@ func (f *FTX) UpdateAccountInfo(ctx context.Context, a asset.Item) (account.Hold } resp.Exchange = f.Name - if err := account.Process(&resp); err != nil { + if err := account.Process(&resp, creds); err != nil { return account.Holdings{}, err } @@ -516,11 +516,14 @@ func (f *FTX) UpdateAccountInfo(ctx context.Context, a asset.Item) (account.Hold // FetchAccountInfo retrieves balances for all enabled currencies func (f *FTX) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(f.Name, assetType) + creds, err := f.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(f.Name, creds, assetType) if err != nil { return f.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 7c4fc8a9..79c9fa89 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -427,7 +427,11 @@ func (g *Gateio) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a } info.Exchange = g.Name - if err := account.Process(&info); err != nil { + creds, err := g.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return account.Holdings{}, err } @@ -436,11 +440,14 @@ func (g *Gateio) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (g *Gateio) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(g.Name, assetType) + creds, err := g.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(g.Name, creds, assetType) if err != nil { return g.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index 22666372..516038c0 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -333,7 +333,11 @@ func (g *Gemini) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a Currencies: currencies, }) - err = account.Process(&response) + creds, err := g.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -343,11 +347,14 @@ func (g *Gemini) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (g *Gemini) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(g.Name, assetType) + creds, err := g.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(g.Name, creds, assetType) if err != nil { return g.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index fb466267..e998ebe6 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -453,7 +453,11 @@ func (h *HitBTC) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a Currencies: currencies, }) - err = account.Process(&response) + creds, err := h.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -463,7 +467,11 @@ func (h *HitBTC) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (h *HitBTC) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(h.Name, assetType) + creds, err := h.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(h.Name, creds, assetType) if err != nil { return h.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 273c67e9..f3d59ff0 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -15,7 +15,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" - exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" @@ -545,7 +545,7 @@ func (h *HUOBI) GenerateDefaultSubscriptions() ([]stream.ChannelSubscription, er // Subscribe sends a websocket message to receive data from the channel func (h *HUOBI) Subscribe(channelsToSubscribe []stream.ChannelSubscription) error { - var creds *exchange.Credentials + var creds *account.Credentials if h.Websocket.CanUseAuthenticatedEndpoints() { var err error creds, err = h.GetCredentials(context.TODO()) @@ -585,7 +585,7 @@ func (h *HUOBI) Subscribe(channelsToSubscribe []stream.ChannelSubscription) erro // Unsubscribe sends a websocket message to stop receiving data from the channel func (h *HUOBI) Unsubscribe(channelsToUnsubscribe []stream.ChannelSubscription) error { - var creds *exchange.Credentials + var creds *account.Credentials if h.Websocket.CanUseAuthenticatedEndpoints() { var err error creds, err = h.GetCredentials(context.TODO()) @@ -623,7 +623,7 @@ func (h *HUOBI) Unsubscribe(channelsToUnsubscribe []stream.ChannelSubscription) return nil } -func (h *HUOBI) wsGenerateSignature(creds *exchange.Credentials, timestamp, endpoint string) ([]byte, error) { +func (h *HUOBI) wsGenerateSignature(creds *account.Credentials, timestamp, endpoint string) ([]byte, error) { values := url.Values{} values.Set("AccessKeyId", creds.Key) values.Set("SignatureMethod", signatureMethod) @@ -668,7 +668,7 @@ func (h *HUOBI) wsLogin(ctx context.Context) error { return nil } -func (h *HUOBI) wsAuthenticatedSubscribe(creds *exchange.Credentials, operation, endpoint, topic string) error { +func (h *HUOBI) wsAuthenticatedSubscribe(creds *account.Credentials, operation, endpoint, topic string) error { timestamp := time.Now().UTC().Format(wsDateTimeFormatting) request := WsAuthenticatedSubscriptionRequest{ Op: operation, diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 4b748fe9..b3bcfa5d 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -808,7 +808,11 @@ func (h *HUOBI) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac } acc.AssetType = assetType info.Accounts = append(info.Accounts, acc) - if err := account.Process(&info); err != nil { + creds, err := h.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return info, err } return info, nil @@ -816,7 +820,11 @@ func (h *HUOBI) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac // FetchAccountInfo retrieves balances for all enabled currencies func (h *HUOBI) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(h.Name, assetType) + creds, err := h.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(h.Name, creds, assetType) if err != nil { return h.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index d95c9d4f..8d7cf589 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -85,7 +85,7 @@ type IBotExchange interface { UpdateOrderExecutionLimits(ctx context.Context, a asset.Item) error AccountManagement - GetCredentials(ctx context.Context) (*Credentials, error) + GetCredentials(ctx context.Context) (*account.Credentials, error) ValidateCredentials(ctx context.Context, a asset.Item) error FunctionalityChecker diff --git a/exchanges/itbit/itbit_wrapper.go b/exchanges/itbit/itbit_wrapper.go index b540c436..bc636254 100644 --- a/exchanges/itbit/itbit_wrapper.go +++ b/exchanges/itbit/itbit_wrapper.go @@ -302,7 +302,11 @@ func (i *ItBit) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac Currencies: fullBalance, }) - err = account.Process(&info) + creds, err := i.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&info, creds) if err != nil { return account.Holdings{}, err } @@ -312,11 +316,14 @@ func (i *ItBit) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac // FetchAccountInfo retrieves balances for all enabled currencies func (i *ItBit) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(i.Name, assetType) + creds, err := i.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(i.Name, creds, assetType) if err != nil { return i.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index 6f6328d6..67dfa881 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -630,7 +630,11 @@ func (k *Kraken) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a }) } } - if err := account.Process(&info); err != nil { + creds, err := k.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return account.Holdings{}, err } return info, nil @@ -638,7 +642,11 @@ func (k *Kraken) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (a // FetchAccountInfo retrieves balances for all enabled currencies func (k *Kraken) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(k.Name, assetType) + creds, err := k.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(k.Name, creds, assetType) if err != nil { return k.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/lbank/lbank_test.go b/exchanges/lbank/lbank_test.go index 6afb2605..c5306e66 100644 --- a/exchanges/lbank/lbank_test.go +++ b/exchanges/lbank/lbank_test.go @@ -14,6 +14,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -310,7 +311,8 @@ func TestLoadPrivKey(t *testing.T) { t.Error(err) } - ctx := exchange.DeployCredentialsToContext(context.Background(), &exchange.Credentials{Secret: "errortest"}) + ctx := account.DeployCredentialsToContext(context.Background(), + &account.Credentials{Secret: "errortest"}) err = l.loadPrivKey(ctx) if err == nil { t.Errorf("Expected error due to pemblock nil") diff --git a/exchanges/lbank/lbank_wrapper.go b/exchanges/lbank/lbank_wrapper.go index 9c0d3844..659da3c1 100644 --- a/exchanges/lbank/lbank_wrapper.go +++ b/exchanges/lbank/lbank_wrapper.go @@ -353,7 +353,11 @@ func (l *Lbank) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac info.Accounts = append(info.Accounts, acc) info.Exchange = l.Name - err = account.Process(&info) + creds, err := l.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&info, creds) if err != nil { return account.Holdings{}, err } @@ -362,7 +366,11 @@ func (l *Lbank) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac // FetchAccountInfo retrieves balances for all enabled currencies func (l *Lbank) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(l.Name, assetType) + creds, err := l.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(l.Name, creds, assetType) if err != nil { return l.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/localbitcoins/localbitcoins_wrapper.go b/exchanges/localbitcoins/localbitcoins_wrapper.go index f1443e07..13147c48 100644 --- a/exchanges/localbitcoins/localbitcoins_wrapper.go +++ b/exchanges/localbitcoins/localbitcoins_wrapper.go @@ -300,7 +300,11 @@ func (l *LocalBitcoins) UpdateAccountInfo(ctx context.Context, assetType asset.I }}, }) - err = account.Process(&response) + creds, err := l.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -310,11 +314,14 @@ func (l *LocalBitcoins) UpdateAccountInfo(ctx context.Context, assetType asset.I // FetchAccountInfo retrieves balances for all enabled currencies func (l *LocalBitcoins) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(l.Name, assetType) + creds, err := l.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(l.Name, creds, assetType) if err != nil { return l.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/okgroup/okgroup.go b/exchanges/okgroup/okgroup.go index 57459bdc..7248f318 100644 --- a/exchanges/okgroup/okgroup.go +++ b/exchanges/okgroup/okgroup.go @@ -16,6 +16,7 @@ import ( "github.com/google/go-querystring/query" "github.com/thrasher-corp/gocryptotrader/common/crypto" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/log" @@ -585,7 +586,7 @@ func (o *OKGroup) SendHTTPRequest(ctx context.Context, ep exchange.URL, httpMeth headers := make(map[string]string) headers["Content-Type"] = "application/json" if authenticated { - var creds *exchange.Credentials + var creds *account.Credentials creds, err = o.GetCredentials(ctx) if err != nil { return nil, err diff --git a/exchanges/okgroup/okgroup_wrapper.go b/exchanges/okgroup/okgroup_wrapper.go index 52261268..795c0430 100644 --- a/exchanges/okgroup/okgroup_wrapper.go +++ b/exchanges/okgroup/okgroup_wrapper.go @@ -216,7 +216,11 @@ func (o *OKGroup) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( resp.Accounts = append(resp.Accounts, currencyAccount) - err = account.Process(&resp) + creds, err := o.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&resp, creds) if err != nil { return resp, err } @@ -226,11 +230,14 @@ func (o *OKGroup) UpdateAccountInfo(ctx context.Context, assetType asset.Item) ( // FetchAccountInfo retrieves balances for all enabled currencies func (o *OKGroup) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(o.Name, assetType) + creds, err := o.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(o.Name, creds, assetType) if err != nil { return o.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 080d6f6d..f7924f2f 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -14,7 +14,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" - exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -542,7 +541,7 @@ func (p *Poloniex) GenerateDefaultSubscriptions() ([]stream.ChannelSubscription, // Subscribe sends a websocket message to receive data from the channel func (p *Poloniex) Subscribe(sub []stream.ChannelSubscription) error { - var creds *exchange.Credentials + var creds *account.Credentials if p.IsWebsocketAuthenticationSupported() { var err error creds, err = p.GetCredentials(context.TODO()) @@ -589,7 +588,7 @@ channels: // Unsubscribe sends a websocket message to stop receiving data from the channel func (p *Poloniex) Unsubscribe(unsub []stream.ChannelSubscription) error { - var creds *exchange.Credentials + var creds *account.Credentials if p.IsWebsocketAuthenticationSupported() { var err error creds, err = p.GetCredentials(context.TODO()) diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index cee92067..fda750f8 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -427,7 +427,11 @@ func (p *Poloniex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) Currencies: currencies, }) - err = account.Process(&response) + creds, err := p.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -437,7 +441,11 @@ func (p *Poloniex) UpdateAccountInfo(ctx context.Context, assetType asset.Item) // FetchAccountInfo retrieves balances for all enabled currencies func (p *Poloniex) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(p.Name, assetType) + creds, err := p.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(p.Name, creds, assetType) if err != nil { return p.UpdateAccountInfo(ctx, assetType) } diff --git a/exchanges/yobit/yobit_wrapper.go b/exchanges/yobit/yobit_wrapper.go index 8558c375..dd503189 100644 --- a/exchanges/yobit/yobit_wrapper.go +++ b/exchanges/yobit/yobit_wrapper.go @@ -328,7 +328,11 @@ func (y *Yobit) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac Currencies: currencies, }) - err = account.Process(&response) + creds, err := y.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + err = account.Process(&response, creds) if err != nil { return account.Holdings{}, err } @@ -338,11 +342,14 @@ func (y *Yobit) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (ac // FetchAccountInfo retrieves balances for all enabled currencies func (y *Yobit) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(y.Name, assetType) + creds, err := y.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(y.Name, creds, assetType) if err != nil { return y.UpdateAccountInfo(ctx, assetType) } - return acc, nil } diff --git a/exchanges/zb/zb_wrapper.go b/exchanges/zb/zb_wrapper.go index 52a676e7..b2c55205 100644 --- a/exchanges/zb/zb_wrapper.go +++ b/exchanges/zb/zb_wrapper.go @@ -397,7 +397,11 @@ func (z *ZB) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (accou Currencies: balances, }) - if err := account.Process(&info); err != nil { + creds, err := z.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + if err := account.Process(&info, creds); err != nil { return account.Holdings{}, err } @@ -406,11 +410,14 @@ func (z *ZB) UpdateAccountInfo(ctx context.Context, assetType asset.Item) (accou // FetchAccountInfo retrieves balances for all enabled currencies func (z *ZB) FetchAccountInfo(ctx context.Context, assetType asset.Item) (account.Holdings, error) { - acc, err := account.GetHoldings(z.Name, assetType) + creds, err := z.GetCredentials(ctx) + if err != nil { + return account.Holdings{}, err + } + acc, err := account.GetHoldings(z.Name, creds, assetType) if err != nil { return z.UpdateAccountInfo(ctx, assetType) } - return acc, nil }