stream: force subscription store check as stop gap for wrapper side implementation (#1717)

* Add check for subscription store insertion and validation with tests

* bybit: fix subscriptions

* deribit: fix subscriptions

* linter: fix

* glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly

* glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement

* spelling: fix

* glorious/nit: silly billy

* gk: nits

* gk/nits: split PartitionByPresence into Contained and Missing

* gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair

* stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths

* stream/websocket: rm nil assignment

* glorious: nits

* gk: niterinos

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* thrasher: nits

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2025-02-24 15:25:49 +11:00
committed by GitHub
parent ef0f398455
commit 3a80cd2871
8 changed files with 376 additions and 181 deletions

View File

@@ -63,31 +63,33 @@ type testSubKey struct {
Mood string
}
var defaultSetup = &WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
func newDefaultSetup() *WebsocketSetup {
return &WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
},
API: config.APIConfig{
AuthenticatedWebsocketSupport: true,
},
WebsocketTrafficTimeout: time.Second * 5,
Name: "GTX",
},
API: config.APIConfig{
AuthenticatedWebsocketSupport: true,
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
WebsocketTrafficTimeout: time.Second * 5,
Name: "GTX",
},
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
}
}
func TestSetup(t *testing.T) {
@@ -186,13 +188,23 @@ func TestConnectionMessageErrors(t *testing.T) {
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
ws := NewWebsocket()
err = ws.Setup(defaultSetup)
err = ws.Setup(newDefaultSetup())
require.NoError(t, err, "Setup must not error")
ws.trafficTimeout = time.Minute
ws.connector = connect
err = ws.Connect()
require.NoError(t, err, "Connect must not error")
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
require.NoError(t, ws.Shutdown())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := ws.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
require.NoError(t, ws.Connect(), "Connect must not error")
checkToRoutineResult := func(t *testing.T) {
t.Helper()
@@ -240,7 +252,7 @@ func TestConnectionMessageErrors(t *testing.T) {
require.ErrorIs(t, err, errDastardlyReason)
ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) {
return subscription.List{{}}, nil
return subscription.List{{Channel: "test"}}, nil
}
err = ws.Connect()
require.ErrorIs(t, err, errNoConnectFunc)
@@ -275,10 +287,26 @@ func TestConnectionMessageErrors(t *testing.T) {
err = ws.Connect()
require.ErrorIs(t, err, errDastardlyReason)
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return errDastardlyReason
}
ws.connectionManager[0].Setup.Authenticate = nil
err = ws.Connect()
require.ErrorIs(t, err, errDastardlyReason)
require.NoError(t, ws.shutdown())
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return nil
}
err = ws.Connect()
require.ErrorIs(t, err, ErrSubscriptionsNotAdded)
require.NoError(t, ws.shutdown())
ws.connectionManager[0].Subscriptions = subscription.NewStore()
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"})
}
err = ws.Connect()
require.NoError(t, err)
err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test"))
@@ -297,6 +325,7 @@ func TestWebsocket(t *testing.T) {
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")
ws.setEnabled(true)
defaultSetup := newDefaultSetup()
err = ws.Setup(defaultSetup) // Sets to enabled again
require.NoError(t, err, "Setup may not error")
@@ -340,8 +369,19 @@ func TestWebsocket(t *testing.T) {
assert.NoError(t, ws.Shutdown())
ws.connector = func() error { return nil }
err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
require.NoError(t, ws.Shutdown())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := ws.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
assert.NoError(t, ws.Connect(), "Connect should not error")
ws.defaultURL = "ws://demos.kaazing.com/echo"
ws.defaultURLAuth = "ws://demos.kaazing.com/echo"
@@ -407,14 +447,14 @@ func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscr
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
@@ -447,9 +487,9 @@ func TestSubscribeUnsubscribe(t *testing.T) {
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
multi := NewWebsocket()
set := *defaultSetup
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
assert.NoError(t, multi.Setup(&set))
assert.NoError(t, multi.Setup(set))
amazingCandidate := &ConnectionSetup{
URL: "AMAZING",
@@ -472,8 +512,8 @@ func TestSubscribeUnsubscribe(t *testing.T) {
subs, err = amazingCandidate.GenerateSubscriptions()
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
@@ -514,12 +554,12 @@ func TestResubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
wackedOutSetup := *defaultSetup
wackedOutSetup := newDefaultSetup()
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
err := ws.Setup(&wackedOutSetup)
err := ws.Setup(wackedOutSetup)
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
err = ws.Setup(defaultSetup)
err = ws.Setup(newDefaultSetup())
assert.NoError(t, err, "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
@@ -992,52 +1032,6 @@ func TestCheckWebsocketURL(t *testing.T) {
assert.NoError(t, err, "checkWebsocketURL should not error")
}
// TestGetChannelDifference exercises GetChannelDifference
// See subscription.TestStoreDiff for further testing
func TestGetChannelDifference(t *testing.T) {
t.Parallel()
w := &Websocket{}
assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store")
subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
require.Empty(t, unsubs, "Should get no unsubs")
require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error")
subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
w = &Websocket{}
sweetConn := &WebsocketConnection{}
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
w.connections = map[Connection]*ConnectionWrapper{
sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}},
}
naughtyConn := &WebsocketConnection{}
subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel})
require.NoError(t, err)
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Empty(t, subs, "Should get no subs")
require.Empty(t, unsubs, "Should get no unsubs")
subs, unsubs = w.GetChannelDifference(sweetConn, nil)
require.Empty(t, subs, "Should get no subs")
require.Equal(t, 1, len(unsubs))
}
// GenSubs defines a theoretical exchange with pair management
type GenSubs struct {
EnabledPairs currency.Pairs
@@ -1111,16 +1105,37 @@ func TestFlushChannels(t *testing.T) {
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)}
w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil }
err = w.FlushChannels()
require.NoError(t, err, "Flush Channels must not error")
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added")
w.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := w.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs
err = w.FlushChannels() // error on full subscribeToChannels
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub
err = w.FlushChannels() // No subs to sub
assert.NoError(t, err, "Flush Channels should not error")
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved)
w.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := w.subscriptions.Remove(sub); err != nil {
return err
}
}
return nil
}
assert.NoError(t, w.FlushChannels(), "FlushChannels should not error")
w.GenerateSubs = newgen.generateSubs
subs, err := w.GenerateSubs()
@@ -1156,21 +1171,24 @@ func TestFlushChannels(t *testing.T) {
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
defer mock.Close()
w.subscriptions = subscription.NewStore()
amazingCandidate := &ConnectionSetup{
URL: "ws" + mock.URL[len("http"):] + "/ws",
Connector: func(ctx context.Context, conn Connection) error {
return conn.DialContext(ctx, websocket.DefaultDialer, nil)
},
GenerateSubscriptions: newgen.generateSubs,
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(w)(ctx, c, s)
},
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(w)(ctx, c, s)
},
Handler: func(context.Context, []byte) error { return nil },
Subscriber: func(context.Context, Connection, subscription.List) error { return nil },
Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil },
Handler: func(context.Context, []byte) error { return nil },
}
require.NoError(t, w.SetupNewConnection(amazingCandidate))
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store")
w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(w)(ctx, c, s)
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
// Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines.
@@ -1181,6 +1199,11 @@ func TestFlushChannels(t *testing.T) {
// of the connection from management.
w.features.Subscribe = true
w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store")
w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(w)(ctx, c, s)
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
}
@@ -1224,7 +1247,7 @@ func TestSetupNewConnection(t *testing.T) {
web := NewWebsocket()
err = web.Setup(defaultSetup)
err = web.Setup(newDefaultSetup())
assert.NoError(t, err, "Setup should not error")
err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"})
@@ -1235,9 +1258,9 @@ func TestSetupNewConnection(t *testing.T) {
// Test connection candidates for multi connection tracking.
multi := NewWebsocket()
set := *defaultSetup
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
require.NoError(t, multi.Setup(&set))
require.NoError(t, multi.Setup(set))
err = multi.SetupNewConnection(nil)
require.ErrorIs(t, err, errExchangeConfigEmpty)
@@ -1566,3 +1589,43 @@ func TestGetConnection(t *testing.T) {
require.NoError(t, err)
assert.Same(t, expected, conn)
}
func TestUpdateChannelSubscriptions(t *testing.T) {
t.Parallel()
ws := Websocket{}
store := subscription.NewStore()
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.ErrorIs(t, err, common.ErrNilPointer)
require.Zero(t, store.Len())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Add(sub); err != nil {
return err
}
}
return nil
}
ws.subscriptions = store
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.NoError(t, err)
require.Equal(t, 1, store.Len())
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.ErrorIs(t, err, common.ErrNilPointer)
ws.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Remove(sub); err != nil {
return err
}
}
return nil
}
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.NoError(t, err)
require.Zero(t, store.Len())
}