diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 21eaeff9..8e9fd56d 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -320,13 +320,15 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error { } } - creds, err := b.GetCredentials(context.TODO()) - if err != nil { + clientID := "" + if creds, err := b.GetCredentials(context.TODO()); err != nil { b.Websocket.DataHandler <- order.ClassificationError{ Exchange: b.Name, OrderID: orderID, Err: err, } + } else if creds != nil { + clientID = creds.ClientID } b.Websocket.DataHandler <- &order.Detail{ @@ -335,7 +337,7 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error { RemainingAmount: orderData.OpenVolume, Exchange: b.Name, OrderID: orderID, - ClientID: creds.ClientID, + ClientID: clientID, Type: oType, Side: oSide, Status: oStatus, diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 3c4258da..98438a60 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -173,6 +173,11 @@ func (c *CoinbasePro) wsHandleData(respRaw []byte) error { } } + clientID := "" + if creds != nil { + clientID = creds.ClientID + } + if wsOrder.UserID != "" { var p currency.Pair var a asset.Item @@ -191,7 +196,7 @@ func (c *CoinbasePro) wsHandleData(respRaw []byte) error { Exchange: c.Name, OrderID: wsOrder.OrderID, AccountID: wsOrder.ProfileID, - ClientID: creds.ClientID, + ClientID: clientID, Type: oType, Side: oSide, Status: oStatus, diff --git a/exchanges/credentials.go b/exchanges/credentials.go index 56c9d80d..03e19495 100644 --- a/exchanges/credentials.go +++ b/exchanges/credentials.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/account" @@ -22,12 +23,11 @@ var ( // undertake an authenticated HTTP request. ErrCredentialsAreEmpty = errors.New("credentials are empty") // Errors related to API requirements and failures - errRequiresAPIKey = errors.New("requires API key but default/empty one set") - errRequiresAPISecret = errors.New("requires API secret but default/empty one set") - errRequiresAPIPEMKey = errors.New("requires API PEM key but default/empty one set") - errRequiresAPIClientID = errors.New("requires API Client ID but default/empty one set") - errBase64DecodeFailure = errors.New("base64 decode has failed") - errContextCredentialsFailure = errors.New("context credentials type assertion failure") + errRequiresAPIKey = errors.New("requires API key but default/empty one set") + errRequiresAPISecret = errors.New("requires API secret but default/empty one set") + errRequiresAPIPEMKey = errors.New("requires API PEM key but default/empty one set") + errRequiresAPIClientID = errors.New("requires API Client ID but default/empty one set") + errBase64DecodeFailure = errors.New("base64 decode has failed") ) // SetKey sets new key for the default credentials @@ -115,31 +115,28 @@ func (b *Base) GetCredentials(ctx context.Context) (*account.Credentials, error) if value != nil { ctxCredStore, ok := value.(*account.ContextCredentialsStore) if !ok { - // NOTE: Return empty credentials on error to limit panic on - // websocket handling. - return &account.Credentials{}, errContextCredentialsFailure + return nil, common.GetTypeAssertError("*account.ContextCredentialsStore", value) } creds := ctxCredStore.Get() if err := b.CheckCredentials(creds, true); err != nil { - return creds, fmt.Errorf("context credentials issue: %w", err) + return nil, fmt.Errorf("error checking credentials from context: %w", err) } return creds, nil } - creds := b.API.credentials - err := b.CheckCredentials(&creds, false) - if err != nil { - // NOTE: Return empty credentials on error to limit panic on websocket - // handling. - return &account.Credentials{}, err - } - subAccountOverride, ok := ctx.Value(account.ContextSubAccountFlag).(string) + // Fallback to exchange loaded credentials b.API.credMu.RLock() - defer b.API.credMu.RUnlock() - if ok { + creds := b.API.credentials + b.API.credMu.RUnlock() + if err := b.CheckCredentials(&creds, false); err != nil { + return nil, fmt.Errorf("error checking credentials: %w", err) + } + + if subAccountOverride, ok := ctx.Value(account.ContextSubAccountFlag).(string); ok { creds.SubAccount = subAccountOverride } + return &creds, nil } diff --git a/exchanges/credentials_test.go b/exchanges/credentials_test.go index 5da6deed..e85bbb49 100644 --- a/exchanges/credentials_test.go +++ b/exchanges/credentials_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/exchanges/account" ) @@ -55,9 +57,7 @@ func TestGetCredentials(t *testing.T) { ctx = context.WithValue(t.Context(), account.ContextCredentialsFlag, "pewpew") _, err = b.GetCredentials(ctx) - if !errors.Is(err, errContextCredentialsFailure) { - t.Fatalf("received: %v but expected: %v", err, errContextCredentialsFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) b.API.CredentialsValidator.RequiresBase64DecodeSecret = false fullCred := &account.Credentials{ diff --git a/gctscript/wrappers/gct/gctwrapper_test.go b/gctscript/wrappers/gct/gctwrapper_test.go index 60f98978..5cc38b63 100644 --- a/gctscript/wrappers/gct/gctwrapper_test.go +++ b/gctscript/wrappers/gct/gctwrapper_test.go @@ -10,6 +10,7 @@ import ( "testing" objects "github.com/d5/tengo/v2" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/engine" @@ -183,21 +184,15 @@ func TestExchangePairs(t *testing.T) { } } -func TestAccountInfo(t *testing.T) { +func TestExchangeAccountInfo(t *testing.T) { t.Parallel() _, err := gct.ExchangeAccountInfo() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + require.ErrorIs(t, err, objects.ErrWrongNumArguments) obj, err := gct.ExchangeAccountInfo(ctx, exch, assetType) - if err != nil { - t.Fatalf("received: %v but expected: %v", err, nil) - } - rString, _ := objects.ToString(obj) - if rString != `error: "Bitstamp REST or Websocket authentication support is not enabled"` { - t.Errorf("received: %v but expected: %v", - rString, `error: "Bitstamp REST or Websocket authentication support is not enabled"`) - } + require.NoError(t, err) + rString, ok := objects.ToString(obj) + require.True(t, ok, "ExchangeAccountInfo return value must return correctly from objects.ToString") + require.Contains(t, rString, "Bitstamp REST or Websocket authentication support is not enabled") } func TestExchangeOrderQuery(t *testing.T) { @@ -229,9 +224,7 @@ func TestExchangeOrderCancel(t *testing.T) { func TestExchangeOrderSubmit(t *testing.T) { t.Parallel() _, err := gct.ExchangeOrderSubmit() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + require.ErrorIs(t, err, objects.ErrWrongNumArguments) orderSide := &objects.String{Value: "ASK"} orderType := &objects.String{Value: "LIMIT"} @@ -249,16 +242,11 @@ func TestExchangeOrderSubmit(t *testing.T) { orderAmount, orderID, orderAsset) - if err != nil { - t.Fatalf("received: %v but expected: %v", err, nil) - } + require.NoError(t, err) - rString, _ := objects.ToString(obj) - if rString != `error: "Bitstamp REST or Websocket authentication support is not enabled"` { - t.Errorf("received: [%v] but expected: %v", - rString, - `error: "Bitstamp REST or Websocket authentication support is not enabled"`) - } + rString, ok := objects.ToString(obj) + require.True(t, ok, "ExchangeOrderSubmit return value must return correctly from objects.ToString") + require.Contains(t, rString, "Bitstamp REST or Websocket authentication support is not enabled") } func TestAllModuleNames(t *testing.T) {