account: segregate holdings by credentials for future multi-key management (#956)

* exchanges/account: shift credentials to account package and segregate funds to keys

* merge: fixes

* linter: fix

* Update exchanges/account/account.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* glorious: nits + protection for string panic

* glorious_suggestion: add method for matching keys

* linter: fix tests

* account: add protected method for credentials minimizing access, display full account details to rpc.

* linter: spelling kweeeeeeen

* accounts/portfolio: clean/check portfolio code and quickly check balances from change. Add protected method for future matching.

* accounts: theres no point in pointerising everything

* linter: ok pointerise this then...

* exchanges: fix regression add in little notes.

* glorious: nits

* Update exchanges/account/credentials.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* Update exchanges/account/credentials_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* Update exchanges/account/credentials_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* glorious: nits

* gloriously: fix glorious glorious test gloriously

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2022-07-21 15:05:31 +10:00
committed by GitHub
parent 455738f25f
commit 663e753f52
48 changed files with 1010 additions and 549 deletions

View File

@@ -26,6 +26,8 @@ var (
errNoExchangeSubAccountBalances = errors.New("no exchange sub account balances")
errNoBalanceFound = errors.New("no balance found")
errBalanceIsNil = errors.New("balance is nil")
errNoCredentialBalances = errors.New("no balances associated with credentials")
errCredentialsAreNil = errors.New("credentials are nil")
)
// CollectBalances converts a map of sub-account balances into a slice
@@ -64,16 +66,22 @@ func SubscribeToExchangeAccount(exchange string) (dispatch.Pipe, error) {
}
// Process processes new account holdings updates
func Process(h *Holdings) error {
return service.Update(h)
func Process(h *Holdings, c *Credentials) error {
return service.Update(h, c)
}
// GetHoldings returns full holdings for an exchange
func GetHoldings(exch string, assetType asset.Item) (Holdings, error) {
// GetHoldings returns full holdings for an exchange.
// NOTE: Due to credentials these amounts could be N*APIKEY actual holdings.
// TODO: Add jurisdiction and differentiation between APIKEY holdings.
func GetHoldings(exch string, creds *Credentials, assetType asset.Item) (Holdings, error) {
if exch == "" {
return Holdings{}, errExchangeNameUnset
}
if creds.IsEmpty() {
return Holdings{}, fmt.Errorf("%s %s %w", exch, assetType, errCredentialsAreNil)
}
if !assetType.IsValid() {
return Holdings{}, fmt.Errorf("%s %s %w", exch, assetType, asset.ErrNotSupported)
}
@@ -88,7 +96,16 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) {
}
var accountsHoldings []SubAccount
for subAccount, assetHoldings := range accounts.SubAccounts {
subAccountHoldings, ok := accounts.SubAccounts[*creds]
if !ok {
return Holdings{}, fmt.Errorf("%s %s %s %w",
exch,
creds,
assetType,
errNoCredentialBalances)
}
for subAccount, assetHoldings := range subAccountHoldings {
for ai, currencyHoldings := range assetHoldings {
if ai != assetType {
continue
@@ -113,10 +130,16 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) {
continue
}
cpy := *creds
if cpy.SubAccount == "" {
cpy.SubAccount = subAccount
}
accountsHoldings = append(accountsHoldings, SubAccount{
ID: subAccount,
AssetType: ai,
Currencies: currencyBalances,
Credentials: Protected{creds: cpy},
ID: subAccount,
AssetType: ai,
Currencies: currencyBalances,
})
break
}
@@ -132,22 +155,24 @@ func GetHoldings(exch string, assetType asset.Item) (Holdings, error) {
}
// GetBalance returns the internal balance for that asset item.
func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*ProtectedBalance, error) {
func GetBalance(exch, subAccount string, creds *Credentials, ai asset.Item, c currency.Code) (*ProtectedBalance, error) {
if exch == "" {
return nil, errExchangeNameUnset
return nil, fmt.Errorf("cannot get balance: %w", errExchangeNameUnset)
}
if !ai.IsValid() {
return nil, fmt.Errorf("%s %w", ai, asset.ErrNotSupported)
return nil, fmt.Errorf("cannot get balance: %s %w", ai, asset.ErrNotSupported)
}
if creds.IsEmpty() {
return nil, fmt.Errorf("cannot get balance: %w", errCredentialsAreNil)
}
if c.IsEmpty() {
return nil, currency.ErrCurrencyCodeEmpty
return nil, fmt.Errorf("cannot get balance: %w", currency.ErrCurrencyCodeEmpty)
}
exch = strings.ToLower(exch)
subAccount = strings.ToLower(subAccount)
service.mu.Lock()
defer service.mu.Unlock()
@@ -156,7 +181,13 @@ func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*Prote
return nil, fmt.Errorf("%s %w", exch, errExchangeHoldingsNotFound)
}
assetBalances, ok := accounts.SubAccounts[subAccount]
subAccounts, ok := accounts.SubAccounts[*creds]
if !ok {
return nil, fmt.Errorf("%s %s %w",
exch, creds, errNoCredentialBalances)
}
assetBalances, ok := subAccounts[subAccount]
if !ok {
return nil, fmt.Errorf("%s %s %w",
exch, subAccount, errNoExchangeSubAccountBalances)
@@ -177,16 +208,20 @@ func GetBalance(exch, subAccount string, ai asset.Item, c currency.Code) (*Prote
}
// Update updates holdings with new account info
func (s *Service) Update(a *Holdings) error {
if a == nil {
return errHoldingsIsNil
func (s *Service) Update(incoming *Holdings, creds *Credentials) error {
if incoming == nil {
return fmt.Errorf("cannot update holdings: %w", errHoldingsIsNil)
}
if a.Exchange == "" {
return errExchangeNameUnset
if incoming.Exchange == "" {
return fmt.Errorf("cannot update holdings: %w", errExchangeNameUnset)
}
exch := strings.ToLower(a.Exchange)
if creds.IsEmpty() {
return fmt.Errorf("cannot update holdings: %w", errCredentialsAreNil)
}
exch := strings.ToLower(incoming.Exchange)
s.mu.Lock()
defer s.mu.Unlock()
accounts, ok := s.exchangeAccounts[exch]
@@ -197,46 +232,65 @@ func (s *Service) Update(a *Holdings) error {
}
accounts = &Accounts{
ID: id,
SubAccounts: make(map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance),
SubAccounts: make(map[Credentials]map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance),
}
s.exchangeAccounts[exch] = accounts
}
var errs common.Errors
for x := range a.Accounts {
if !a.Accounts[x].AssetType.IsValid() {
for x := range incoming.Accounts {
if !incoming.Accounts[x].AssetType.IsValid() {
errs = append(errs, fmt.Errorf("cannot load sub account holdings for %s [%s] %w",
a.Accounts[x].ID,
a.Accounts[x].AssetType,
incoming.Accounts[x].ID,
incoming.Accounts[x].AssetType,
asset.ErrNotSupported))
continue
}
lowerSA := strings.ToLower(a.Accounts[x].ID)
// This assignment outside of scope is designed to have minimal impact
// on the exchange implementation UpdateAccountInfo() and portfoio
// management.
// TODO: Update incoming Holdings type to already be populated. (Suggestion)
cpy := *creds
if cpy.SubAccount == "" {
cpy.SubAccount = incoming.Accounts[x].ID
}
incoming.Accounts[x].Credentials.creds = cpy
var subAccounts map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance
subAccounts, ok = accounts.SubAccounts[*creds]
if !ok {
subAccounts = make(map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance)
accounts.SubAccounts[*creds] = subAccounts
}
var accountAssets map[asset.Item]map[*currency.Item]*ProtectedBalance
accountAssets, ok = accounts.SubAccounts[lowerSA]
accountAssets, ok = subAccounts[incoming.Accounts[x].ID]
if !ok {
accountAssets = make(map[asset.Item]map[*currency.Item]*ProtectedBalance)
accounts.SubAccounts[lowerSA] = accountAssets
// Note: Sub accounts are case sensitive and an account "name" is
// different to account "naMe".
subAccounts[incoming.Accounts[x].ID] = accountAssets
}
var currencyBalances map[*currency.Item]*ProtectedBalance
currencyBalances, ok = accountAssets[a.Accounts[x].AssetType]
currencyBalances, ok = accountAssets[incoming.Accounts[x].AssetType]
if !ok {
currencyBalances = make(map[*currency.Item]*ProtectedBalance)
accountAssets[a.Accounts[x].AssetType] = currencyBalances
accountAssets[incoming.Accounts[x].AssetType] = currencyBalances
}
for y := range a.Accounts[x].Currencies {
bal := currencyBalances[a.Accounts[x].Currencies[y].CurrencyName.Item]
for y := range incoming.Accounts[x].Currencies {
bal := currencyBalances[incoming.Accounts[x].Currencies[y].CurrencyName.Item]
if bal == nil {
bal = &ProtectedBalance{}
currencyBalances[a.Accounts[x].Currencies[y].CurrencyName.Item] = bal
currencyBalances[incoming.Accounts[x].Currencies[y].CurrencyName.Item] = bal
}
bal.load(a.Accounts[x].Currencies[y])
bal.load(incoming.Accounts[x].Currencies[y])
}
}
err := s.mux.Publish(a, accounts.ID)
err := s.mux.Publish(incoming, accounts.ID)
if err != nil {
return err
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
)
var happyCredentials = &Credentials{Key: "AAAAA"}
func TestCollectBalances(t *testing.T) {
t.Parallel()
accounts, err := CollectBalances(
@@ -63,12 +65,12 @@ func TestGetHoldings(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = Process(nil)
err = Process(nil, nil)
if !errors.Is(err, errHoldingsIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHoldingsIsNil)
}
err = Process(&Holdings{})
err = Process(&Holdings{}, nil)
if !errors.Is(err, errExchangeNameUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset)
}
@@ -77,9 +79,14 @@ func TestGetHoldings(t *testing.T) {
Exchange: "Test",
}
err = Process(&holdings)
if err != nil {
t.Error(err)
err = Process(&holdings, nil)
if !errors.Is(err, errCredentialsAreNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil)
}
err = Process(&holdings, happyCredentials)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
err = Process(&Holdings{
@@ -88,7 +95,7 @@ func TestGetHoldings(t *testing.T) {
{
ID: "1337",
}},
})
}, happyCredentials)
if !errors.Is(err, asset.ErrNotSupported) {
t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported)
}
@@ -111,7 +118,7 @@ func TestGetHoldings(t *testing.T) {
},
},
}},
})
}, happyCredentials)
if err != nil {
t.Error(err)
}
@@ -131,32 +138,42 @@ func TestGetHoldings(t *testing.T) {
},
},
}},
})
}, happyCredentials)
if err != nil {
t.Error(err)
}
_, err = GetHoldings("", asset.Spot)
_, err = GetHoldings("", nil, asset.Spot)
if !errors.Is(err, errExchangeNameUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset)
}
_, err = GetHoldings("bla", asset.Spot)
_, err = GetHoldings("bla", nil, asset.Spot)
if !errors.Is(err, errCredentialsAreNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil)
}
_, err = GetHoldings("bla", happyCredentials, asset.Spot)
if !errors.Is(err, errExchangeHoldingsNotFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeHoldingsNotFound)
}
_, err = GetHoldings("bla", asset.Empty)
_, err = GetHoldings("bla", happyCredentials, asset.Empty)
if !errors.Is(err, asset.ErrNotSupported) {
t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported)
}
_, err = GetHoldings("Test", asset.UpsideProfitContract)
_, err = GetHoldings("Test", happyCredentials, asset.UpsideProfitContract)
if !errors.Is(err, errAssetHoldingsNotFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errAssetHoldingsNotFound)
}
u, err := GetHoldings("Test", asset.Spot)
_, err = GetHoldings("Test", &Credentials{Key: "BBBBB"}, asset.Spot)
if !errors.Is(err, errNoCredentialBalances) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNoCredentialBalances)
}
u, err := GetHoldings("Test", happyCredentials, asset.Spot)
if err != nil {
t.Error(err)
}
@@ -217,7 +234,7 @@ func TestGetHoldings(t *testing.T) {
},
},
}},
})
}, happyCredentials)
if err != nil {
t.Error(err)
}
@@ -226,22 +243,27 @@ func TestGetHoldings(t *testing.T) {
}
func TestGetBalance(t *testing.T) {
_, err := GetBalance("", "", asset.Empty, currency.Code{})
_, err := GetBalance("", "", nil, asset.Empty, currency.Code{})
if !errors.Is(err, errExchangeNameUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset)
}
_, err = GetBalance("bruh", "", asset.Empty, currency.Code{})
_, err = GetBalance("bruh", "", nil, asset.Empty, currency.Code{})
if !errors.Is(err, asset.ErrNotSupported) {
t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported)
}
_, err = GetBalance("bruh", "", asset.Spot, currency.Code{})
_, err = GetBalance("bruh", "", nil, asset.Spot, currency.Code{})
if !errors.Is(err, errCredentialsAreNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errCredentialsAreNil)
}
_, err = GetBalance("bruh", "", happyCredentials, asset.Spot, currency.Code{})
if !errors.Is(err, currency.ErrCurrencyCodeEmpty) {
t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyCodeEmpty)
}
_, err = GetBalance("bruh", "", asset.Spot, currency.BTC)
_, err = GetBalance("bruh", "", happyCredentials, asset.Spot, currency.BTC)
if !errors.Is(err, errExchangeHoldingsNotFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeHoldingsNotFound)
}
@@ -254,22 +276,27 @@ func TestGetBalance(t *testing.T) {
ID: "1337",
},
},
})
}, happyCredentials)
if err != nil {
t.Error(err)
}
_, err = GetBalance("bruh", "1336", asset.Spot, currency.BTC)
_, err = GetBalance("bruh", "1336", &Credentials{Key: "BBBBB"}, asset.Spot, currency.BTC)
if !errors.Is(err, errNoCredentialBalances) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNoCredentialBalances)
}
_, err = GetBalance("bruh", "1336", happyCredentials, asset.Spot, currency.BTC)
if !errors.Is(err, errNoExchangeSubAccountBalances) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNoExchangeSubAccountBalances)
}
_, err = GetBalance("bruh", "1337", asset.Futures, currency.BTC)
_, err = GetBalance("bruh", "1337", happyCredentials, asset.Futures, currency.BTC)
if !errors.Is(err, errAssetHoldingsNotFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errAssetHoldingsNotFound)
}
_, err = GetBalance("bruh", "1337", asset.Spot, currency.BTC)
_, err = GetBalance("bruh", "1337", happyCredentials, asset.Spot, currency.BTC)
if !errors.Is(err, errNoBalanceFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNoBalanceFound)
}
@@ -289,12 +316,12 @@ func TestGetBalance(t *testing.T) {
},
},
},
})
}, happyCredentials)
if err != nil {
t.Error(err)
}
bal, err := GetBalance("bruh", "1337", asset.Spot, currency.BTC)
bal, err := GetBalance("bruh", "1337", happyCredentials, asset.Spot, currency.BTC)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
@@ -379,12 +406,12 @@ func TestGetFree(t *testing.T) {
func TestUpdate(t *testing.T) {
t.Parallel()
s := &Service{exchangeAccounts: make(map[string]*Accounts), mux: dispatch.GetNewMux(nil)}
err := s.Update(nil)
err := s.Update(nil, nil)
if !errors.Is(err, errHoldingsIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHoldingsIsNil)
}
err = s.Update(&Holdings{})
err = s.Update(&Holdings{}, nil)
if !errors.Is(err, errExchangeNameUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeNameUnset)
}
@@ -416,7 +443,7 @@ func TestUpdate(t *testing.T) {
},
},
},
})
}, happyCredentials)
if !errors.Is(err, asset.ErrNotSupported) {
t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported)
}
@@ -436,7 +463,7 @@ func TestUpdate(t *testing.T) {
},
},
},
})
}, happyCredentials)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
@@ -446,7 +473,7 @@ func TestUpdate(t *testing.T) {
t.Fatal("account should be loaded")
}
b, ok := acc.SubAccounts["1337"][asset.Spot][currency.BTC.Item]
b, ok := acc.SubAccounts[Credentials{Key: "AAAAA"}]["1337"][asset.Spot][currency.BTC.Item]
if !ok {
t.Fatal("account should be loaded")
}

