From cce28f9d2bf5831da6687377a2700f943b10e28c Mon Sep 17 00:00:00 2001 From: dazi005 <460497231@qq.com> Date: Fri, 10 Oct 2025 07:59:08 +0800 Subject: [PATCH] OKX: Fix authenticated websocket login (#2051) * Okx:remove * Okx:replace * Okx:ping gws.PingMessage * Okx:ping gws.PingMessage * Okx: add authenticateConnection * Okx: fix pingHandler * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan * Update exchanges/okx/okx_websocket.go Co-authored-by: Gareth Kirwan * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan * Okx:UseMultiConnectionManagement * Okx:rm UseMultiConnectionManagement * Okx:roll back * Okx:apply diff * Okx:make lint fix * Okx:make lint fix * Okx:make lint fix * Okx:fix * Okx:fix name * Okx:fix NilGuard depends on #2076 * Okx:remove comment --------- Co-authored-by: Ryan O'Hara-Reid Co-authored-by: Gareth Kirwan --- exchange/websocket/connection.go | 12 ++--- exchange/websocket/manager.go | 24 +++++++--- exchange/websocket/manager_test.go | 53 +++++++++++++++++++- exchanges/okx/okx_business_websocket.go | 49 +++---------------- exchanges/okx/okx_types.go | 6 +++ exchanges/okx/okx_websocket.go | 64 +++++++++++++++---------- 6 files changed, 124 insertions(+), 84 deletions(-) diff --git a/exchange/websocket/connection.go b/exchange/websocket/connection.go index d7a8bbf8..82473f65 100644 --- a/exchange/websocket/connection.go +++ b/exchange/websocket/connection.go @@ -19,6 +19,7 @@ import ( "time" gws "github.com/gorilla/websocket" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" @@ -240,15 +241,12 @@ func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHan return } c.Wg.Go(func() { - ticker := time.NewTicker(handler.Delay) for { select { case <-c.shutdown: - ticker.Stop() return - case <-ticker.C: - err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message) - if err != nil { + case <-time.After(handler.Delay): + if err := c.SendRawMessage(context.Background(), epl, handler.MessageType, handler.Message); err != nil { log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err) return } @@ -368,8 +366,8 @@ func (c *connection) defaultGenerateMessageID(highPrec bool) int64 { // Shutdown shuts down and closes specific connection func (c *connection) Shutdown() error { - if c == nil || c.Connection == nil { - return nil + if err := common.NilGuard(c, c.Connection); err != nil { + return err } c.setConnectedStatus(false) c.writeControl.Lock() diff --git a/exchange/websocket/manager.go b/exchange/websocket/manager.go index 45db1630..f0c6c0fc 100644 --- a/exchange/websocket/manager.go +++ b/exchange/websocket/manager.go @@ -633,15 +633,14 @@ func (m *Manager) Shutdown() error { } func (m *Manager) shutdown() error { - if !m.IsConnected() { - return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected) - } - - // TODO: Interrupt connection and or close connection when it is re-established. if m.IsConnecting() { return fmt.Errorf("%v %w: %w ", m.exchangeName, errCannotShutdown, errAlreadyReconnecting) } + if !m.IsConnected() { + return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected) + } + if m.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", m.exchangeName) } @@ -682,11 +681,22 @@ func (m *Manager) shutdown() error { // flush any subscriptions from last connection if needed m.subscriptions.Clear() - m.setState(disconnectedState) - close(m.ShutdownC) + m.setState(disconnectedState) m.Wg.Wait() m.ShutdownC = make(chan struct{}) + + for _, conn := range []Connection{m.Conn, m.AuthConn} { + if conn == nil { + continue + } + conn, ok := conn.(*connection) + if !ok { + return fmt.Errorf("%s websocket: %w", m.exchangeName, common.GetTypeAssertError("*connection", conn)) + } + conn.shutdown = m.ShutdownC + } + if m.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", m.exchangeName) } diff --git a/exchange/websocket/manager_test.go b/exchange/websocket/manager_test.go index 32f29231..af24cee6 100644 --- a/exchange/websocket/manager_test.go +++ b/exchange/websocket/manager_test.go @@ -1088,7 +1088,7 @@ func TestConnectionShutdown(t *testing.T) { t.Parallel() wc := connection{shutdown: make(chan struct{})} err := wc.Shutdown() - assert.NoError(t, err, "Shutdown should not error") + assert.ErrorIs(t, err, common.ErrNilPointer, "Shutdown should error correctly") err = wc.Dial(t.Context(), &gws.Dialer{}, nil) assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial should error correctly") @@ -1336,3 +1336,54 @@ func TestGetConnection(t *testing.T) { require.NoError(t, err) assert.Same(t, expected, conn) } + +func TestShutdown(t *testing.T) { + t.Parallel() + m := Manager{} + m.setState(connectingState) + require.ErrorIs(t, m.Shutdown(), errAlreadyReconnecting, "Shutdown must error correctly") + m.setState(disconnectedState) + require.ErrorIs(t, m.Shutdown(), ErrNotConnected, "Shutdown must error correctly") + m.setState(connectedState) + require.Panics(t, func() { _ = m.Shutdown() }, "Shutdown must panic on nil shutdown channel") + m.ShutdownC = make(chan struct{}) + require.NoError(t, m.Shutdown(), "Shutdown must not error with no connections") + m.setState(connectedState) + m.Conn = &struct{ *connection }{&connection{}} + m.AuthConn = &struct{ *connection }{&connection{}} + require.ErrorIs(t, m.Shutdown(), common.ErrTypeAssertFailure, "Shutdown must error with unhandled connection type") + + mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) + defer mock.Close() + + wsURL := "ws" + mock.URL[len("http"):] + "/ws" + conn, resp, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil) + require.NoError(t, err, "DialContext must not error") + defer resp.Body.Close() + + m.AuthConn = nil + m.Conn = nil + m.connectionManager = []*connectionWrapper{{connection: &connection{Connection: nil}}, {connection: &connection{Connection: conn}}} + m.setState(connectedState) + require.NoError(t, m.Shutdown(), "Shutdown must not error with faulty connection in connectionManager") + + gwsConnAuth, respAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil) + require.NoError(t, err, "DialContext must not error") + defer respAuth.Body.Close() + + gwsConnUnAuth, respUnAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil) + require.NoError(t, err, "DialContext must not error") + defer respUnAuth.Body.Close() + + m.connectionManager = nil + authConn := &connection{Connection: gwsConnAuth, shutdown: m.ShutdownC} + m.AuthConn = authConn + unauthConn := &connection{Connection: gwsConnUnAuth, shutdown: m.ShutdownC} + m.Conn = unauthConn + + m.setState(connectedState) + require.NoError(t, m.Shutdown(), "Shutdown must not error with good connections") + + require.Equal(t, m.ShutdownC, authConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed") + require.Equal(t, m.ShutdownC, unauthConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed") +} diff --git a/exchanges/okx/okx_business_websocket.go b/exchanges/okx/okx_business_websocket.go index 3e965bab..42143d3e 100644 --- a/exchanges/okx/okx_business_websocket.go +++ b/exchanges/okx/okx_business_websocket.go @@ -2,16 +2,12 @@ package okx import ( "context" - "encoding/base64" - "fmt" "net/http" - "strconv" "strings" "time" gws "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" - "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/websocket" @@ -55,15 +51,13 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error { dialer.WriteBufferSize = 8192 e.Websocket.Conn.SetURL(okxBusinessWebsocketURL) - err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}) - if err != nil { + if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil { return err } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.Conn) + e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) }) + if e.Verbose { - log.Debugf(log.ExchangeSys, "Successful connection to %v\n", - e.Websocket.GetWebsocketURL()) + log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL()) } e.Websocket.Conn.SetupPingHandler(request.UnAuth, websocket.PingHandler{ MessageType: gws.TextMessage, @@ -71,45 +65,14 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error { Delay: time.Second * 20, }) if e.Websocket.CanUseAuthenticatedEndpoints() { - err = e.WsSpreadAuth(ctx) - if err != nil { - log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error()) + if err := e.authenticateConnection(ctx, e.Websocket.Conn); err != nil { // business WS uses same conn for public and private + log.Errorf(log.ExchangeSys, "Error authenticating business websocket: %s", err) e.Websocket.SetCanUseAuthenticatedEndpoints(false) } } return nil } -// WsSpreadAuth will connect to Okx's Private websocket connection and Authenticate with a login payload. -func (e *Exchange) WsSpreadAuth(ctx context.Context) error { - if !e.Websocket.CanUseAuthenticatedEndpoints() { - return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name) - } - creds, err := e.GetCredentials(ctx) - if err != nil { - return err - } - e.Websocket.SetCanUseAuthenticatedEndpoints(true) - ts := time.Now().Unix() - signPath := "/users/self/verify" - hmac, err := crypto.GetHMAC(crypto.HashSHA256, - []byte(strconv.FormatInt(ts, 10)+http.MethodGet+signPath), - []byte(creds.Secret), - ) - if err != nil { - return err - } - args := []WebsocketLoginData{ - { - APIKey: creds.Key, - Passphrase: creds.ClientID, - Timestamp: ts, - Sign: base64.StdEncoding.EncodeToString(hmac), - }, - } - return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil) -} - // GenerateDefaultBusinessSubscriptions returns a list of default subscriptions to business websocket. func (e *Exchange) GenerateDefaultBusinessSubscriptions() ([]subscription.Subscription, error) { var subs []string diff --git a/exchanges/okx/okx_types.go b/exchanges/okx/okx_types.go index 50cbe3d7..b950e9d8 100644 --- a/exchanges/okx/okx_types.go +++ b/exchanges/okx/okx_types.go @@ -3171,6 +3171,12 @@ type WebsocketLoginData struct { Sign string `json:"sign"` } +// WebsocketAuthLogin holds the operation and arguments +type WebsocketAuthLogin struct { + Operation string `json:"op"` + Arguments []WebsocketLoginData `json:"args"` +} + // SubscriptionInfo holds the channel and instrument IDs type SubscriptionInfo struct { Channel string `json:"channel,omitempty"` diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 72d57537..f06ce55f 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -41,7 +41,7 @@ var ( // See: https://www.okx.com/docs-v5/en/#error-code-websocket-public authConnErrorCodes = []string{ "60007", "60022", "60023", "60024", "60026", "63999", "60032", "60011", "60009", - "60005", "60021", "60031", "50110", + "60005", "60021", "60031", "50110", "60033", } ) @@ -245,15 +245,12 @@ func (e *Exchange) WsConnect() error { dialer.ReadBufferSize = 8192 dialer.WriteBufferSize = 8192 - err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}) - if err != nil { + if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil { return err } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.Conn) + e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) }) if e.Verbose { - log.Debugf(log.ExchangeSys, "Successful connection to %v\n", - e.Websocket.GetWebsocketURL()) + log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL()) } e.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{ MessageType: gws.TextMessage, @@ -261,8 +258,7 @@ func (e *Exchange) WsConnect() error { Delay: time.Second * 20, }) if e.Websocket.CanUseAuthenticatedEndpoints() { - err = e.WsAuth(ctx) - if err != nil { + if err := e.WsAuth(ctx); err != nil { log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error()) e.Websocket.SetCanUseAuthenticatedEndpoints(false) } @@ -275,24 +271,24 @@ func (e *Exchange) WsAuth(ctx context.Context) error { if !e.AreCredentialsValid(ctx) || !e.Websocket.CanUseAuthenticatedEndpoints() { return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name) } - creds, err := e.GetCredentials(ctx) - if err != nil { - return err - } var dialer gws.Dialer - err = e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}) - if err != nil { + if err := e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}); err != nil { return err } - e.Websocket.Wg.Add(1) - go e.wsReadData(ctx, e.Websocket.AuthConn) + e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.AuthConn) }) e.Websocket.AuthConn.SetupPingHandler(request.Unset, websocket.PingHandler{ MessageType: gws.TextMessage, Message: pingMsg, Delay: time.Second * 20, }) + return e.authenticateConnection(ctx, e.Websocket.AuthConn) +} - e.Websocket.SetCanUseAuthenticatedEndpoints(true) +func (e *Exchange) authenticateConnection(ctx context.Context, conn websocket.Connection) error { + creds, err := e.GetCredentials(ctx) + if err != nil { + return err + } ts := time.Now().Unix() signPath := "/users/self/verify" hmac, err := crypto.GetHMAC(crypto.HashSHA256, @@ -303,21 +299,37 @@ func (e *Exchange) WsAuth(ctx context.Context) error { return err } - args := []WebsocketLoginData{ - { - APIKey: creds.Key, - Passphrase: creds.ClientID, - Timestamp: ts, - Sign: base64.StdEncoding.EncodeToString(hmac), + op := WebsocketAuthLogin{ + Operation: operationLogin, + Arguments: []WebsocketLoginData{ + { + APIKey: creds.Key, + Passphrase: creds.ClientID, + Timestamp: ts, + Sign: base64.StdEncoding.EncodeToString(hmac), + }, }, } + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, "login-response", op) + if err != nil { + return err + } + var intermediary struct { + Code int64 `json:"code,string"` + Message string `json:"msg"` + } + if err := json.Unmarshal(resp, &intermediary); err != nil { + return err + } - return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil) + if intermediary.Code != 0 { + return getStatusError(intermediary.Code, intermediary.Message) + } + return nil } // wsReadData sends msgs from public and auth websockets to data handler func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) { - defer e.Websocket.Wg.Done() for { resp := ws.ReadMessage() if resp.Raw == nil {