mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
Kraken: Protect authToken with RWMutex to prevent race (#1926)
* Fix(kraken): Protect authToken with RWMutex to prevent race This commit introduces a sync.RWMutex to protect the global `authToken` variable in `exchanges/kraken/kraken_websocket.go`. The race condition occurred due to concurrent read/write access to `authToken` from different goroutines, notably between `WsConnect` (write) and functions like `wsCancelOrder`, `wsAddOrder`, `wsCancelAllOrders`, and `manageSubs` (read). The fix involves: - Adding `authTokenMutex.Lock()` before writing to `authToken` in `WsConnect` and `authTokenMutex.Unlock()` after. - Adding `authTokenMutex.RLock()` before reading `authToken` in `wsAddOrder`, `wsCancelOrder`, `wsCancelAllOrders`, and `manageSubs`, and `authTokenMutex.RUnlock()` after. This change resolves the data race reported in https://github.com/thrasher-corp/gocryptotrader/issues/1762. I ran tests in the `exchanges/kraken` package with the `-race` detector, and all tests passed without detecting any race conditions. * kraken: Add common websocketAuthToken func for concurrent read access * Kraken: Only set authToken on mission success * Refactor: Adjust websocket authentication handling to use setWebsocketAuthToken method and rename wsAuthMu to wsAuthMtx for clarity --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
@@ -37,6 +38,8 @@ const (
|
||||
// Kraken is the overarching type across the kraken package
|
||||
type Kraken struct {
|
||||
exchange.Base
|
||||
wsAuthToken string
|
||||
wsAuthMtx sync.RWMutex
|
||||
}
|
||||
|
||||
// GetCurrentServerTime returns current server time
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1741,3 +1742,38 @@ func TestEnforceStandardChannelNames(t *testing.T) {
|
||||
assert.ErrorIsf(t, err, subscription.ErrUseConstChannelName, "Private channel names should not be allowed for %s", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketAuthToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
k := new(Kraken)
|
||||
k.setWebsocketAuthToken("meep")
|
||||
const n = 69
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2 * n)
|
||||
|
||||
start := make(chan struct{})
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
k.setWebsocketAuthToken("69420")
|
||||
}()
|
||||
}
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
k.websocketAuthToken()
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
assert.Equal(t, "69420", k.websocketAuthToken(), "websocketAuthToken should return correctly after concurrent reads and writes")
|
||||
}
|
||||
|
||||
func TestSetWebsocketAuthToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
k := new(Kraken)
|
||||
k.setWebsocketAuthToken("69420")
|
||||
assert.Equal(t, "69420", k.websocketAuthToken())
|
||||
}
|
||||
|
||||
@@ -79,7 +79,6 @@ func init() {
|
||||
}
|
||||
|
||||
var (
|
||||
authToken string
|
||||
errCancellingOrder = errors.New("error cancelling order")
|
||||
errSubPairMissing = errors.New("pair missing from subscription response")
|
||||
errInvalidChecksum = errors.New("invalid checksum")
|
||||
@@ -112,22 +111,15 @@ func (k *Kraken) WsConnect() error {
|
||||
go k.wsFunnelConnectionData(k.Websocket.Conn, comms)
|
||||
|
||||
if k.IsWebsocketAuthenticationSupported() {
|
||||
authToken, err = k.GetWebsocketToken(context.TODO())
|
||||
if err != nil {
|
||||
if authToken, err := k.GetWebsocketToken(context.TODO()); err != nil {
|
||||
k.Websocket.SetCanUseAuthenticatedEndpoints(false)
|
||||
log.Errorf(log.ExchangeSys,
|
||||
"%v - authentication failed: %v\n",
|
||||
k.Name,
|
||||
err)
|
||||
log.Errorf(log.ExchangeSys, "%s - authentication failed: %v\n", k.Name, err)
|
||||
} else {
|
||||
err = k.Websocket.AuthConn.Dial(&dialer, http.Header{})
|
||||
if err != nil {
|
||||
if err := k.Websocket.AuthConn.Dial(&dialer, http.Header{}); err != nil {
|
||||
k.Websocket.SetCanUseAuthenticatedEndpoints(false)
|
||||
log.Errorf(log.ExchangeSys,
|
||||
"%v - failed to connect to authenticated endpoint: %v\n",
|
||||
k.Name,
|
||||
err)
|
||||
log.Errorf(log.ExchangeSys, "%s - failed to connect to authenticated endpoint: %v\n", k.Name, err)
|
||||
} else {
|
||||
k.setWebsocketAuthToken(authToken)
|
||||
k.Websocket.SetCanUseAuthenticatedEndpoints(true)
|
||||
k.Websocket.Wg.Add(1)
|
||||
go k.wsFunnelConnectionData(k.Websocket.AuthConn, comms)
|
||||
@@ -1091,7 +1083,7 @@ func (k *Kraken) manageSubs(op string, subs subscription.List) error {
|
||||
|
||||
conn := k.Websocket.Conn
|
||||
if s.Authenticated {
|
||||
r.Subscription.Token = authToken
|
||||
r.Subscription.Token = k.websocketAuthToken()
|
||||
conn = k.Websocket.AuthConn
|
||||
}
|
||||
|
||||
@@ -1305,7 +1297,7 @@ func (k *Kraken) wsAddOrder(req *WsAddOrderRequest) (string, error) {
|
||||
}
|
||||
req.RequestID = k.Websocket.AuthConn.GenerateMessageID(false)
|
||||
req.Event = krakenWsAddOrder
|
||||
req.Token = authToken
|
||||
req.Token = k.websocketAuthToken()
|
||||
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, req.RequestID, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1345,7 +1337,7 @@ func (k *Kraken) wsCancelOrder(orderID string) error {
|
||||
id := k.Websocket.AuthConn.GenerateMessageID(false)
|
||||
req := WsCancelOrderRequest{
|
||||
Event: krakenWsCancelOrder,
|
||||
Token: authToken,
|
||||
Token: k.websocketAuthToken(),
|
||||
TransactionIDs: []string{orderID},
|
||||
RequestID: id,
|
||||
}
|
||||
@@ -1376,7 +1368,7 @@ func (k *Kraken) wsCancelAllOrders() (*WsCancelOrderResponse, error) {
|
||||
id := k.Websocket.AuthConn.GenerateMessageID(false)
|
||||
req := WsCancelOrderRequest{
|
||||
Event: krakenWsCancelAll,
|
||||
Token: authToken,
|
||||
Token: k.websocketAuthToken(),
|
||||
RequestID: id,
|
||||
}
|
||||
|
||||
@@ -1414,3 +1406,16 @@ const subTplText = `
|
||||
{{- channelName $.S }}
|
||||
{{- end }}
|
||||
`
|
||||
|
||||
// websocketAuthToken retrieves the current websocket session's auth token
|
||||
func (k *Kraken) websocketAuthToken() string {
|
||||
k.wsAuthMtx.RLock()
|
||||
defer k.wsAuthMtx.RUnlock()
|
||||
return k.wsAuthToken
|
||||
}
|
||||
|
||||
func (k *Kraken) setWebsocketAuthToken(token string) {
|
||||
k.wsAuthMtx.Lock()
|
||||
k.wsAuthToken = token
|
||||
k.wsAuthMtx.Unlock()
|
||||
}
|
||||
|
||||
@@ -1436,10 +1436,12 @@ func (k *Kraken) GetOrderHistory(ctx context.Context, getOrdersRequest *order.Mu
|
||||
// AuthenticateWebsocket sends an authentication message to the websocket
|
||||
func (k *Kraken) AuthenticateWebsocket(ctx context.Context) error {
|
||||
resp, err := k.GetWebsocketToken(ctx)
|
||||
if resp != "" {
|
||||
authToken = resp
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
k.setWebsocketAuthToken(resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAPICredentials validates current credentials used for wrapper
|
||||
|
||||
Reference in New Issue
Block a user