mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-03 07:26:45 +00:00
stream: force subscription store check as stop gap for wrapper side implementation (#1717)
* Add check for subscription store insertion and validation with tests * bybit: fix subscriptions * deribit: fix subscriptions * linter: fix * glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly * glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement * spelling: fix * glorious/nit: silly billy * gk: nits * gk/nits: split PartitionByPresence into Contained and Missing * gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair * stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths * stream/websocket: rm nil assignment * glorious: nits * gk: niterinos * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * thrasher: nits --------- Co-authored-by: shazbert <ryan.oharareid@thrasher.io> Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
@@ -22,16 +22,17 @@ const jobBuffer = 5000
|
||||
|
||||
// Public websocket errors
|
||||
var (
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
|
||||
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
)
|
||||
|
||||
// Private websocket errors
|
||||
@@ -372,6 +373,10 @@ func (w *Websocket) connect() error {
|
||||
if err := w.SubscribeToChannels(nil, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := w.subscriptions.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -385,6 +390,9 @@ func (w *Websocket) connect() error {
|
||||
// be shutdown and the websocket to be disconnected.
|
||||
var multiConnectFatalError error
|
||||
|
||||
// subscriptionError is a non-fatal error that does not shutdown connections
|
||||
var subscriptionError error
|
||||
|
||||
// TODO: Implement concurrency below.
|
||||
for i := range w.connectionManager {
|
||||
if w.connectionManager[i].Setup.GenerateSubscriptions == nil {
|
||||
@@ -443,16 +451,20 @@ func (w *Websocket) connect() error {
|
||||
if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() {
|
||||
err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn)
|
||||
if err != nil {
|
||||
// Opted to not fail entirely here for POC. This should be
|
||||
// revisited and handled more gracefully.
|
||||
log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err)
|
||||
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs)
|
||||
if err != nil {
|
||||
multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)
|
||||
break
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 {
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing))
|
||||
continue
|
||||
}
|
||||
|
||||
if w.verbose {
|
||||
@@ -496,7 +508,7 @@ func (w *Websocket) connect() error {
|
||||
go w.monitorFrame(nil, w.monitorConnection)
|
||||
}
|
||||
|
||||
return nil
|
||||
return subscriptionError
|
||||
}
|
||||
|
||||
// Disable disables the exchange websocket protocol
|
||||
@@ -625,14 +637,7 @@ func (w *Websocket) FlushChannels() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subs, unsubs := w.GetChannelDifference(nil, newSubs)
|
||||
if err := w.UnsubscribeChannels(nil, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(subs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return w.SubscribeToChannels(nil, subs)
|
||||
return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs)
|
||||
}
|
||||
|
||||
for x := range w.connectionManager {
|
||||
@@ -658,17 +663,9 @@ func (w *Websocket) FlushChannels() error {
|
||||
w.connectionManager[x].Connection = conn
|
||||
}
|
||||
|
||||
subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs)
|
||||
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
|
||||
@@ -683,6 +680,31 @@ func (w *Websocket) FlushChannels() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
|
||||
// have been subscribed to or unsubscribed from.
|
||||
func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
|
||||
subs, unsubs := store.Diff(incoming)
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(c, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if contained := store.Contained(unsubs); len(contained) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained)
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(c, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := store.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Websocket) setState(s uint32) {
|
||||
w.state.Store(s)
|
||||
}
|
||||
@@ -830,21 +852,6 @@ func (w *Websocket) GetName() string {
|
||||
return w.exchangeName
|
||||
}
|
||||
|
||||
// GetChannelDifference finds the difference between the subscribed channels
|
||||
// and the new subscription list when pairs are disabled or enabled.
|
||||
func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) {
|
||||
var subscriptionStore **subscription.Store
|
||||
if wrapper, ok := w.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = &wrapper.Subscriptions
|
||||
} else {
|
||||
subscriptionStore = &w.subscriptions
|
||||
}
|
||||
if *subscriptionStore == nil {
|
||||
*subscriptionStore = subscription.NewStore()
|
||||
}
|
||||
return (*subscriptionStore).Diff(newSubs)
|
||||
}
|
||||
|
||||
// UnsubscribeChannels unsubscribes from a list of websocket channel
|
||||
func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error {
|
||||
if len(channels) == 0 {
|
||||
@@ -855,6 +862,11 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L
|
||||
return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels)
|
||||
})
|
||||
}
|
||||
|
||||
if w.Unsubscriber == nil {
|
||||
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error {
|
||||
return w.Unsubscriber(channels)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user