mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 23:16:45 +00:00
OKX: Fix authenticated websocket login (#2051)
* Okx:remove * Okx:replace * Okx:ping gws.PingMessage * Okx:ping gws.PingMessage * Okx: add authenticateConnection * Okx: fix pingHandler * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/okx/okx_websocket.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/okx/okx_business_websocket.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Okx:UseMultiConnectionManagement * Okx:rm UseMultiConnectionManagement * Okx:roll back * Okx:apply diff * Okx:make lint fix * Okx:make lint fix * Okx:make lint fix * Okx:fix * Okx:fix name * Okx:fix NilGuard depends on #2076 * Okx:remove comment --------- Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com> Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/encoding/json"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
@@ -240,15 +241,12 @@ func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHan
|
||||
return
|
||||
}
|
||||
c.Wg.Go(func() {
|
||||
ticker := time.NewTicker(handler.Delay)
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdown:
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
|
||||
if err != nil {
|
||||
case <-time.After(handler.Delay):
|
||||
if err := c.SendRawMessage(context.Background(), epl, handler.MessageType, handler.Message); err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err)
|
||||
return
|
||||
}
|
||||
@@ -368,8 +366,8 @@ func (c *connection) defaultGenerateMessageID(highPrec bool) int64 {
|
||||
|
||||
// Shutdown shuts down and closes specific connection
|
||||
func (c *connection) Shutdown() error {
|
||||
if c == nil || c.Connection == nil {
|
||||
return nil
|
||||
if err := common.NilGuard(c, c.Connection); err != nil {
|
||||
return err
|
||||
}
|
||||
c.setConnectedStatus(false)
|
||||
c.writeControl.Lock()
|
||||
|
||||
@@ -633,15 +633,14 @@ func (m *Manager) Shutdown() error {
|
||||
}
|
||||
|
||||
func (m *Manager) shutdown() error {
|
||||
if !m.IsConnected() {
|
||||
return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected)
|
||||
}
|
||||
|
||||
// TODO: Interrupt connection and or close connection when it is re-established.
|
||||
if m.IsConnecting() {
|
||||
return fmt.Errorf("%v %w: %w ", m.exchangeName, errCannotShutdown, errAlreadyReconnecting)
|
||||
}
|
||||
|
||||
if !m.IsConnected() {
|
||||
return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected)
|
||||
}
|
||||
|
||||
if m.verbose {
|
||||
log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", m.exchangeName)
|
||||
}
|
||||
@@ -682,11 +681,22 @@ func (m *Manager) shutdown() error {
|
||||
// flush any subscriptions from last connection if needed
|
||||
m.subscriptions.Clear()
|
||||
|
||||
m.setState(disconnectedState)
|
||||
|
||||
close(m.ShutdownC)
|
||||
m.setState(disconnectedState)
|
||||
m.Wg.Wait()
|
||||
m.ShutdownC = make(chan struct{})
|
||||
|
||||
for _, conn := range []Connection{m.Conn, m.AuthConn} {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
conn, ok := conn.(*connection)
|
||||
if !ok {
|
||||
return fmt.Errorf("%s websocket: %w", m.exchangeName, common.GetTypeAssertError("*connection", conn))
|
||||
}
|
||||
conn.shutdown = m.ShutdownC
|
||||
}
|
||||
|
||||
if m.verbose {
|
||||
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", m.exchangeName)
|
||||
}
|
||||
|
||||
@@ -1088,7 +1088,7 @@ func TestConnectionShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
wc := connection{shutdown: make(chan struct{})}
|
||||
err := wc.Shutdown()
|
||||
assert.NoError(t, err, "Shutdown should not error")
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Shutdown should error correctly")
|
||||
|
||||
err = wc.Dial(t.Context(), &gws.Dialer{}, nil)
|
||||
assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial should error correctly")
|
||||
@@ -1336,3 +1336,54 @@ func TestGetConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, expected, conn)
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := Manager{}
|
||||
m.setState(connectingState)
|
||||
require.ErrorIs(t, m.Shutdown(), errAlreadyReconnecting, "Shutdown must error correctly")
|
||||
m.setState(disconnectedState)
|
||||
require.ErrorIs(t, m.Shutdown(), ErrNotConnected, "Shutdown must error correctly")
|
||||
m.setState(connectedState)
|
||||
require.Panics(t, func() { _ = m.Shutdown() }, "Shutdown must panic on nil shutdown channel")
|
||||
m.ShutdownC = make(chan struct{})
|
||||
require.NoError(t, m.Shutdown(), "Shutdown must not error with no connections")
|
||||
m.setState(connectedState)
|
||||
m.Conn = &struct{ *connection }{&connection{}}
|
||||
m.AuthConn = &struct{ *connection }{&connection{}}
|
||||
require.ErrorIs(t, m.Shutdown(), common.ErrTypeAssertFailure, "Shutdown must error with unhandled connection type")
|
||||
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
|
||||
defer mock.Close()
|
||||
|
||||
wsURL := "ws" + mock.URL[len("http"):] + "/ws"
|
||||
conn, resp, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
|
||||
require.NoError(t, err, "DialContext must not error")
|
||||
defer resp.Body.Close()
|
||||
|
||||
m.AuthConn = nil
|
||||
m.Conn = nil
|
||||
m.connectionManager = []*connectionWrapper{{connection: &connection{Connection: nil}}, {connection: &connection{Connection: conn}}}
|
||||
m.setState(connectedState)
|
||||
require.NoError(t, m.Shutdown(), "Shutdown must not error with faulty connection in connectionManager")
|
||||
|
||||
gwsConnAuth, respAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
|
||||
require.NoError(t, err, "DialContext must not error")
|
||||
defer respAuth.Body.Close()
|
||||
|
||||
gwsConnUnAuth, respUnAuth, err := gws.DefaultDialer.DialContext(t.Context(), wsURL, nil)
|
||||
require.NoError(t, err, "DialContext must not error")
|
||||
defer respUnAuth.Body.Close()
|
||||
|
||||
m.connectionManager = nil
|
||||
authConn := &connection{Connection: gwsConnAuth, shutdown: m.ShutdownC}
|
||||
m.AuthConn = authConn
|
||||
unauthConn := &connection{Connection: gwsConnUnAuth, shutdown: m.ShutdownC}
|
||||
m.Conn = unauthConn
|
||||
|
||||
m.setState(connectedState)
|
||||
require.NoError(t, m.Shutdown(), "Shutdown must not error with good connections")
|
||||
|
||||
require.Equal(t, m.ShutdownC, authConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed")
|
||||
require.Equal(t, m.ShutdownC, unauthConn.shutdown, "shutdown channels must be the same after original shutdown channel is closed")
|
||||
}
|
||||
|
||||
@@ -2,16 +2,12 @@ package okx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/crypto"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/encoding/json"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchange/websocket"
|
||||
@@ -55,15 +51,13 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error {
|
||||
dialer.WriteBufferSize = 8192
|
||||
|
||||
e.Websocket.Conn.SetURL(okxBusinessWebsocketURL)
|
||||
err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
|
||||
if err != nil {
|
||||
if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil {
|
||||
return err
|
||||
}
|
||||
e.Websocket.Wg.Add(1)
|
||||
go e.wsReadData(ctx, e.Websocket.Conn)
|
||||
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) })
|
||||
|
||||
if e.Verbose {
|
||||
log.Debugf(log.ExchangeSys, "Successful connection to %v\n",
|
||||
e.Websocket.GetWebsocketURL())
|
||||
log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL())
|
||||
}
|
||||
e.Websocket.Conn.SetupPingHandler(request.UnAuth, websocket.PingHandler{
|
||||
MessageType: gws.TextMessage,
|
||||
@@ -71,45 +65,14 @@ func (e *Exchange) WsConnectBusiness(ctx context.Context) error {
|
||||
Delay: time.Second * 20,
|
||||
})
|
||||
if e.Websocket.CanUseAuthenticatedEndpoints() {
|
||||
err = e.WsSpreadAuth(ctx)
|
||||
if err != nil {
|
||||
log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error())
|
||||
if err := e.authenticateConnection(ctx, e.Websocket.Conn); err != nil { // business WS uses same conn for public and private
|
||||
log.Errorf(log.ExchangeSys, "Error authenticating business websocket: %s", err)
|
||||
e.Websocket.SetCanUseAuthenticatedEndpoints(false)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WsSpreadAuth will connect to Okx's Private websocket connection and Authenticate with a login payload.
|
||||
func (e *Exchange) WsSpreadAuth(ctx context.Context) error {
|
||||
if !e.Websocket.CanUseAuthenticatedEndpoints() {
|
||||
return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name)
|
||||
}
|
||||
creds, err := e.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.Websocket.SetCanUseAuthenticatedEndpoints(true)
|
||||
ts := time.Now().Unix()
|
||||
signPath := "/users/self/verify"
|
||||
hmac, err := crypto.GetHMAC(crypto.HashSHA256,
|
||||
[]byte(strconv.FormatInt(ts, 10)+http.MethodGet+signPath),
|
||||
[]byte(creds.Secret),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args := []WebsocketLoginData{
|
||||
{
|
||||
APIKey: creds.Key,
|
||||
Passphrase: creds.ClientID,
|
||||
Timestamp: ts,
|
||||
Sign: base64.StdEncoding.EncodeToString(hmac),
|
||||
},
|
||||
}
|
||||
return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil)
|
||||
}
|
||||
|
||||
// GenerateDefaultBusinessSubscriptions returns a list of default subscriptions to business websocket.
|
||||
func (e *Exchange) GenerateDefaultBusinessSubscriptions() ([]subscription.Subscription, error) {
|
||||
var subs []string
|
||||
|
||||
@@ -3171,6 +3171,12 @@ type WebsocketLoginData struct {
|
||||
Sign string `json:"sign"`
|
||||
}
|
||||
|
||||
// WebsocketAuthLogin holds the operation and arguments
|
||||
type WebsocketAuthLogin struct {
|
||||
Operation string `json:"op"`
|
||||
Arguments []WebsocketLoginData `json:"args"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo holds the channel and instrument IDs
|
||||
type SubscriptionInfo struct {
|
||||
Channel string `json:"channel,omitempty"`
|
||||
|
||||
@@ -41,7 +41,7 @@ var (
|
||||
// See: https://www.okx.com/docs-v5/en/#error-code-websocket-public
|
||||
authConnErrorCodes = []string{
|
||||
"60007", "60022", "60023", "60024", "60026", "63999", "60032", "60011", "60009",
|
||||
"60005", "60021", "60031", "50110",
|
||||
"60005", "60021", "60031", "50110", "60033",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -245,15 +245,12 @@ func (e *Exchange) WsConnect() error {
|
||||
dialer.ReadBufferSize = 8192
|
||||
dialer.WriteBufferSize = 8192
|
||||
|
||||
err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
|
||||
if err != nil {
|
||||
if err := e.Websocket.Conn.Dial(ctx, &dialer, http.Header{}); err != nil {
|
||||
return err
|
||||
}
|
||||
e.Websocket.Wg.Add(1)
|
||||
go e.wsReadData(ctx, e.Websocket.Conn)
|
||||
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.Conn) })
|
||||
if e.Verbose {
|
||||
log.Debugf(log.ExchangeSys, "Successful connection to %v\n",
|
||||
e.Websocket.GetWebsocketURL())
|
||||
log.Debugf(log.ExchangeSys, "Successful connection to %v", e.Websocket.GetWebsocketURL())
|
||||
}
|
||||
e.Websocket.Conn.SetupPingHandler(request.Unset, websocket.PingHandler{
|
||||
MessageType: gws.TextMessage,
|
||||
@@ -261,8 +258,7 @@ func (e *Exchange) WsConnect() error {
|
||||
Delay: time.Second * 20,
|
||||
})
|
||||
if e.Websocket.CanUseAuthenticatedEndpoints() {
|
||||
err = e.WsAuth(ctx)
|
||||
if err != nil {
|
||||
if err := e.WsAuth(ctx); err != nil {
|
||||
log.Errorf(log.ExchangeSys, "Error connecting auth socket: %s\n", err.Error())
|
||||
e.Websocket.SetCanUseAuthenticatedEndpoints(false)
|
||||
}
|
||||
@@ -275,24 +271,24 @@ func (e *Exchange) WsAuth(ctx context.Context) error {
|
||||
if !e.AreCredentialsValid(ctx) || !e.Websocket.CanUseAuthenticatedEndpoints() {
|
||||
return fmt.Errorf("%v AuthenticatedWebsocketAPISupport not enabled", e.Name)
|
||||
}
|
||||
creds, err := e.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var dialer gws.Dialer
|
||||
err = e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{})
|
||||
if err != nil {
|
||||
if err := e.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}); err != nil {
|
||||
return err
|
||||
}
|
||||
e.Websocket.Wg.Add(1)
|
||||
go e.wsReadData(ctx, e.Websocket.AuthConn)
|
||||
e.Websocket.Wg.Go(func() { e.wsReadData(ctx, e.Websocket.AuthConn) })
|
||||
e.Websocket.AuthConn.SetupPingHandler(request.Unset, websocket.PingHandler{
|
||||
MessageType: gws.TextMessage,
|
||||
Message: pingMsg,
|
||||
Delay: time.Second * 20,
|
||||
})
|
||||
return e.authenticateConnection(ctx, e.Websocket.AuthConn)
|
||||
}
|
||||
|
||||
e.Websocket.SetCanUseAuthenticatedEndpoints(true)
|
||||
func (e *Exchange) authenticateConnection(ctx context.Context, conn websocket.Connection) error {
|
||||
creds, err := e.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts := time.Now().Unix()
|
||||
signPath := "/users/self/verify"
|
||||
hmac, err := crypto.GetHMAC(crypto.HashSHA256,
|
||||
@@ -303,21 +299,37 @@ func (e *Exchange) WsAuth(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
args := []WebsocketLoginData{
|
||||
{
|
||||
APIKey: creds.Key,
|
||||
Passphrase: creds.ClientID,
|
||||
Timestamp: ts,
|
||||
Sign: base64.StdEncoding.EncodeToString(hmac),
|
||||
op := WebsocketAuthLogin{
|
||||
Operation: operationLogin,
|
||||
Arguments: []WebsocketLoginData{
|
||||
{
|
||||
APIKey: creds.Key,
|
||||
Passphrase: creds.ClientID,
|
||||
Timestamp: ts,
|
||||
Sign: base64.StdEncoding.EncodeToString(hmac),
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, "login-response", op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var intermediary struct {
|
||||
Code int64 `json:"code,string"`
|
||||
Message string `json:"msg"`
|
||||
}
|
||||
if err := json.Unmarshal(resp, &intermediary); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.SendAuthenticatedWebsocketRequest(ctx, request.Unset, "login-response", operationLogin, args, nil)
|
||||
if intermediary.Code != 0 {
|
||||
return getStatusError(intermediary.Code, intermediary.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// wsReadData sends msgs from public and auth websockets to data handler
|
||||
func (e *Exchange) wsReadData(ctx context.Context, ws websocket.Connection) {
|
||||
defer e.Websocket.Wg.Done()
|
||||
for {
|
||||
resp := ws.ReadMessage()
|
||||
if resp.Raw == nil {
|
||||
|
||||
Reference in New Issue
Block a user