diff --git a/exchanges/account/credentials.go b/exchanges/account/credentials.go index 3ac0e8e1..250de87a 100644 --- a/exchanges/account/credentials.go +++ b/exchanges/account/credentials.go @@ -43,12 +43,13 @@ var ( // Credentials define parameters that allow for an authenticated request. type Credentials struct { - Key string - Secret string - ClientID string // TODO: Implement with exchange orders functionality - PEMKey string - SubAccount string - OneTimePassword string + Key string + Secret string + ClientID string // TODO: Implement with exchange orders functionality + PEMKey string + SubAccount string + OneTimePassword string + SecretBase64Decoded bool // TODO: Add AccessControl uint8 for READ/WRITE/Withdraw capabilities. } diff --git a/exchanges/credentials.go b/exchanges/credentials.go index e6f8be99..cfc925cd 100644 --- a/exchanges/credentials.go +++ b/exchanges/credentials.go @@ -185,12 +185,15 @@ func (b *Base) VerifyAPICredentials(creds *account.Credentials) error { return fmt.Errorf("%s %w", b.Name, errRequiresAPIClientID) } - if b.API.CredentialsValidator.RequiresBase64DecodeSecret && !b.LoadedByConfig { - _, err := crypto.Base64Decode(creds.Secret) + if b.API.CredentialsValidator.RequiresBase64DecodeSecret && !creds.SecretBase64Decoded { + decodedResult, err := crypto.Base64Decode(creds.Secret) if err != nil { return fmt.Errorf("%s API secret %w: %s", b.Name, errBase64DecodeFailure, err) } + creds.Secret = string(decodedResult) + creds.SecretBase64Decoded = true } + return nil } @@ -218,6 +221,7 @@ func (b *Base) SetCredentials(apiKey, apiSecret, clientID, subaccount, pemKey, o return } b.API.credentials.Secret = string(result) + b.API.credentials.SecretBase64Decoded = true } else { b.API.credentials.Secret = apiSecret } diff --git a/exchanges/credentials_test.go b/exchanges/credentials_test.go index e2cb6339..b7de7571 100644 --- a/exchanges/credentials_test.go +++ b/exchanges/credentials_test.go @@ -31,12 +31,35 @@ func TestGetCredentials(t *testing.T) { t.Fatalf("received: %v but expected: %v", err, errRequiresAPISecret) } + b.API.CredentialsValidator.RequiresBase64DecodeSecret = true + ctx = account.DeployCredentialsToContext(context.Background(), &account.Credentials{ + Key: "meow", + Secret: "invalidb64", + }) + if _, err = b.GetCredentials(ctx); !errors.Is(err, errBase64DecodeFailure) { + t.Fatalf("received: %v but expected: %v", err, errBase64DecodeFailure) + } + + const expectedBase64DecodedOutput = "hello world" + ctx = account.DeployCredentialsToContext(context.Background(), &account.Credentials{ + Key: "meow", + Secret: "aGVsbG8gd29ybGQ=", + }) + creds, err := b.GetCredentials(ctx) + if !errors.Is(err, nil) { + t.Fatalf("received: %v but expected: %v", err, nil) + } + if creds.Secret != expectedBase64DecodedOutput { + t.Fatalf("received: %v but expected: %v", creds.Secret, expectedBase64DecodedOutput) + } + ctx = context.WithValue(context.Background(), account.ContextCredentialsFlag, "pewpew") _, err = b.GetCredentials(ctx) if !errors.Is(err, errContextCredentialsFailure) { t.Fatalf("received: %v but expected: %v", err, errContextCredentialsFailure) } + b.API.CredentialsValidator.RequiresBase64DecodeSecret = false fullCred := &account.Credentials{ Key: "superkey", Secret: "supersecret", @@ -47,7 +70,7 @@ func TestGetCredentials(t *testing.T) { } ctx = account.DeployCredentialsToContext(context.Background(), fullCred) - creds, err := b.GetCredentials(ctx) + creds, err = b.GetCredentials(ctx) if !errors.Is(err, nil) { t.Fatalf("received: %v but expected: %v", err, nil) } @@ -131,9 +154,13 @@ func TestVerifyAPICredentials(t *testing.T) { RequiresSecret bool RequiresClientID bool RequiresBase64DecodeSecret bool + UseSetCredentials bool + CheckBase64DecodedOutput bool Expected error } + const expectedBase64DecodedOutput = "hello world" + testCases := []tester{ // Empty credentials {Expected: ErrCredentialsAreEmpty}, @@ -152,31 +179,45 @@ func TestVerifyAPICredentials(t *testing.T) { // test requires base64 decode secret {RequiresBase64DecodeSecret: true, RequiresSecret: true, Expected: errRequiresAPISecret, Key: "bruh"}, {RequiresBase64DecodeSecret: true, Secret: "%%", Expected: errBase64DecodeFailure}, - {RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ="}, + {RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ=", CheckBase64DecodedOutput: true}, + {RequiresBase64DecodeSecret: true, Secret: "aGVsbG8gd29ybGQ=", UseSetCredentials: true, CheckBase64DecodedOutput: true}, } setupBase := func(tData *tester) *Base { - b := &Base{} - b.API.SetKey(tData.Key) - b.API.SetSecret(tData.Secret) - b.API.SetClientID(tData.ClientID) - b.API.SetPEMKey(tData.PEMKey) - b.API.CredentialsValidator.RequiresKey = tData.RequiresKey - b.API.CredentialsValidator.RequiresSecret = tData.RequiresSecret - b.API.CredentialsValidator.RequiresPEM = tData.RequiresPEM - b.API.CredentialsValidator.RequiresClientID = tData.RequiresClientID - b.API.CredentialsValidator.RequiresBase64DecodeSecret = tData.RequiresBase64DecodeSecret + b := &Base{ + API: API{ + CredentialsValidator: CredentialsValidator{ + RequiresKey: tData.RequiresKey, + RequiresSecret: tData.RequiresSecret, + RequiresClientID: tData.RequiresClientID, + RequiresPEM: tData.RequiresPEM, + RequiresBase64DecodeSecret: tData.RequiresBase64DecodeSecret, + }, + }, + } + if tData.UseSetCredentials { + b.SetCredentials(tData.Key, tData.Secret, tData.ClientID, "", tData.PEMKey, "") + } else { + b.API.SetKey(tData.Key) + b.API.SetSecret(tData.Secret) + b.API.SetClientID(tData.ClientID) + b.API.SetPEMKey(tData.PEMKey) + } return b } - for x := range testCases { - testData := &testCases[x] - x := x + for x, tc := range testCases { + x, tc := x, tc t.Run("", func(t *testing.T) { t.Parallel() - b := setupBase(testData) - if err := b.VerifyAPICredentials(b.API.credentials); !errors.Is(err, testData.Expected) { - t.Errorf("Test %d: expected: %v: got %v", x+1, testData.Expected, err) + b := setupBase(&tc) + if err := b.VerifyAPICredentials(b.API.credentials); !errors.Is(err, tc.Expected) { + t.Errorf("Test %d: expected: %v: got %v", x+1, tc.Expected, err) + } + if tc.CheckBase64DecodedOutput { + if b.API.credentials.Secret != expectedBase64DecodedOutput { + t.Errorf("Test %d: expected: %v: got %v", x+1, expectedBase64DecodedOutput, b.API.credentials.Secret) + } } }) } @@ -185,50 +226,103 @@ func TestVerifyAPICredentials(t *testing.T) { func TestCheckCredentials(t *testing.T) { t.Parallel() - b := Base{ - SkipAuthCheck: true, - API: API{credentials: &account.Credentials{}}, + testCases := []struct { + name string + base *Base + checkBase64Output bool + expectedErr error + }{ + { + name: "Test SkipAuthCheck", + base: &Base{ + SkipAuthCheck: true, + API: API{credentials: &account.Credentials{}}, + }, + expectedErr: nil, + }, + { + name: "Test credentials failure", + base: &Base{ + API: API{ + CredentialsValidator: CredentialsValidator{RequiresKey: true}, + credentials: &account.Credentials{OneTimePassword: "wow"}, + }, + }, + expectedErr: errRequiresAPIKey, + }, + { + name: "Test exchange usage with authenticated API support disabled, but with valid credentials", + base: &Base{ + LoadedByConfig: true, + API: API{ + CredentialsValidator: CredentialsValidator{RequiresKey: true}, + credentials: &account.Credentials{Key: "k3y"}, + }, + }, + expectedErr: ErrAuthenticationSupportNotEnabled, + }, + { + name: "Test enabled authenticated API support and loaded by config but invalid credentials", + base: &Base{ + LoadedByConfig: true, + API: API{ + AuthenticatedSupport: true, + CredentialsValidator: CredentialsValidator{RequiresKey: true}, + credentials: &account.Credentials{}, + }, + }, + expectedErr: ErrCredentialsAreEmpty, + }, + { + name: "Test base64 decoded invalid credentials", + base: &Base{ + API: API{ + CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true}, + credentials: &account.Credentials{Secret: "invalid"}, + }, + }, + expectedErr: errBase64DecodeFailure, + }, + { + name: "Test base64 decoded valid credentials", + base: &Base{ + API: API{ + CredentialsValidator: CredentialsValidator{RequiresBase64DecodeSecret: true}, + credentials: &account.Credentials{Secret: "aGVsbG8gd29ybGQ="}, + }, + }, + checkBase64Output: true, + expectedErr: nil, + }, + { + name: "Test valid credentials", + base: &Base{ + API: API{ + AuthenticatedSupport: true, + CredentialsValidator: CredentialsValidator{RequiresKey: true}, + credentials: &account.Credentials{Key: "k3y"}, + }, + }, + expectedErr: nil, + }, } - // Test SkipAuthCheck - err := b.CheckCredentials(&account.Credentials{}, false) - if !errors.Is(err, nil) { - t.Errorf("received '%v' expected '%v'", err, nil) - } - - // Test credentials failure - b.SkipAuthCheck = false - b.API.CredentialsValidator.RequiresKey = true - b.API.credentials.OneTimePassword = "wow" - err = b.CheckCredentials(b.API.credentials, false) - if !errors.Is(err, errRequiresAPIKey) { - t.Errorf("received '%v' expected '%v'", err, errRequiresAPIKey) - } - b.API.credentials.OneTimePassword = "" - - // Test bot usage with authenticated API support disabled, but with - // valid credentials - b.LoadedByConfig = true - b.API.credentials.Key = "k3y" - err = b.CheckCredentials(b.API.credentials, false) - if !errors.Is(err, ErrAuthenticationSupportNotEnabled) { - t.Errorf("received '%v' expected '%v'", err, ErrAuthenticationSupportNotEnabled) - } - - // Test enabled authenticated API support and loaded by config - // but invalid credentials - b.API.AuthenticatedSupport = true - b.API.credentials.Key = "" - err = b.CheckCredentials(b.API.credentials, false) - if !errors.Is(err, ErrCredentialsAreEmpty) { - t.Errorf("received '%v' expected '%v'", err, ErrCredentialsAreEmpty) - } - - // Finally a valid one - b.API.credentials.Key = "k3y" - err = b.CheckCredentials(b.API.credentials, false) - if !errors.Is(err, nil) { - t.Errorf("received '%v' expected '%v'", err, nil) + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if err := tc.base.CheckCredentials(tc.base.API.credentials, false); !errors.Is(err, tc.expectedErr) { + t.Errorf("%s: received '%v' but expected '%v'", tc.name, err, tc.expectedErr) + } + if tc.checkBase64Output { + if tc.base.API.credentials.SecretBase64Decoded != true { + t.Errorf("%s: expected secret to be base64 decoded", tc.name) + } + if tc.base.API.credentials.Secret != "hello world" { + t.Errorf("%s: expected %q but received %q", "hello world", tc.name, tc.base.API.credentials.Secret) + } + } + }) } }