From 233a65a7785a61ed60fe46c75cef1b82971c4ebf Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Tue, 10 Oct 2023 10:17:07 +1100 Subject: [PATCH] exchange: refactor credentials.go from PR #1320 (#1360) * refactor credentials * glorious: nits * shazbert: nits --------- Co-authored-by: romanornr Co-authored-by: Ryan O'Hara-Reid --- exchanges/credentials.go | 48 +++++------------------------------ exchanges/credentials_test.go | 31 +++++++++++----------- exchanges/exchange.go | 4 --- exchanges/exchange_types.go | 14 ++-------- 4 files changed, 24 insertions(+), 73 deletions(-) diff --git a/exchanges/credentials.go b/exchanges/credentials.go index cfc925cd..56c9d80d 100644 --- a/exchanges/credentials.go +++ b/exchanges/credentials.go @@ -21,7 +21,7 @@ var ( // completely empty but an attempt at retrieving credentials was made to // undertake an authenticated HTTP request. ErrCredentialsAreEmpty = errors.New("credentials are empty") - + // Errors related to API requirements and failures errRequiresAPIKey = errors.New("requires API key but default/empty one set") errRequiresAPISecret = errors.New("requires API secret but default/empty one set") errRequiresAPIPEMKey = errors.New("requires API PEM key but default/empty one set") @@ -34,9 +34,6 @@ var ( func (a *API) SetKey(key string) { a.credMu.Lock() defer a.credMu.Unlock() - if a.credentials == nil { - a.credentials = &account.Credentials{} - } a.credentials.Key = key } @@ -44,9 +41,6 @@ func (a *API) SetKey(key string) { func (a *API) SetSecret(secret string) { a.credMu.Lock() defer a.credMu.Unlock() - if a.credentials == nil { - a.credentials = &account.Credentials{} - } a.credentials.Secret = secret } @@ -54,9 +48,6 @@ func (a *API) SetSecret(secret string) { func (a *API) SetClientID(clientID string) { a.credMu.Lock() defer a.credMu.Unlock() - if a.credentials == nil { - a.credentials = &account.Credentials{} - } a.credentials.ClientID = clientID } @@ -64,9 +55,6 @@ func (a *API) SetClientID(clientID string) { func (a *API) SetPEMKey(pem string) { a.credMu.Lock() defer a.credMu.Unlock() - if a.credentials == nil { - a.credentials = &account.Credentials{} - } a.credentials.PEMKey = pem } @@ -74,9 +62,6 @@ func (a *API) SetPEMKey(pem string) { func (a *API) SetSubAccount(sub string) { a.credMu.Lock() defer a.credMu.Unlock() - if a.credentials == nil { - a.credentials = &account.Credentials{} - } a.credentials.SubAccount = sub } @@ -116,10 +101,10 @@ func (b *Base) AreCredentialsValid(ctx context.Context) bool { func (b *Base) GetDefaultCredentials() *account.Credentials { b.API.credMu.RLock() defer b.API.credMu.RUnlock() - if b.API.credentials == nil { + if b.API.credentials == (account.Credentials{}) { return nil } - creds := *b.API.credentials + creds := b.API.credentials return &creds } @@ -142,7 +127,8 @@ func (b *Base) GetCredentials(ctx context.Context) (*account.Credentials, error) return creds, nil } - err := b.CheckCredentials(b.API.credentials, false) + creds := b.API.credentials + err := b.CheckCredentials(&creds, false) if err != nil { // NOTE: Return empty credentials on error to limit panic on websocket // handling. @@ -151,7 +137,6 @@ func (b *Base) GetCredentials(ctx context.Context) (*account.Credentials, error) subAccountOverride, ok := ctx.Value(account.ContextSubAccountFlag).(string) b.API.credMu.RLock() defer b.API.credMu.RUnlock() - creds := *b.API.credentials if ok { creds.SubAccount = subAccountOverride } @@ -201,9 +186,6 @@ func (b *Base) VerifyAPICredentials(creds *account.Credentials) error { func (b *Base) SetCredentials(apiKey, apiSecret, clientID, subaccount, pemKey, oneTimePassword string) { b.API.credMu.Lock() defer b.API.credMu.Unlock() - if b.API.credentials == nil { - b.API.credentials = &account.Credentials{} - } b.API.credentials.Key = apiKey b.API.credentials.ClientID = clientID b.API.credentials.SubAccount = subaccount @@ -231,29 +213,13 @@ func (b *Base) SetCredentials(apiKey, apiSecret, clientID, subaccount, pemKey, o func (b *Base) SetAPICredentialDefaults() { b.API.credMu.Lock() defer b.API.credMu.Unlock() + // Exchange hardcoded settings take precedence and overwrite the config settings if b.Config.API.CredentialsValidator == nil { b.Config.API.CredentialsValidator = new(config.APICredentialsValidatorConfig) } - if b.Config.API.CredentialsValidator.RequiresKey != b.API.CredentialsValidator.RequiresKey { - b.Config.API.CredentialsValidator.RequiresKey = b.API.CredentialsValidator.RequiresKey - } - if b.Config.API.CredentialsValidator.RequiresSecret != b.API.CredentialsValidator.RequiresSecret { - b.Config.API.CredentialsValidator.RequiresSecret = b.API.CredentialsValidator.RequiresSecret - } - - if b.Config.API.CredentialsValidator.RequiresBase64DecodeSecret != b.API.CredentialsValidator.RequiresBase64DecodeSecret { - b.Config.API.CredentialsValidator.RequiresBase64DecodeSecret = b.API.CredentialsValidator.RequiresBase64DecodeSecret - } - - if b.Config.API.CredentialsValidator.RequiresClientID != b.API.CredentialsValidator.RequiresClientID { - b.Config.API.CredentialsValidator.RequiresClientID = b.API.CredentialsValidator.RequiresClientID - } - - if b.Config.API.CredentialsValidator.RequiresPEM != b.API.CredentialsValidator.RequiresPEM { - b.Config.API.CredentialsValidator.RequiresPEM = b.API.CredentialsValidator.RequiresPEM - } + *b.Config.API.CredentialsValidator = b.API.CredentialsValidator } // IsWebsocketAuthenticationSupported returns whether the exchange supports diff --git a/exchanges/credentials_test.go b/exchanges/credentials_test.go index b7de7571..9c978fd1 100644 --- a/exchanges/credentials_test.go +++ b/exchanges/credentials_test.go @@ -186,7 +186,7 @@ func TestVerifyAPICredentials(t *testing.T) { setupBase := func(tData *tester) *Base { b := &Base{ API: API{ - CredentialsValidator: CredentialsValidator{ + CredentialsValidator: config.APICredentialsValidatorConfig{ RequiresKey: tData.RequiresKey, RequiresSecret: tData.RequiresSecret, RequiresClientID: tData.RequiresClientID, @@ -211,7 +211,7 @@ func TestVerifyAPICredentials(t *testing.T) { t.Run("", func(t *testing.T) { t.Parallel() b := setupBase(&tc) - if err := b.VerifyAPICredentials(b.API.credentials); !errors.Is(err, tc.Expected) { + if err := b.VerifyAPICredentials(&b.API.credentials); !errors.Is(err, tc.Expected) { t.Errorf("Test %d: expected: %v: got %v", x+1, tc.Expected, err) } if tc.CheckBase64DecodedOutput { @@ -236,7 +236,6 @@ func TestCheckCredentials(t *testing.T) { name: "Test SkipAuthCheck", base: &Base{ SkipAuthCheck: true, - API: API{credentials: &account.Credentials{}}, }, expectedErr: nil, }, @@ -244,8 +243,8 @@ func TestCheckCredentials(t *testing.T) { name: "Test credentials failure", base: &Base{ API: API{ - CredentialsValidator: CredentialsValidator{RequiresKey: true}, - credentials: &account.Credentials{OneTimePassword: "wow"}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresKey: true}, + credentials: account.Credentials{OneTimePassword: "wow"}, }, }, expectedErr: errRequiresAPIKey, @@ -255,8 +254,8 @@ func TestCheckCredentials(t *testing.T) { base: &Base{ LoadedByConfig: true, API: API{ - CredentialsValidator: CredentialsValidator{RequiresKey: true}, - credentials: &account.Credentials{Key: "k3y"}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresKey: true}, + credentials: account.Credentials{Key: "k3y"}, }, }, expectedErr: ErrAuthenticationSupportNotEnabled, @@ -267,8 +266,8 @@ func TestCheckCredentials(t *testing.T) { LoadedByConfig: true, API: API{ AuthenticatedSupport: true, - CredentialsValidator: CredentialsValidator{RequiresKey: true}, - credentials: &account.Credentials{}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresKey: true}, + credentials: account.Credentials{}, }, }, expectedErr: ErrCredentialsAreEmpty, @@ -277,8 +276,8 @@ func TestCheckCredentials(t *testing.T) { name: "Test base64 decoded invalid credentials", base: &Base{ API: API{ - CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true}, - credentials: &account.Credentials{Secret: "invalid"}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresBase64DecodeSecret: true}, + credentials: account.Credentials{Secret: "invalid"}, }, }, expectedErr: errBase64DecodeFailure, @@ -287,8 +286,8 @@ func TestCheckCredentials(t *testing.T) { name: "Test base64 decoded valid credentials", base: &Base{ API: API{ - CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true}, - credentials: &account.Credentials{Secret: "aGVsbG8gd29ybGQ="}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresBase64DecodeSecret: true}, + credentials: account.Credentials{Secret: "aGVsbG8gd29ybGQ="}, }, }, checkBase64Output: true, @@ -299,8 +298,8 @@ func TestCheckCredentials(t *testing.T) { base: &Base{ API: API{ AuthenticatedSupport: true, - CredentialsValidator: CredentialsValidator{RequiresKey: true}, - credentials: &account.Credentials{Key: "k3y"}, + CredentialsValidator: config.APICredentialsValidatorConfig{RequiresKey: true}, + credentials: account.Credentials{Key: "k3y"}, }, }, expectedErr: nil, @@ -311,7 +310,7 @@ func TestCheckCredentials(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - if err := tc.base.CheckCredentials(tc.base.API.credentials, false); !errors.Is(err, tc.expectedErr) { + if err := tc.base.CheckCredentials(&tc.base.API.credentials, false); !errors.Is(err, tc.expectedErr) { t.Errorf("%s: received '%v' but expected '%v'", tc.name, err, tc.expectedErr) } if tc.checkBase64Output { diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 98f7f6ba..94ccfc0b 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -14,7 +14,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" - "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/collateral" "github.com/thrasher-corp/gocryptotrader/exchanges/currencystate" @@ -548,9 +547,6 @@ func (b *Base) SetupDefaults(exch *config.Exchange) error { b.API.AuthenticatedSupport = exch.API.AuthenticatedSupport b.API.AuthenticatedWebsocketSupport = exch.API.AuthenticatedWebsocketSupport - if b.API.credentials == nil { - b.API.credentials = &account.Credentials{} - } b.API.credentials.SubAccount = exch.API.Credentials.Subaccount if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { b.SetCredentials(exch.API.Credentials.Key, diff --git a/exchanges/exchange_types.go b/exchanges/exchange_types.go index 651a2bb9..c5a1f748 100644 --- a/exchanges/exchange_types.go +++ b/exchanges/exchange_types.go @@ -209,20 +209,10 @@ type API struct { Endpoints *Endpoints - credentials *account.Credentials + credentials account.Credentials credMu sync.RWMutex - CredentialsValidator CredentialsValidator -} - -// CredentialsValidator determines what is required -// to make authenticated requests for an exchange -type CredentialsValidator struct { - RequiresPEM bool - RequiresKey bool - RequiresSecret bool - RequiresClientID bool - RequiresBase64DecodeSecret bool + CredentialsValidator config.APICredentialsValidatorConfig } // Base stores the individual exchange information