From b74888577c5f9f6ecaf1d3712a837d540a849d98 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 11 Sep 2025 16:03:54 +0700 Subject: [PATCH] Okx: Fix websocket candle subscription handling (#1990) * Okx: Fix websocket candle subscription handling * Okx: Fix panic in BusinessSubscribe when no pairs * Okx: Fix MarkPriceCandles not including Pair --- exchanges/okx/okx_business_websocket.go | 44 ++++++++++++++----------- exchanges/okx/okx_test.go | 41 +++++++++++++++++++++++ exchanges/okx/okx_types.go | 13 ++++---- exchanges/okx/okx_websocket.go | 19 ++++++----- 4 files changed, 83 insertions(+), 34 deletions(-) diff --git a/exchanges/okx/okx_business_websocket.go b/exchanges/okx/okx_business_websocket.go index 2f3ace0b..3e965bab 100644 --- a/exchanges/okx/okx_business_websocket.go +++ b/exchanges/okx/okx_business_websocket.go @@ -6,9 +6,11 @@ import ( "fmt" "net/http" "strconv" + "strings" "time" gws "github.com/gorilla/websocket" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -178,28 +180,32 @@ func (e *Exchange) handleBusinessSubscription(ctx context.Context, operation str arg := SubscriptionInfo{ Channel: subscriptions[i].Channel, } - var instrumentFamily, spreadID string - var instrumentID currency.Pair + switch arg.Channel { - case okxSpreadOrders, - okxSpreadTrades, - okxSpreadOrderbookLevel1, - okxSpreadOrderbook, - okxSpreadPublicTrades, - okxSpreadPublicTicker: - spreadID = subscriptions[i].Pairs[0].String() - case channelPublicBlockTrades, - channelBlockTickers: - instrumentID = subscriptions[i].Pairs[0] - } - instrumentFamilyInterface, okay := subscriptions[i].Params["instFamily"] - if okay { - instrumentFamily, _ = instrumentFamilyInterface.(string) + case okxSpreadOrders, okxSpreadTrades, okxSpreadOrderbookLevel1, okxSpreadOrderbook, okxSpreadPublicTrades, okxSpreadPublicTicker: + if len(subscriptions[i].Pairs) != 1 { + return currency.ErrCurrencyPairEmpty + } + arg.SpreadID = subscriptions[i].Pairs[0].String() + case channelPublicBlockTrades, channelBlockTickers: + if len(subscriptions[i].Pairs) != 1 { + return currency.ErrCurrencyPairEmpty + } + arg.InstrumentID = subscriptions[i].Pairs[0] } - arg.InstrumentFamily = instrumentFamily - arg.SpreadID = spreadID - arg.InstrumentID = instrumentID + if strings.HasPrefix(arg.Channel, candle) || strings.HasPrefix(arg.Channel, indexCandlestick) || strings.HasPrefix(arg.Channel, markPrice) { + if len(subscriptions[i].Pairs) != 1 { + return currency.ErrCurrencyPairEmpty + } + arg.InstrumentID = subscriptions[i].Pairs[0] + } + + if ifAny, ok := subscriptions[i].Params["instFamily"]; ok { + if arg.InstrumentFamily, ok = ifAny.(string); !ok { + return common.GetTypeAssertError("string", ifAny, "instFamily") + } + } var chunk []byte channels = append(channels, subscriptions[i]) diff --git a/exchanges/okx/okx_test.go b/exchanges/okx/okx_test.go index f2765658..611d46e0 100644 --- a/exchanges/okx/okx_test.go +++ b/exchanges/okx/okx_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/encoding/json" "github.com/thrasher-corp/gocryptotrader/exchange/order/limits" + "github.com/thrasher-corp/gocryptotrader/exchange/websocket" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/collateral" @@ -6089,6 +6090,46 @@ func TestGenerateSubscriptions(t *testing.T) { testsubs.EqualLists(t, exp, subs) } +func TestBusinessWSCandleSubscriptions(t *testing.T) { + t.Parallel() + e := new(Exchange) //nolint:govet // Intentional shadow + require.NoError(t, testexch.Setup(e), "Setup must not error") + + err := e.WsConnectBusiness(t.Context()) + require.NoError(t, err, "WsConnectBusiness must not error") + + err = e.BusinessSubscribe(t.Context(), subscription.List{{Channel: channelCandle1D}}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + p := currency.Pairs{ + mainPair, + currency.NewPairWithDelimiter("ETH", "USDT", "-"), + currency.NewPairWithDelimiter("OKB", "USDT", "-"), + } + + for i, ch := range []string{channelCandle1D, channelMarkPriceCandle1M, channelIndexCandle1H} { + err := e.BusinessSubscribe(t.Context(), subscription.List{{Channel: ch, Pairs: p[i : i+1]}}) + require.NoErrorf(t, err, "BusinessSubscribe %s-%s must not error", ch, p[i]) + } + + var got currency.Pairs + assert.Eventually(t, func() bool { + select { + case a := <-e.Websocket.DataHandler: + switch v := a.(type) { + case websocket.KlineData: + got = got.Add(v.Pair) + case []CandlestickMarkPrice: + if len(v) > 0 { + got = got.Add(v[0].Pair) + } + } + default: + } + return len(got) == 3 + }, 4*time.Second, 100*time.Millisecond, "Should eventually get candles from the datahandler") +} + const ( processSpreadOrderbookJSON = `{"arg":{"channel":"sprd-books5", "sprdId": "BTC-USDT_BTC-USDT-SWAP" }, "data": [ { "asks": [ ["111.06","55154","2"], ["111.07","53276","2"], ["111.08","72435","2"], ["111.09","70312","2"], ["111.1","67272","2"]], "bids": [ ["111.05","57745","2"], ["111.04","57109","2"], ["111.03","69563","2"], ["111.02","71248","2"], ["111.01","65090","2"]], "ts": "1670324386802"}]}` wsProcessPublicSpreadTradesJSON = `{"arg":{"channel":"sprd-public-trades", "sprdId": "BTC-USDT_BTC-USDT-SWAP" }, "data": [ { "sprdId": "BTC-USDT_BTC-USDT-SWAP", "tradeId": "2499206329160695808", "px": "-10", "sz": "0.001", "side": "sell", "ts": "1726801105519"}]}` diff --git a/exchanges/okx/okx_types.go b/exchanges/okx/okx_types.go index c2fe843d..50cbe3d7 100644 --- a/exchanges/okx/okx_types.go +++ b/exchanges/okx/okx_types.go @@ -3637,13 +3637,14 @@ type WsDeliveryEstimatedPrice struct { Data []DeliveryEstimatedPrice `json:"data"` } -// CandlestickMarkPrice represents candlestick mark price push data as a result of subscription to "mark-price-candle*" channel +// CandlestickMarkPrice contains mark-price-candle subscription candles type CandlestickMarkPrice struct { - Timestamp time.Time `json:"ts"` - OpenPrice float64 `json:"o"` - HighestPrice float64 `json:"h"` - LowestPrice float64 `json:"l"` - ClosePrice float64 `json:"s"` + Pair currency.Pair + Timestamp time.Time + OpenPrice float64 + HighestPrice float64 + LowestPrice float64 + ClosePrice float64 } // WsOrderBook order book represents order book push data which is returned as a result of subscription to "books*" channel diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index 59bb41cd..72d57537 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -1121,22 +1121,23 @@ func (e *Exchange) CalculateOrderbookChecksum(orderbookData *WsOrderBookData) (u // wsHandleMarkPriceCandles processes candlestick mark price push data as a result of subscription to "mark-price-candle*" channel. func (e *Exchange) wsHandleMarkPriceCandles(data []byte) error { - tempo := &struct { + m := &struct { Argument SubscriptionInfo `json:"arg"` Data [][5]types.Number `json:"data"` }{} - err := json.Unmarshal(data, tempo) + err := json.Unmarshal(data, m) if err != nil { return err } - candles := make([]CandlestickMarkPrice, len(tempo.Data)) - for x := range tempo.Data { + candles := make([]CandlestickMarkPrice, len(m.Data)) + for x := range m.Data { candles[x] = CandlestickMarkPrice{ - Timestamp: time.UnixMilli(tempo.Data[x][0].Int64()), - OpenPrice: tempo.Data[x][1].Float64(), - HighestPrice: tempo.Data[x][2].Float64(), - LowestPrice: tempo.Data[x][3].Float64(), - ClosePrice: tempo.Data[x][4].Float64(), + Pair: m.Argument.InstrumentID, + Timestamp: time.UnixMilli(m.Data[x][0].Int64()), + OpenPrice: m.Data[x][1].Float64(), + HighestPrice: m.Data[x][2].Float64(), + LowestPrice: m.Data[x][3].Float64(), + ClosePrice: m.Data[x][4].Float64(), } } e.Websocket.DataHandler <- candles