mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-28 07:26:57 +00:00
stream: force subscription store check as stop gap for wrapper side implementation (#1717)
* Add check for subscription store insertion and validation with tests * bybit: fix subscriptions * deribit: fix subscriptions * linter: fix * glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly * glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement * spelling: fix * glorious/nit: silly billy * gk: nits * gk/nits: split PartitionByPresence into Contained and Missing * gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair * stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths * stream/websocket: rm nil assignment * glorious: nits * gk: niterinos * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * thrasher: nits --------- Co-authored-by: shazbert <ryan.oharareid@thrasher.io> Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
@@ -22,16 +22,17 @@ const jobBuffer = 5000
|
||||
|
||||
// Public websocket errors
|
||||
var (
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
|
||||
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
)
|
||||
|
||||
// Private websocket errors
|
||||
@@ -372,6 +373,10 @@ func (w *Websocket) connect() error {
|
||||
if err := w.SubscribeToChannels(nil, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := w.subscriptions.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -385,6 +390,9 @@ func (w *Websocket) connect() error {
|
||||
// be shutdown and the websocket to be disconnected.
|
||||
var multiConnectFatalError error
|
||||
|
||||
// subscriptionError is a non-fatal error that does not shutdown connections
|
||||
var subscriptionError error
|
||||
|
||||
// TODO: Implement concurrency below.
|
||||
for i := range w.connectionManager {
|
||||
if w.connectionManager[i].Setup.GenerateSubscriptions == nil {
|
||||
@@ -443,16 +451,20 @@ func (w *Websocket) connect() error {
|
||||
if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() {
|
||||
err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn)
|
||||
if err != nil {
|
||||
// Opted to not fail entirely here for POC. This should be
|
||||
// revisited and handled more gracefully.
|
||||
log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err)
|
||||
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs)
|
||||
if err != nil {
|
||||
multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)
|
||||
break
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 {
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing))
|
||||
continue
|
||||
}
|
||||
|
||||
if w.verbose {
|
||||
@@ -496,7 +508,7 @@ func (w *Websocket) connect() error {
|
||||
go w.monitorFrame(nil, w.monitorConnection)
|
||||
}
|
||||
|
||||
return nil
|
||||
return subscriptionError
|
||||
}
|
||||
|
||||
// Disable disables the exchange websocket protocol
|
||||
@@ -625,14 +637,7 @@ func (w *Websocket) FlushChannels() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subs, unsubs := w.GetChannelDifference(nil, newSubs)
|
||||
if err := w.UnsubscribeChannels(nil, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(subs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return w.SubscribeToChannels(nil, subs)
|
||||
return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs)
|
||||
}
|
||||
|
||||
for x := range w.connectionManager {
|
||||
@@ -658,17 +663,9 @@ func (w *Websocket) FlushChannels() error {
|
||||
w.connectionManager[x].Connection = conn
|
||||
}
|
||||
|
||||
subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs)
|
||||
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
|
||||
@@ -683,6 +680,31 @@ func (w *Websocket) FlushChannels() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
|
||||
// have been subscribed to or unsubscribed from.
|
||||
func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
|
||||
subs, unsubs := store.Diff(incoming)
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(c, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if contained := store.Contained(unsubs); len(contained) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained)
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(c, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := store.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Websocket) setState(s uint32) {
|
||||
w.state.Store(s)
|
||||
}
|
||||
@@ -830,21 +852,6 @@ func (w *Websocket) GetName() string {
|
||||
return w.exchangeName
|
||||
}
|
||||
|
||||
// GetChannelDifference finds the difference between the subscribed channels
|
||||
// and the new subscription list when pairs are disabled or enabled.
|
||||
func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) {
|
||||
var subscriptionStore **subscription.Store
|
||||
if wrapper, ok := w.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = &wrapper.Subscriptions
|
||||
} else {
|
||||
subscriptionStore = &w.subscriptions
|
||||
}
|
||||
if *subscriptionStore == nil {
|
||||
*subscriptionStore = subscription.NewStore()
|
||||
}
|
||||
return (*subscriptionStore).Diff(newSubs)
|
||||
}
|
||||
|
||||
// UnsubscribeChannels unsubscribes from a list of websocket channel
|
||||
func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error {
|
||||
if len(channels) == 0 {
|
||||
@@ -855,6 +862,11 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L
|
||||
return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels)
|
||||
})
|
||||
}
|
||||
|
||||
if w.Unsubscriber == nil {
|
||||
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error {
|
||||
return w.Unsubscriber(channels)
|
||||
})
|
||||
|
||||
@@ -63,31 +63,33 @@ type testSubKey struct {
|
||||
Mood string
|
||||
}
|
||||
|
||||
var defaultSetup = &WebsocketSetup{
|
||||
ExchangeConfig: &config.Exchange{
|
||||
Features: &config.FeaturesConfig{
|
||||
Enabled: config.FeaturesEnabledConfig{Websocket: true},
|
||||
func newDefaultSetup() *WebsocketSetup {
|
||||
return &WebsocketSetup{
|
||||
ExchangeConfig: &config.Exchange{
|
||||
Features: &config.FeaturesConfig{
|
||||
Enabled: config.FeaturesEnabledConfig{Websocket: true},
|
||||
},
|
||||
API: config.APIConfig{
|
||||
AuthenticatedWebsocketSupport: true,
|
||||
},
|
||||
WebsocketTrafficTimeout: time.Second * 5,
|
||||
Name: "GTX",
|
||||
},
|
||||
API: config.APIConfig{
|
||||
AuthenticatedWebsocketSupport: true,
|
||||
DefaultURL: "testDefaultURL",
|
||||
RunningURL: "wss://testRunningURL",
|
||||
Connector: func() error { return nil },
|
||||
Subscriber: func(subscription.List) error { return nil },
|
||||
Unsubscriber: func(subscription.List) error { return nil },
|
||||
GenerateSubscriptions: func() (subscription.List, error) {
|
||||
return subscription.List{
|
||||
{Channel: "TestSub"},
|
||||
{Channel: "TestSub2", Key: "purple"},
|
||||
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
|
||||
{Channel: "TestSub4", Key: 42},
|
||||
}, nil
|
||||
},
|
||||
WebsocketTrafficTimeout: time.Second * 5,
|
||||
Name: "GTX",
|
||||
},
|
||||
DefaultURL: "testDefaultURL",
|
||||
RunningURL: "wss://testRunningURL",
|
||||
Connector: func() error { return nil },
|
||||
Subscriber: func(subscription.List) error { return nil },
|
||||
Unsubscriber: func(subscription.List) error { return nil },
|
||||
GenerateSubscriptions: func() (subscription.List, error) {
|
||||
return subscription.List{
|
||||
{Channel: "TestSub"},
|
||||
{Channel: "TestSub2", Key: "purple"},
|
||||
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
|
||||
{Channel: "TestSub4", Key: 42},
|
||||
}, nil
|
||||
},
|
||||
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
|
||||
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
@@ -186,13 +188,23 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
|
||||
|
||||
ws := NewWebsocket()
|
||||
err = ws.Setup(defaultSetup)
|
||||
err = ws.Setup(newDefaultSetup())
|
||||
require.NoError(t, err, "Setup must not error")
|
||||
ws.trafficTimeout = time.Minute
|
||||
ws.connector = connect
|
||||
|
||||
err = ws.Connect()
|
||||
require.NoError(t, err, "Connect must not error")
|
||||
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := ws.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, ws.Connect(), "Connect must not error")
|
||||
|
||||
checkToRoutineResult := func(t *testing.T) {
|
||||
t.Helper()
|
||||
@@ -240,7 +252,7 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
|
||||
ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) {
|
||||
return subscription.List{{}}, nil
|
||||
return subscription.List{{Channel: "test"}}, nil
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errNoConnectFunc)
|
||||
@@ -275,10 +287,26 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return errDastardlyReason
|
||||
}
|
||||
ws.connectionManager[0].Setup.Authenticate = nil
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
require.NoError(t, ws.shutdown())
|
||||
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return nil
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.shutdown())
|
||||
|
||||
ws.connectionManager[0].Subscriptions = subscription.NewStore()
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"})
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test"))
|
||||
@@ -297,6 +325,7 @@ func TestWebsocket(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")
|
||||
|
||||
ws.setEnabled(true)
|
||||
defaultSetup := newDefaultSetup()
|
||||
err = ws.Setup(defaultSetup) // Sets to enabled again
|
||||
require.NoError(t, err, "Setup may not error")
|
||||
|
||||
@@ -340,8 +369,19 @@ func TestWebsocket(t *testing.T) {
|
||||
assert.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.connector = func() error { return nil }
|
||||
err = ws.Connect()
|
||||
assert.NoError(t, err, "Connect should not error")
|
||||
|
||||
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := ws.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, ws.Connect(), "Connect should not error")
|
||||
|
||||
ws.defaultURL = "ws://demos.kaazing.com/echo"
|
||||
ws.defaultURLAuth = "ws://demos.kaazing.com/echo"
|
||||
@@ -407,14 +447,14 @@ func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscr
|
||||
func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewWebsocket()
|
||||
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
|
||||
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
ws.Unsubscriber = currySimpleUnsub(ws)
|
||||
|
||||
subs, err := ws.GenerateSubs()
|
||||
require.NoError(t, err, "Generating test subscriptions should not error")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
@@ -447,9 +487,9 @@ func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
|
||||
|
||||
multi := NewWebsocket()
|
||||
set := *defaultSetup
|
||||
set := newDefaultSetup()
|
||||
set.UseMultiConnectionManagement = true
|
||||
assert.NoError(t, multi.Setup(&set))
|
||||
assert.NoError(t, multi.Setup(set))
|
||||
|
||||
amazingCandidate := &ConnectionSetup{
|
||||
URL: "AMAZING",
|
||||
@@ -472,8 +512,8 @@ func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
|
||||
subs, err = amazingCandidate.GenerateSubscriptions()
|
||||
require.NoError(t, err, "Generating test subscriptions should not error")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
@@ -514,12 +554,12 @@ func TestResubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewWebsocket()
|
||||
|
||||
wackedOutSetup := *defaultSetup
|
||||
wackedOutSetup := newDefaultSetup()
|
||||
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
|
||||
err := ws.Setup(&wackedOutSetup)
|
||||
err := ws.Setup(wackedOutSetup)
|
||||
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
|
||||
|
||||
err = ws.Setup(defaultSetup)
|
||||
err = ws.Setup(newDefaultSetup())
|
||||
assert.NoError(t, err, "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
@@ -992,52 +1032,6 @@ func TestCheckWebsocketURL(t *testing.T) {
|
||||
assert.NoError(t, err, "checkWebsocketURL should not error")
|
||||
}
|
||||
|
||||
// TestGetChannelDifference exercises GetChannelDifference
|
||||
// See subscription.TestStoreDiff for further testing
|
||||
func TestGetChannelDifference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := &Websocket{}
|
||||
assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store")
|
||||
subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error")
|
||||
subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}})
|
||||
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
|
||||
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
|
||||
|
||||
w = &Websocket{}
|
||||
sweetConn := &WebsocketConnection{}
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
w.connections = map[Connection]*ConnectionWrapper{
|
||||
sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}},
|
||||
}
|
||||
|
||||
naughtyConn := &WebsocketConnection{}
|
||||
subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel})
|
||||
require.NoError(t, err)
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Empty(t, subs, "Should get no subs")
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, nil)
|
||||
require.Empty(t, subs, "Should get no subs")
|
||||
require.Equal(t, 1, len(unsubs))
|
||||
}
|
||||
|
||||
// GenSubs defines a theoretical exchange with pair management
|
||||
type GenSubs struct {
|
||||
EnabledPairs currency.Pairs
|
||||
@@ -1111,16 +1105,37 @@ func TestFlushChannels(t *testing.T) {
|
||||
// Disable pair and flush system
|
||||
newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)}
|
||||
w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil }
|
||||
err = w.FlushChannels()
|
||||
require.NoError(t, err, "Flush Channels must not error")
|
||||
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added")
|
||||
|
||||
w.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := w.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
|
||||
w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs
|
||||
err = w.FlushChannels() // error on full subscribeToChannels
|
||||
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
|
||||
|
||||
w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub
|
||||
err = w.FlushChannels() // No subs to sub
|
||||
assert.NoError(t, err, "Flush Channels should not error")
|
||||
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved)
|
||||
|
||||
w.Unsubscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := w.subscriptions.Remove(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, w.FlushChannels(), "FlushChannels should not error")
|
||||
|
||||
w.GenerateSubs = newgen.generateSubs
|
||||
subs, err := w.GenerateSubs()
|
||||
@@ -1156,21 +1171,24 @@ func TestFlushChannels(t *testing.T) {
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
|
||||
defer mock.Close()
|
||||
|
||||
w.subscriptions = subscription.NewStore()
|
||||
|
||||
amazingCandidate := &ConnectionSetup{
|
||||
URL: "ws" + mock.URL[len("http"):] + "/ws",
|
||||
Connector: func(ctx context.Context, conn Connection) error {
|
||||
return conn.DialContext(ctx, websocket.DefaultDialer, nil)
|
||||
},
|
||||
GenerateSubscriptions: newgen.generateSubs,
|
||||
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleSubConn(w)(ctx, c, s)
|
||||
},
|
||||
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleUnsubConn(w)(ctx, c, s)
|
||||
},
|
||||
Handler: func(context.Context, []byte) error { return nil },
|
||||
Subscriber: func(context.Context, Connection, subscription.List) error { return nil },
|
||||
Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil },
|
||||
Handler: func(context.Context, []byte) error { return nil },
|
||||
}
|
||||
require.NoError(t, w.SetupNewConnection(amazingCandidate))
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store")
|
||||
|
||||
w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleSubConn(w)(ctx, c, s)
|
||||
}
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
|
||||
// Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines.
|
||||
@@ -1181,6 +1199,11 @@ func TestFlushChannels(t *testing.T) {
|
||||
// of the connection from management.
|
||||
w.features.Subscribe = true
|
||||
w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store")
|
||||
|
||||
w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleUnsubConn(w)(ctx, c, s)
|
||||
}
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
}
|
||||
|
||||
@@ -1224,7 +1247,7 @@ func TestSetupNewConnection(t *testing.T) {
|
||||
|
||||
web := NewWebsocket()
|
||||
|
||||
err = web.Setup(defaultSetup)
|
||||
err = web.Setup(newDefaultSetup())
|
||||
assert.NoError(t, err, "Setup should not error")
|
||||
|
||||
err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"})
|
||||
@@ -1235,9 +1258,9 @@ func TestSetupNewConnection(t *testing.T) {
|
||||
|
||||
// Test connection candidates for multi connection tracking.
|
||||
multi := NewWebsocket()
|
||||
set := *defaultSetup
|
||||
set := newDefaultSetup()
|
||||
set.UseMultiConnectionManagement = true
|
||||
require.NoError(t, multi.Setup(&set))
|
||||
require.NoError(t, multi.Setup(set))
|
||||
|
||||
err = multi.SetupNewConnection(nil)
|
||||
require.ErrorIs(t, err, errExchangeConfigEmpty)
|
||||
@@ -1566,3 +1589,43 @@ func TestGetConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, expected, conn)
|
||||
}
|
||||
|
||||
func TestUpdateChannelSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ws := Websocket{}
|
||||
store := subscription.NewStore()
|
||||
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
require.Zero(t, store.Len())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ws.subscriptions = store
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, store.Len())
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
ws.Unsubscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Remove(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, store.Len())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user