stream: force subscription store check as stop gap for wrapper side implementation (#1717)

* Add check for subscription store insertion and validation with tests

* bybit: fix subscriptions

* deribit: fix subscriptions

* linter: fix

* glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly

* glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement

* spelling: fix

* glorious/nit: silly billy

* gk: nits

* gk/nits: split PartitionByPresence into Contained and Missing

* gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair

* stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths

* stream/websocket: rm nil assignment

* glorious: nits

* gk: niterinos

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* thrasher: nits

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2025-02-24 15:25:49 +11:00
committed by GitHub
parent ef0f398455
commit 3a80cd2871
8 changed files with 376 additions and 181 deletions

View File

@@ -22,16 +22,17 @@ const jobBuffer = 5000
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
ErrRequestRouteNotFound = errors.New("request route not found")
ErrSignatureNotSet = errors.New("signature not set")
ErrRequestPayloadNotSet = errors.New("request payload not set")
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
ErrRequestRouteNotFound = errors.New("request route not found")
ErrSignatureNotSet = errors.New("signature not set")
ErrRequestPayloadNotSet = errors.New("request payload not set")
)
// Private websocket errors
@@ -372,6 +373,10 @@ func (w *Websocket) connect() error {
if err := w.SubscribeToChannels(nil, subs); err != nil {
return err
}
if missing := w.subscriptions.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}
@@ -385,6 +390,9 @@ func (w *Websocket) connect() error {
// be shutdown and the websocket to be disconnected.
var multiConnectFatalError error
// subscriptionError is a non-fatal error that does not shutdown connections
var subscriptionError error
// TODO: Implement concurrency below.
for i := range w.connectionManager {
if w.connectionManager[i].Setup.GenerateSubscriptions == nil {
@@ -443,16 +451,20 @@ func (w *Websocket) connect() error {
if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() {
err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn)
if err != nil {
// Opted to not fail entirely here for POC. This should be
// revisited and handled more gracefully.
log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err)
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err)
break
}
}
err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs)
if err != nil {
multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)
break
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err))
continue
}
if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 {
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing))
continue
}
if w.verbose {
@@ -496,7 +508,7 @@ func (w *Websocket) connect() error {
go w.monitorFrame(nil, w.monitorConnection)
}
return nil
return subscriptionError
}
// Disable disables the exchange websocket protocol
@@ -625,14 +637,7 @@ func (w *Websocket) FlushChannels() error {
if err != nil {
return err
}
subs, unsubs := w.GetChannelDifference(nil, newSubs)
if err := w.UnsubscribeChannels(nil, unsubs); err != nil {
return err
}
if len(subs) == 0 {
return nil
}
return w.SubscribeToChannels(nil, subs)
return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs)
}
for x := range w.connectionManager {
@@ -658,17 +663,9 @@ func (w *Websocket) FlushChannels() error {
w.connectionManager[x].Connection = conn
}
subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs)
if len(unsubs) != 0 {
if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil {
return err
}
}
if len(subs) != 0 {
if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil {
return err
}
err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs)
if err != nil {
return err
}
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
@@ -683,6 +680,31 @@ func (w *Websocket) FlushChannels() error {
return nil
}
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
// have been subscribed to or unsubscribed from.
func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
subs, unsubs := store.Diff(incoming)
if len(unsubs) != 0 {
if err := w.UnsubscribeChannels(c, unsubs); err != nil {
return err
}
if contained := store.Contained(unsubs); len(contained) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained)
}
}
if len(subs) != 0 {
if err := w.SubscribeToChannels(c, subs); err != nil {
return err
}
if missing := store.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}
func (w *Websocket) setState(s uint32) {
w.state.Store(s)
}
@@ -830,21 +852,6 @@ func (w *Websocket) GetName() string {
return w.exchangeName
}
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) {
var subscriptionStore **subscription.Store
if wrapper, ok := w.connections[conn]; ok && conn != nil {
subscriptionStore = &wrapper.Subscriptions
} else {
subscriptionStore = &w.subscriptions
}
if *subscriptionStore == nil {
*subscriptionStore = subscription.NewStore()
}
return (*subscriptionStore).Diff(newSubs)
}
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error {
if len(channels) == 0 {
@@ -855,6 +862,11 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L
return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels)
})
}
if w.Unsubscriber == nil {
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
}
return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error {
return w.Unsubscriber(channels)
})