OKX: Fix authenticated websocket login (#2051)

* Okx:remove

* Okx:replace

* Okx:ping gws.PingMessage

* Okx:ping gws.PingMessage

* Okx: add authenticateConnection

* Okx: fix pingHandler

* Update exchanges/okx/okx_business_websocket.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/okx/okx_business_websocket.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/okx/okx_websocket.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/okx/okx_business_websocket.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Okx:UseMultiConnectionManagement

* Okx:rm UseMultiConnectionManagement

* Okx:roll back

* Okx:apply diff

* Okx:make lint fix

* Okx:make lint fix

* Okx:make lint fix

* Okx:fix

* Okx:fix name

* Okx:fix NilGuard depends on #2076

* Okx:remove comment

---------

Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
This commit is contained in:
dazi005
2025-10-10 07:59:08 +08:00
committed by GitHub
parent 3ae0197e68
commit cce28f9d2b
6 changed files with 124 additions and 84 deletions

View File

@@ -19,6 +19,7 @@ import (
"time"
gws "github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/encoding/json"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
@@ -240,15 +241,12 @@ func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHan
return
}
c.Wg.Go(func() {
ticker := time.NewTicker(handler.Delay)
for {
select {
case <-c.shutdown:
ticker.Stop()
return
case <-ticker.C:
err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
if err != nil {
case <-time.After(handler.Delay):
if err := c.SendRawMessage(context.Background(), epl, handler.MessageType, handler.Message); err != nil {
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err)
return
}
@@ -368,8 +366,8 @@ func (c *connection) defaultGenerateMessageID(highPrec bool) int64 {
// Shutdown shuts down and closes specific connection
func (c *connection) Shutdown() error {
if c == nil || c.Connection == nil {
return nil
if err := common.NilGuard(c, c.Connection); err != nil {
return err
}
c.setConnectedStatus(false)
c.writeControl.Lock()

View File

@@ -633,15 +633,14 @@ func (m *Manager) Shutdown() error {
}
func (m *Manager) shutdown() error {
if !m.IsConnected() {
return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected)
}
// TODO: Interrupt connection and or close connection when it is re-established.
if m.IsConnecting() {
return fmt.Errorf("%v %w: %w ", m.exchangeName, errCannotShutdown, errAlreadyReconnecting)
}
if !m.IsConnected() {
return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected)
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", m.exchangeName)
}
@@ -682,11 +681,22 @@ func (m *Manager) shutdown() error {
// flush any subscriptions from last connection if needed
m.subscriptions.Clear()
m.setState(disconnectedState)
close(m.ShutdownC)
m.setState(disconnectedState)
m.Wg.Wait()
m.ShutdownC = make(chan struct{})
for _, conn := range []Connection{m.Conn, m.AuthConn} {
if conn == nil {
continue
}
conn, ok := conn.(*connection)
if !ok {
return fmt.Errorf("%s websocket: %w", m.exchangeName, common.GetTypeAssertError("*connection", conn))
}
conn.shutdown = m.ShutdownC
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", m.exchangeName)
}

View File

@@ -1088,7 +1088,7 @@ func TestConnectionShutdown(t *testing.T) {
t.Parallel()
wc := connection{shutdown: make(chan struct{})}
err := wc.Shutdown()
assert.NoError(t, err, "Shutdown should not error")
assert.ErrorIs(t, err, common.ErrNilPointer, "Shutdown should error correctly")
err = wc.Dial(t.Context(), &gws.Dialer{}, nil)
assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial should error correctly")
@@ -1336,3 +1336,54 @@ func TestGetConnection(t *testing.T) {
require.NoError(t, err)
assert.Same(t, expected, conn)
}
func TestShutdown(t *testing.T) {
t.Parallel()
m := Manager{}
m.setState(connectingState)
require.ErrorIs(t, m.Shutdown(), errAlreadyReconnecting, "Shutdown must error correctly")
m.setState(disconnectedState)
require.ErrorIs(t, m.Shutdown(), ErrNotConnected, "Shutdown must error correctly")
m.setState(connectedState)
require.Panics(t, func() { _ = m.Shutdown() }, "Shutdown must panic on nil shutdown channel")
m.ShutdownC = make(chan struct{})
require.NoError(t, m.Shutdown(), "Shutdown must not error with no connections")
m.setState(connectedState)
m.Conn = &struct{ *connection }{&connection{}}
m.AuthConn = &struct{ *connection }{&connection{}}
require.ErrorIs(t, m.Shutdown(), common.ErrTypeAssertFailure, "Shutdown must error with unhandled connection type")
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
defer mock.Close()
wsURL := "ws" + mock.URL[len("http"):] + "/ws"
conn, resp, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
require.NoError(t, err, "DialContext must not error")
defer resp.Body.Close()
m.AuthConn = nil
m.Conn = nil
m.connectionManager = []*connectionWrapper{{connection: &connection{Connection: nil}}, {connection: &connection{Connection: conn}}}
m.setState(connectedState)
require.NoError(t, m.Shutdown(), "Shutdown must not error with faulty connection in connectionManager")
gwsConnAuth, respAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
require.NoError(t, err, "DialContext must not error")
defer respAuth.Body.Close()
gwsConnUnAuth, respUnAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
require.NoError(t, err, "DialContext must not error")
defer respUnAuth.Body.Close()
m.connectionManager = nil
authConn := &connection{Connection: gwsConnAuth, shutdown: m.ShutdownC}
m.AuthConn = authConn
unauthConn := &connection{Connection: gwsConnUnAuth, shutdown: m.ShutdownC}
m.Conn = unauthConn
m.setState(connectedState)
require.NoError(t, m.Shutdown(), "Shutdown must not error with good connections")
require.Equal(t, m.ShutdownC, authConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed")
require.Equal(t, m.ShutdownC, unauthConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed")
}

View File

@@ -2,16 +2,12 @@ package okx
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strconv"
"strings"
"time"
gws "github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/encoding/json"
"github.com/thrasher-corp/gocryptotrader/exchange/websocket"
@@ -55,15 +51,13 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error {
dialer.WriteBufferSize = 8192
e.Websocket.Conn.SetURL(okxBusinessWebsocketURL)
err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil {
return err
}
e.Websocket.Wg.Add(1)
go e.wsReadData(ctx, e.Websocket.Conn)
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) })
if e.Verbose {
log.Debugf(log.ExchangeSys, "Successful connection to %v\n",
e.Websocket.GetWebsocketURL())
log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL())
}
e.Websocket.Conn.SetupPingHandler(request.UnAuth, websocket.PingHandler{
MessageType: gws.TextMessage,
@@ -71,45 +65,14 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error {
Delay: time.Second * 20,
})
if e.Websocket.CanUseAuthenticatedEndpoints() {
err = e.WsSpreadAuth(ctx)
if err != nil {
log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error())
if err := e.authenticateConnection(ctx, e.Websocket.Conn); err != nil { // business WS uses same conn for public and private
log.Errorf(log.ExchangeSys, "Error authenticating business websocket: %s", err)
e.Websocket.SetCanUseAuthenticatedEndpoints(false)
}
}
return nil
}
// WsSpreadAuth will connect to Okx's Private websocket connection and Authenticate with a login payload.
func (e *Exchange) WsSpreadAuth(ctx context.Context) error {
if !e.Websocket.CanUseAuthenticatedEndpoints() {
return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name)
}
creds, err := e.GetCredentials(ctx)
if err != nil {
return err
}
e.Websocket.SetCanUseAuthenticatedEndpoints(true)
ts := time.Now().Unix()
signPath := "/users/self/verify"
hmac, err := crypto.GetHMAC(crypto.HashSHA256,
[]byte(strconv.FormatInt(ts, 10)+http.MethodGet+signPath),
[]byte(creds.Secret),
)
if err != nil {
return err
}
args := []WebsocketLoginData{
{
APIKey: creds.Key,
Passphrase: creds.ClientID,
Timestamp: ts,
Sign: base64.StdEncoding.EncodeToString(hmac),
},
}
return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil)
}
// GenerateDefaultBusinessSubscriptions returns a list of default subscriptions to business websocket.
func (e *Exchange) GenerateDefaultBusinessSubscriptions() ([]subscription.Subscription, error) {
var subs []string

View File

@@ -3171,6 +3171,12 @@ type WebsocketLoginData struct {
Sign string `json:"sign"`
}
// WebsocketAuthLogin holds the operation and arguments
type WebsocketAuthLogin struct {
Operation string `json:"op"`
Arguments []WebsocketLoginData `json:"args"`
}
// SubscriptionInfo holds the channel and instrument IDs
type SubscriptionInfo struct {
Channel string `json:"channel,omitempty"`

View File

@@ -41,7 +41,7 @@ var (
// See: https://www.okx.com/docs-v5/en/#error-code-websocket-public
authConnErrorCodes = []string{
"60007", "60022", "60023", "60024", "60026", "63999", "60032", "60011", "60009",
"60005", "60021", "60031", "50110",
"60005", "60021", "60031", "50110", "60033",
}
)
@@ -245,15 +245,12 @@ func (e *Exchange) WsConnect() error {
dialer.ReadBufferSize = 8192
dialer.WriteBufferSize = 8192
err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil {
return err
}
e.Websocket.Wg.Add(1)
go e.wsReadData(ctx, e.Websocket.Conn)
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) })
if e.Verbose {
log.Debugf(log.ExchangeSys, "Successful connection to %v\n",
e.Websocket.GetWebsocketURL())
log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL())
}
e.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{
MessageType: gws.TextMessage,
@@ -261,8 +258,7 @@ func (e *Exchange) WsConnect() error {
Delay: time.Second * 20,
})
if e.Websocket.CanUseAuthenticatedEndpoints() {
err = e.WsAuth(ctx)
if err != nil {
if err := e.WsAuth(ctx); err != nil {
log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error())
e.Websocket.SetCanUseAuthenticatedEndpoints(false)
}
@@ -275,24 +271,24 @@ func (e *Exchange) WsAuth(ctx context.Context) error {
if !e.AreCredentialsValid(ctx) || !e.Websocket.CanUseAuthenticatedEndpoints() {
return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name)
}
creds, err := e.GetCredentials(ctx)
if err != nil {
return err
}
var dialer gws.Dialer
err = e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{})
if err != nil {
if err := e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}); err != nil {
return err
}
e.Websocket.Wg.Add(1)
go e.wsReadData(ctx, e.Websocket.AuthConn)
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.AuthConn) })
e.Websocket.AuthConn.SetupPingHandler(request.Unset, websocket.PingHandler{
MessageType: gws.TextMessage,
Message: pingMsg,
Delay: time.Second * 20,
})
return e.authenticateConnection(ctx, e.Websocket.AuthConn)
}
e.Websocket.SetCanUseAuthenticatedEndpoints(true)
func (e *Exchange) authenticateConnection(ctx context.Context, conn websocket.Connection) error {
creds, err := e.GetCredentials(ctx)
if err != nil {
return err
}
ts := time.Now().Unix()
signPath := "/users/self/verify"
hmac, err := crypto.GetHMAC(crypto.HashSHA256,
@@ -303,21 +299,37 @@ func (e *Exchange) WsAuth(ctx context.Context) error {
return err
}
args := []WebsocketLoginData{
{
APIKey: creds.Key,
Passphrase: creds.ClientID,
Timestamp: ts,
Sign: base64.StdEncoding.EncodeToString(hmac),
op := WebsocketAuthLogin{
Operation: operationLogin,
Arguments: []WebsocketLoginData{
{
APIKey: creds.Key,
Passphrase: creds.ClientID,
Timestamp: ts,
Sign: base64.StdEncoding.EncodeToString(hmac),
},
},
}
resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, "login-response", op)
if err != nil {
return err
}
var intermediary struct {
Code int64 `json:"code,string"`
Message string `json:"msg"`
}
if err := json.Unmarshal(resp, &intermediary); err != nil {
return err
}
return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil)
if intermediary.Code != 0 {
return getStatusError(intermediary.Code, intermediary.Message)
}
return nil
}
// wsReadData sends msgs from public and auth websockets to data handler
func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) {
defer e.Websocket.Wg.Done()
for {
resp := ws.ReadMessage()
if resp.Raw == nil {