websocket/exchanges: populate context before multi connection upgrade (#1933)

* websocket/exchanges: populate context before multi connection upgrade

* fix test

* linter: fix

* gk: dial

* gk: nits rm param names

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2025-06-17 13:43:00 +10:00
committed by GitHub
parent 2958e64afe
commit 3e80f1b9e5
64 changed files with 1160 additions and 1124 deletions

View File

@@ -14,12 +14,13 @@ import (
// WsInverseConnect connects to inverse websocket feed
func (by *Bybit) WsInverseConnect() error {
ctx := context.TODO()
if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) {
return websocket.ErrWebsocketNotEnabled
}
by.Websocket.Conn.SetURL(inversePublic)
var dialer gws.Dialer
err := by.Websocket.Conn.Dial(&dialer, http.Header{})
err := by.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
return err
}
@@ -30,7 +31,7 @@ func (by *Bybit) WsInverseConnect() error {
})
by.Websocket.Wg.Add(1)
go by.wsReadData(asset.CoinMarginedFutures, by.Websocket.Conn)
go by.wsReadData(ctx, asset.CoinMarginedFutures, by.Websocket.Conn)
return nil
}
@@ -57,15 +58,17 @@ func (by *Bybit) GenerateInverseDefaultSubscriptions() (subscription.List, error
// InverseSubscribe sends a subscription message to linear public channels.
func (by *Bybit) InverseSubscribe(channelSubscriptions subscription.List) error {
return by.handleInversePayloadSubscription("subscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleInversePayloadSubscription(ctx, "subscribe", channelSubscriptions)
}
// InverseUnsubscribe sends an unsubscription messages through linear public channels.
func (by *Bybit) InverseUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleInversePayloadSubscription("unsubscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleInversePayloadSubscription(ctx, "unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions subscription.List) error {
func (by *Bybit) handleInversePayloadSubscription(ctx context.Context, operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
@@ -73,7 +76,7 @@ func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubsc
for a := range payloads {
// The options connection does not send the subscription request id back with the subscription notification payload
// therefore the code doesn't wait for the response to check whether the subscription is successful or not.
err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a])
err = by.Websocket.Conn.SendJSONMessage(ctx, request.Unset, payloads[a])
if err != nil {
return err
}

View File

@@ -14,12 +14,13 @@ import (
// WsLinearConnect connects to linear a websocket feed
func (by *Bybit) WsLinearConnect() error {
ctx := context.TODO()
if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) {
return websocket.ErrWebsocketNotEnabled
}
by.Websocket.Conn.SetURL(linearPublic)
var dialer gws.Dialer
err := by.Websocket.Conn.Dial(&dialer, http.Header{})
err := by.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
return err
}
@@ -30,9 +31,9 @@ func (by *Bybit) WsLinearConnect() error {
})
by.Websocket.Wg.Add(1)
go by.wsReadData(asset.LinearContract, by.Websocket.Conn)
go by.wsReadData(ctx, asset.LinearContract, by.Websocket.Conn)
if by.IsWebsocketAuthenticationSupported() {
err = by.WsAuth(context.TODO())
err = by.WsAuth(ctx)
if err != nil {
by.Websocket.DataHandler <- err
by.Websocket.SetCanUseAuthenticatedEndpoints(false)
@@ -75,15 +76,17 @@ func (by *Bybit) GenerateLinearDefaultSubscriptions() (subscription.List, error)
// LinearSubscribe sends a subscription message to linear public channels.
func (by *Bybit) LinearSubscribe(channelSubscriptions subscription.List) error {
return by.handleLinearPayloadSubscription("subscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleLinearPayloadSubscription(ctx, "subscribe", channelSubscriptions)
}
// LinearUnsubscribe sends an unsubscription messages through linear public channels.
func (by *Bybit) LinearUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleLinearPayloadSubscription("unsubscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleLinearPayloadSubscription(ctx, "unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
func (by *Bybit) handleLinearPayloadSubscription(ctx context.Context, operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
@@ -91,7 +94,7 @@ func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscr
for a := range payloads {
// The options connection does not send the subscription request id back with the subscription notification payload
// therefore the code doesn't wait for the response to check whether the subscription is successful or not.
err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a])
err = by.Websocket.Conn.SendJSONMessage(ctx, request.Unset, payloads[a])
if err != nil {
return err
}

View File

@@ -16,12 +16,13 @@ import (
// WsOptionsConnect connects to options a websocket feed
func (by *Bybit) WsOptionsConnect() error {
ctx := context.TODO()
if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) {
return websocket.ErrWebsocketNotEnabled
}
by.Websocket.Conn.SetURL(optionPublic)
var dialer gws.Dialer
err := by.Websocket.Conn.Dial(&dialer, http.Header{})
err := by.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
return err
}
@@ -37,7 +38,7 @@ func (by *Bybit) WsOptionsConnect() error {
})
by.Websocket.Wg.Add(1)
go by.wsReadData(asset.Options, by.Websocket.Conn)
go by.wsReadData(ctx, asset.Options, by.Websocket.Conn)
return nil
}
@@ -64,15 +65,17 @@ func (by *Bybit) GenerateOptionsDefaultSubscriptions() (subscription.List, error
// OptionSubscribe sends a subscription message to options public channels.
func (by *Bybit) OptionSubscribe(channelSubscriptions subscription.List) error {
return by.handleOptionsPayloadSubscription("subscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleOptionsPayloadSubscription(ctx, "subscribe", channelSubscriptions)
}
// OptionUnsubscribe sends an unsubscription messages through options public channels.
func (by *Bybit) OptionUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleOptionsPayloadSubscription("unsubscribe", channelSubscriptions)
ctx := context.TODO()
return by.handleOptionsPayloadSubscription(ctx, "unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
func (by *Bybit) handleOptionsPayloadSubscription(ctx context.Context, operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(operation, channelSubscriptions)
if err != nil {
return err
@@ -80,7 +83,7 @@ func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubsc
for a := range payloads {
// The options connection does not send the subscription request id back with the subscription notification payload
// therefore the code doesn't wait for the response to check whether the subscription is successful or not.
err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a])
err = by.Websocket.Conn.SendJSONMessage(ctx, request.Unset, payloads[a])
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package bybit
import (
"context"
"errors"
"fmt"
"maps"
@@ -3054,7 +3055,7 @@ func TestPushData(t *testing.T) {
slices.Sort(keys)
for x := range keys {
err := b.wsHandleData(asset.Spot, []byte(pushDataMap[keys[x]]))
err := b.wsHandleData(t.Context(), asset.Spot, []byte(pushDataMap[keys[x]]))
assert.NoError(t, err, "wsHandleData should not error")
}
}
@@ -3067,9 +3068,9 @@ func TestWsTicker(t *testing.T) {
asset.USDCMarginedFutures, asset.USDCMarginedFutures, asset.CoinMarginedFutures, asset.CoinMarginedFutures,
}
require.NoError(t, testexch.Setup(b), "Test instance Setup must not error")
testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(r []byte) error {
testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(_ context.Context, r []byte) error {
defer slices.Delete(assetRouting, 0, 1)
return b.wsHandleData(assetRouting[0], r)
return b.wsHandleData(t.Context(), assetRouting[0], r)
})
close(b.Websocket.DataHandler)
expected := 8
@@ -3318,7 +3319,7 @@ func TestFetchTradablePairs(t *testing.T) {
func TestDeltaUpdateOrderbook(t *testing.T) {
t.Parallel()
data := []byte(`{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"snapshot","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}`)
err := b.wsHandleData(asset.Spot, data)
err := b.wsHandleData(t.Context(), asset.Spot, data)
if err != nil {
t.Fatal(err)
}

View File

@@ -80,11 +80,12 @@ var subscriptionNames = map[string]string{
// WsConnect connects to a websocket feed
func (by *Bybit) WsConnect() error {
ctx := context.TODO()
if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) {
return websocket.ErrWebsocketNotEnabled
}
var dialer gws.Dialer
err := by.Websocket.Conn.Dial(&dialer, http.Header{})
err := by.Websocket.Conn.Dial(ctx, &dialer, http.Header{})
if err != nil {
return err
}
@@ -95,9 +96,9 @@ func (by *Bybit) WsConnect() error {
})
by.Websocket.Wg.Add(1)
go by.wsReadData(asset.Spot, by.Websocket.Conn)
go by.wsReadData(ctx, asset.Spot, by.Websocket.Conn)
if by.Websocket.CanUseAuthenticatedEndpoints() {
err = by.WsAuth(context.TODO())
err = by.WsAuth(ctx)
if err != nil {
by.Websocket.DataHandler <- err
by.Websocket.SetCanUseAuthenticatedEndpoints(false)
@@ -114,7 +115,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
}
var dialer gws.Dialer
if err := by.Websocket.AuthConn.Dial(&dialer, http.Header{}); err != nil {
if err := by.Websocket.AuthConn.Dial(ctx, &dialer, http.Header{}); err != nil {
return err
}
@@ -125,7 +126,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
})
by.Websocket.Wg.Add(1)
go by.wsReadData(asset.Spot, by.Websocket.AuthConn)
go by.wsReadData(ctx, asset.Spot, by.Websocket.AuthConn)
intNonce := time.Now().Add(time.Hour * 6).UnixMilli()
strNonce := strconv.FormatInt(intNonce, 10)
@@ -143,7 +144,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
Operation: "auth",
Args: []any{creds.Key, intNonce, sign},
}
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, req.RequestID, req)
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req)
if err != nil {
return err
}
@@ -160,7 +161,8 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
// Subscribe sends a websocket message to receive data from the channel
func (by *Bybit) Subscribe(channelsToSubscribe subscription.List) error {
return by.handleSpotSubscription("subscribe", channelsToSubscribe)
ctx := context.TODO()
return by.handleSpotSubscription(ctx, "subscribe", channelsToSubscribe)
}
func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (args []SubscriptionArgument, err error) {
@@ -186,10 +188,11 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (by *Bybit) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return by.handleSpotSubscription("unsubscribe", channelsToUnsubscribe)
ctx := context.TODO()
return by.handleSpotSubscription(ctx, "unsubscribe", channelsToUnsubscribe)
}
func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe subscription.List) error {
func (by *Bybit) handleSpotSubscription(ctx context.Context, operation string, channelsToSubscribe subscription.List) error {
payloads, err := by.handleSubscriptions(operation, channelsToSubscribe)
if err != nil {
return err
@@ -197,12 +200,12 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
for a := range payloads {
var response []byte
if payloads[a].auth {
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[a].RequestID, payloads[a])
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(ctx, request.Unset, payloads[a].RequestID, payloads[a])
if err != nil {
return err
}
} else {
response, err = by.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[a].RequestID, payloads[a])
response, err = by.Websocket.Conn.SendMessageReturnResponse(ctx, request.Unset, payloads[a].RequestID, payloads[a])
if err != nil {
return err
}
@@ -252,7 +255,7 @@ func (by *Bybit) GetSubscriptionTemplate(_ *subscription.Subscription) (*templat
}
// wsReadData receives and passes on websocket messages for processing
func (by *Bybit) wsReadData(assetType asset.Item, ws websocket.Connection) {
func (by *Bybit) wsReadData(ctx context.Context, assetType asset.Item, ws websocket.Connection) {
defer by.Websocket.Wg.Done()
for {
select {
@@ -263,7 +266,7 @@ func (by *Bybit) wsReadData(assetType asset.Item, ws websocket.Connection) {
if resp.Raw == nil {
return
}
err := by.wsHandleData(assetType, resp.Raw)
err := by.wsHandleData(ctx, assetType, resp.Raw)
if err != nil {
by.Websocket.DataHandler <- err
}
@@ -271,7 +274,7 @@ func (by *Bybit) wsReadData(assetType asset.Item, ws websocket.Connection) {
}
}
func (by *Bybit) wsHandleData(assetType asset.Item, respRaw []byte) error {
func (by *Bybit) wsHandleData(ctx context.Context, assetType asset.Item, respRaw []byte) error {
var result WebsocketResponse
err := json.Unmarshal(respRaw, &result)
if err != nil {
@@ -322,7 +325,7 @@ func (by *Bybit) wsHandleData(assetType asset.Item, respRaw []byte) error {
case chanOrder:
return by.wsProcessOrder(asset.Spot, &result)
case chanWallet:
return by.wsProcessWalletPushData(asset.Spot, respRaw)
return by.wsProcessWalletPushData(ctx, asset.Spot, respRaw)
case chanGreeks:
return by.wsProcessGreeks(respRaw)
case chanDCP:
@@ -341,13 +344,13 @@ func (by *Bybit) wsProcessGreeks(resp []byte) error {
return nil
}
func (by *Bybit) wsProcessWalletPushData(assetType asset.Item, resp []byte) error {
func (by *Bybit) wsProcessWalletPushData(ctx context.Context, assetType asset.Item, resp []byte) error {
var result WebsocketWallet
err := json.Unmarshal(resp, &result)
if err != nil {
return err
}
creds, err := by.GetCredentials(context.TODO())
creds, err := by.GetCredentials(ctx)
if err != nil {
return err
}