View File

@@ -26,8 +26,13 @@ type Service struct {
// Accounts holds a stream ID and a map to the exchange holdings
type Accounts struct {
ID uuid.UUID
SubAccounts map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance
ID uuid.UUID
// NOTE: Credentials is a place holder for a future interface type, which
// will need -
// TODO: Credential tracker to match to keys that are managed and return
// pointer.
// TODO: Have different cred struct for centralized verse DEFI exchanges.
SubAccounts map[Credentials]map[string]map[asset.Item]map[*currency.Item]*ProtectedBalance
}
// Holdings is a generic type to hold each exchange's holdings for all enabled
@@ -39,9 +44,10 @@ type Holdings struct {
// SubAccount defines a singular account type with associated currency balances
type SubAccount struct {
ID string
AssetType asset.Item
Currencies []Balance
Credentials Protected
ID string
AssetType asset.Item
Currencies []Balance
}
// Balance is a sub type to store currency name and individual totals
@@ -76,3 +82,9 @@ type ProtectedBalance struct {
// usage.
notice alert.Notice
}
// Protected limits the access to the underlying credentials outside of this
// package.
type Protected struct {
creds Credentials
}

View File

@@ -0,0 +1,218 @@
package account
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"google.golang.org/grpc/metadata"
)
// contextCredential is a string flag for use with context values when setting
// credentials internally or via gRPC.
type contextCredential string
const (
// ContextCredentialsFlag used for retrieving api credentials from context
ContextCredentialsFlag contextCredential = "apicredentials"
// ContextSubAccountFlag used for retrieving just the sub account from
// context, when the default config credentials sub account needs to be
// changed while the same keys can be used.
ContextSubAccountFlag contextCredential = "subaccountoverride"
Key = "key"
Secret = "secret"
SubAccountSTR = "subaccount"
ClientID = "clientid"
OneTimePassword = "otp"
PEMKey = "pemkey"
apiKeyDisplaySize = 16
)
var (
errMetaDataIsNil = errors.New("meta data is nil")
errInvalidCredentialMetaDataLength = errors.New("invalid meta data to process credentials")
errMissingInfo = errors.New("cannot parse meta data missing information in key value pair")
)
// Credentials define parameters that allow for an authenticated request.
type Credentials struct {
Key string
Secret string
ClientID string // TODO: Implement with exchange orders functionality
PEMKey string
SubAccount string
OneTimePassword string
// TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities.
}
// GetMetaData returns the credentials for metadata context deployment
func (c *Credentials) GetMetaData() (flag, values string) {
vals := make([]string, 0, 6)
if c.Key != "" {
vals = append(vals, Key+":"+c.Key)
}
if c.Secret != "" {
vals = append(vals, Secret+":"+c.Secret)
}
if c.SubAccount != "" {
vals = append(vals, SubAccountSTR+":"+c.SubAccount)
}
if c.ClientID != "" {
vals = append(vals, ClientID+":"+c.ClientID)
}
if c.PEMKey != "" {
vals = append(vals, PEMKey+":"+c.PEMKey)
}
if c.OneTimePassword != "" {
vals = append(vals, OneTimePassword+":"+c.OneTimePassword)
}
return string(ContextCredentialsFlag), strings.Join(vals, ",")
}
// String prints out basic credential info (obfuscated) to track key instances
// associated with exchanges.
func (c *Credentials) String() string {
obfuscated := c.Key
if len(obfuscated) > apiKeyDisplaySize {
obfuscated = obfuscated[:apiKeyDisplaySize]
}
return fmt.Sprintf("Key:[%s...] SubAccount:[%s] ClientID:[%s]",
obfuscated,
c.SubAccount,
c.ClientID)
}
// getInternal returns the values for assignment to an internal context
func (c *Credentials) getInternal() (contextCredential, *ContextCredentialsStore) {
if c.IsEmpty() {
return "", nil
}
store := &ContextCredentialsStore{}
store.Load(c)
return ContextCredentialsFlag, store
}
// IsEmpty return true if the underlying credentials type has not been filled
// with at least one item.
func (c *Credentials) IsEmpty() bool {
return c == nil || c.ClientID == "" &&
c.Key == "" &&
c.OneTimePassword == "" &&
c.PEMKey == "" &&
c.Secret == "" &&
c.SubAccount == ""
}
// Equal determines if the keys are the same.
// OTP omitted because it's generated per request.
// PEMKey and Secret omitted because of direct correlation with api key.
func (c *Credentials) Equal(other *Credentials) bool {
return c != nil &&
other != nil &&
c.Key == other.Key &&
c.ClientID == other.ClientID &&
c.SubAccount == other.SubAccount
}
// ContextCredentialsStore protects the stored credentials for use in a context
type ContextCredentialsStore struct {
creds *Credentials
mu sync.RWMutex
}
// Load stores provided credentials
func (c *ContextCredentialsStore) Load(creds *Credentials) {
// Segregate from external call
cpy := *creds
c.mu.Lock()
c.creds = &cpy
c.mu.Unlock()
}
// Get returns the full credentials from the store
func (c *ContextCredentialsStore) Get() *Credentials {
c.mu.RLock()
creds := *c.creds
c.mu.RUnlock()
return &creds
}
// ParseCredentialsMetadata intercepts and converts credentials metadata to a
// static type for authentication processing and protection.
func ParseCredentialsMetadata(ctx context.Context, md metadata.MD) (context.Context, error) {
if md == nil {
return ctx, errMetaDataIsNil
}
credMD, ok := md[string(ContextCredentialsFlag)]
if !ok || len(credMD) == 0 {
return ctx, nil
}
if len(credMD) != 1 {
return ctx, errInvalidCredentialMetaDataLength
}
segregatedCreds := strings.Split(credMD[0], ",")
var ctxCreds Credentials
var subAccountHere string
for x := range segregatedCreds {
keyVals := strings.Split(segregatedCreds[x], ":")
if len(keyVals) != 2 {
return ctx, fmt.Errorf("%w received %v fields, expected 2 contains: %s",
errMissingInfo,
len(keyVals),
keyVals)
}
switch keyVals[0] {
case Key:
ctxCreds.Key = keyVals[1]
case Secret:
ctxCreds.Secret = keyVals[1]
case SubAccountSTR:
// Capture sub account as this can override if other values are
// not included in metadata.
subAccountHere = keyVals[1]
case ClientID:
ctxCreds.ClientID = keyVals[1]
case PEMKey:
ctxCreds.PEMKey = keyVals[1]
case OneTimePassword:
ctxCreds.OneTimePassword = keyVals[1]
}
}
if ctxCreds.IsEmpty() && subAccountHere != "" {
// This will override default sub account details if needed.
return deploySubAccountOverrideToContext(ctx, subAccountHere), nil
}
// merge sub account to main context credentials
ctxCreds.SubAccount = subAccountHere
return DeployCredentialsToContext(ctx, &ctxCreds), nil
}
// DeployCredentialsToContext sets credentials for internal use to context which
// can override default credential values.
func DeployCredentialsToContext(ctx context.Context, creds *Credentials) context.Context {
flag, store := creds.getInternal()
return context.WithValue(ctx, flag, store)
}
// deploySubAccountOverrideToContext sets subaccount as override to credentials
// as a separate flag.
func deploySubAccountOverrideToContext(ctx context.Context, subAccount string) context.Context {
return context.WithValue(ctx, ContextSubAccountFlag, subAccount)
}
// String strings the credentials in a protected way.
func (p *Protected) String() string {
return p.creds.String()
}
// Equal determines if the keys are the same
func (p *Protected) Equal(other *Credentials) bool {
return p.creds.Equal(other)
}

View File

@@ -0,0 +1,240 @@
package account
import (
"context"
"errors"
"testing"
"google.golang.org/grpc/metadata"
)
func TestIsEmpty(t *testing.T) {
t.Parallel()
var c *Credentials
if !c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
}
c = new(Credentials)
if !c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", true, c.IsEmpty())
}
c.SubAccount = "woow"
if c.IsEmpty() {
t.Fatalf("expected: %v but received: %v", false, c.IsEmpty())
}
}
func TestParseCredentialsMetadata(t *testing.T) {
t.Parallel()
_, err := ParseCredentialsMetadata(context.Background(), nil)
if !errors.Is(err, errMetaDataIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errMetaDataIsNil)
}
_, err = ParseCredentialsMetadata(context.Background(), metadata.MD{})
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
ctx := metadata.AppendToOutgoingContext(context.Background(),
string(ContextCredentialsFlag), "wow", string(ContextCredentialsFlag), "wow2")
nortyMD, _ := metadata.FromOutgoingContext(ctx)
_, err = ParseCredentialsMetadata(context.Background(), nortyMD)
if !errors.Is(err, errInvalidCredentialMetaDataLength) {
t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidCredentialMetaDataLength)
}
ctx = metadata.AppendToOutgoingContext(context.Background(),
string(ContextCredentialsFlag), "brokenstring")
nortyMD, _ = metadata.FromOutgoingContext(ctx)
_, err = ParseCredentialsMetadata(context.Background(), nortyMD)
if !errors.Is(err, errMissingInfo) {
t.Fatalf("received: '%v' but expected: '%v'", err, errMissingInfo)
}
beforeCreds := Credentials{
Key: "superkey",
Secret: "supersecret",
SubAccount: "supersub",
ClientID: "superclient",
PEMKey: "superpem",
OneTimePassword: "superOneTimePasssssss",
}
flag, outGoing := beforeCreds.GetMetaData()
ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing)
lovelyMD, _ := metadata.FromOutgoingContext(ctx)
ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
store, ok := ctx.Value(ContextCredentialsFlag).(*ContextCredentialsStore)
if !ok {
t.Fatal("should have processed")
}
afterCreds := store.Get()
if afterCreds.Key != "superkey" &&
afterCreds.Secret != "supersecret" &&
afterCreds.SubAccount != "supersub" &&
afterCreds.ClientID != "superclient" &&
afterCreds.PEMKey != "superpem" &&
afterCreds.OneTimePassword != "superOneTimePasssssss" {
t.Fatal("unexpected values")
}
// subaccount override
subaccount := Credentials{
SubAccount: "supersub",
}
flag, outGoing = subaccount.GetMetaData()
ctx = metadata.AppendToOutgoingContext(context.Background(), flag, outGoing)
lovelyMD, _ = metadata.FromOutgoingContext(ctx)
ctx, err = ParseCredentialsMetadata(context.Background(), lovelyMD)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
sa, ok := ctx.Value(ContextSubAccountFlag).(string)
if !ok {
t.Fatal("should have processed")
}
if sa != "supersub" {
t.Fatal("unexpected value")
}
}
func TestGetInternal(t *testing.T) {
t.Parallel()
flag, store := (&Credentials{}).getInternal()
if flag != "" {
t.Fatal("unexpected value")
}
if store != nil {
t.Fatal("unexpected value")
}
flag, store = (&Credentials{Key: "wow"}).getInternal()
if flag != ContextCredentialsFlag {
t.Fatal("unexpected value")
}
if store == nil {
t.Fatal("unexpected value")
}
if store.Get().Key != "wow" {
t.Fatal("unexpected value")
}
}
func TestString(t *testing.T) {
t.Parallel()
creds := Credentials{}
if s := creds.String(); s != "Key:[...] SubAccount:[] ClientID:[]" {
t.Fatal("unexpected value")
}
creds.Key = "12345678910111234"
creds.SubAccount = "sub"
creds.ClientID = "client"
if s := creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" {
t.Fatal("unexpected value")
}
}
func TestCredentialsEqual(t *testing.T) {
t.Parallel()
var this, that *Credentials
if this.Equal(that) {
t.Fatal("unexpected value")
}
this = &Credentials{}
if this.Equal(that) {
t.Fatal("unexpected value")
}
that = &Credentials{Key: "1337"}
if this.Equal(that) {
t.Fatal("unexpected value")
}
this.Key = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.ClientID = "1337"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.ClientID = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.SubAccount = "someSub"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.SubAccount = "someSub"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
}
func TestProtectedString(t *testing.T) {
t.Parallel()
p := Protected{}
if s := p.String(); s != "Key:[...] SubAccount:[] ClientID:[]" {
t.Fatal("unexpected value")
}
p.creds.Key = "12345678910111234"
p.creds.SubAccount = "sub"
p.creds.ClientID = "client"
if s := p.creds.String(); s != "Key:[1234567891011123...] SubAccount:[sub] ClientID:[client]" {
t.Fatal("unexpected value")
}
}
func TestProtectedCredentialsEqual(t *testing.T) {
t.Parallel()
var this Protected
var that *Credentials
if this.Equal(that) {
t.Fatal("unexpected value")
}
this.creds = Credentials{}
if this.Equal(that) {
t.Fatal("unexpected value")
}
that = &Credentials{Key: "1337"}
if this.Equal(that) {
t.Fatal("unexpected value")
}
this.creds.Key = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.creds.ClientID = "1337"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.ClientID = "1337"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
this.creds.SubAccount = "someSub"
if this.Equal(that) {
t.Fatal("unexpected value")
}
that.SubAccount = "someSub"
if !this.Equal(that) {
t.Fatal("unexpected value")
}
}