stream: force subscription store check as stop gap for wrapper side implementation (#1717)

* Add check for subscription store insertion and validation with tests

* bybit: fix subscriptions

* deribit: fix subscriptions

* linter: fix

* glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly

* glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement

* spelling: fix

* glorious/nit: silly billy

* gk: nits

* gk/nits: split PartitionByPresence into Contained and Missing

* gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair

* stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths

* stream/websocket: rm nil assignment

* glorious: nits

* gk: niterinos

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* Update exchanges/stream/websocket_test.go

Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>

* thrasher: nits

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2025-02-24 15:25:49 +11:00
committed by GitHub
parent ef0f398455
commit 3a80cd2871
8 changed files with 376 additions and 181 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/encoding/json"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/types"
)
@@ -33,10 +34,11 @@ type Authenticate struct {
// SubscriptionArgument represents a subscription arguments.
type SubscriptionArgument struct {
auth bool `json:"-"`
RequestID string `json:"req_id"`
Operation string `json:"op"`
Arguments []string `json:"args"`
auth bool `json:"-"`
RequestID string `json:"req_id"`
Operation string `json:"op"`
Arguments []string `json:"args"`
associatedSubs subscription.List `json:"-"`
}
// Fee holds fee information

View File

@@ -167,30 +167,19 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (
if err != nil {
return
}
chans := []string{}
authChans := []string{}
for _, s := range subs {
if s.Authenticated {
authChans = append(authChans, s.QualifiedChannel)
} else {
chans = append(chans, s.QualifiedChannel)
for _, list := range []subscription.List{subs.Public(), subs.Private()} {
for _, b := range common.Batch(list, 10) {
args = append(args, SubscriptionArgument{
auth: b[0].Authenticated,
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: b.QualifiedChannels(),
associatedSubs: b,
})
}
}
for _, b := range common.Batch(chans, 10) {
args = append(args, SubscriptionArgument{
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: b,
})
}
if len(authChans) != 0 {
args = append(args, SubscriptionArgument{
auth: true,
Operation: operation,
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
Arguments: authChans,
})
}
return
}
@@ -225,6 +214,22 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
if !resp.Success {
return fmt.Errorf("%s with request ID %s msg: %s", resp.Operation, resp.RequestID, resp.RetMsg)
}
var conn stream.Connection
if payloads[a].auth {
conn = by.Websocket.AuthConn
} else {
conn = by.Websocket.Conn
}
if operation == "unsubscribe" {
err = by.Websocket.RemoveSubscriptions(conn, payloads[a].associatedSubs...)
} else {
err = by.Websocket.AddSubscriptions(conn, payloads[a].associatedSubs...)
}
if err != nil {
return err
}
}
return nil
}

View File

@@ -4094,3 +4094,14 @@ func TestGetCurrencyTradeURL(t *testing.T) {
require.NoError(t, err)
assert.NotEmpty(t, resp)
}
func TestFormatChannelPair(t *testing.T) {
t.Parallel()
pair := currency.NewPair(currency.BTC, currency.NewCode("USDC-PERPETUAL"))
pair.Delimiter = "-"
assert.Equal(t, "BTC_USDC-PERPETUAL", formatChannelPair(pair))
pair = currency.NewPair(currency.BTC, currency.NewCode("PERPETUAL"))
pair.Delimiter = "-"
assert.Equal(t, "BTC-PERPETUAL", formatChannelPair(pair))
}

View File

@@ -779,6 +779,7 @@ func (d *Deribit) GetSubscriptionTemplate(_ *subscription.Subscription) (*templa
"channelName": channelName,
"interval": channelInterval,
"isSymbolChannel": isSymbolChannel,
"fmt": formatChannelPair,
}).
Parse(subTplText)
}
@@ -805,18 +806,19 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro
if err != nil || len(subs) == 0 {
return err
}
r := WsSubscriptionInput{
JSONRPCVersion: rpcVersion,
ID: d.Websocket.Conn.GenerateMessageID(false),
Method: method,
Params: map[string][]string{
"channels": subs.QualifiedChannels(),
},
Params: map[string][]string{"channels": subs.QualifiedChannels()},
}
data, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, r.ID, r)
if err != nil {
return err
}
var response wsSubscriptionResponse
err = json.Unmarshal(data, &response)
if err != nil {
@@ -827,15 +829,25 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro
subAck[c] = true
}
if len(subAck) != len(subs) {
err = common.ErrUnknownError
err = stream.ErrSubscriptionFailure
}
for _, s := range subs {
if _, ok := subAck[s.QualifiedChannel]; ok {
err = common.AppendError(err, d.Websocket.AddSuccessfulSubscriptions(d.Websocket.Conn, s))
delete(subAck, s.QualifiedChannel)
if !strings.Contains(method, "unsubscribe") {
err = common.AppendError(err, d.Websocket.AddSuccessfulSubscriptions(d.Websocket.Conn, s))
} else {
err = common.AppendError(err, d.Websocket.RemoveSubscriptions(d.Websocket.Conn, s))
}
} else {
err = common.AppendError(err, errors.New(s.String()))
err = common.AppendError(err, errors.New(s.String()+" failed to "+method))
}
}
for key := range subAck {
err = common.AppendError(err, fmt.Errorf("unexpected channel `%s` in result", key))
}
return err
}
@@ -899,11 +911,18 @@ func isSymbolChannel(s *subscription.Subscription) bool {
return false
}
func formatChannelPair(pair currency.Pair) string {
if str := pair.Quote.String(); strings.Contains(str, "PERPETUAL") && strings.Contains(str, "-") {
pair.Delimiter = "_"
}
return pair.String()
}
const subTplText = `
{{- if isSymbolChannel $.S -}}
{{- range $asset, $pairs := $.AssetPairs }}
{{- range $p := $pairs }}
{{- channelName $.S -}} . {{- $p }}
{{- channelName $.S -}} . {{- fmt $p }}
{{- with $i := interval $.S -}} . {{- $i }}{{ end }}
{{- $.PairSeparator }}
{{- end }}

View File

@@ -22,16 +22,17 @@ const jobBuffer = 5000
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
ErrRequestRouteNotFound = errors.New("request route not found")
ErrSignatureNotSet = errors.New("signature not set")
ErrRequestPayloadNotSet = errors.New("request payload not set")
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
ErrRequestRouteNotFound = errors.New("request route not found")
ErrSignatureNotSet = errors.New("signature not set")
ErrRequestPayloadNotSet = errors.New("request payload not set")
)
// Private websocket errors
@@ -372,6 +373,10 @@ func (w *Websocket) connect() error {
if err := w.SubscribeToChannels(nil, subs); err != nil {
return err
}
if missing := w.subscriptions.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}
@@ -385,6 +390,9 @@ func (w *Websocket) connect() error {
// be shutdown and the websocket to be disconnected.
var multiConnectFatalError error
// subscriptionError is a non-fatal error that does not shutdown connections
var subscriptionError error
// TODO: Implement concurrency below.
for i := range w.connectionManager {
if w.connectionManager[i].Setup.GenerateSubscriptions == nil {
@@ -443,16 +451,20 @@ func (w *Websocket) connect() error {
if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() {
err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn)
if err != nil {
// Opted to not fail entirely here for POC. This should be
// revisited and handled more gracefully.
log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err)
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err)
break
}
}
err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs)
if err != nil {
multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)
break
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err))
continue
}
if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 {
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing))
continue
}
if w.verbose {
@@ -496,7 +508,7 @@ func (w *Websocket) connect() error {
go w.monitorFrame(nil, w.monitorConnection)
}
return nil
return subscriptionError
}
// Disable disables the exchange websocket protocol
@@ -625,14 +637,7 @@ func (w *Websocket) FlushChannels() error {
if err != nil {
return err
}
subs, unsubs := w.GetChannelDifference(nil, newSubs)
if err := w.UnsubscribeChannels(nil, unsubs); err != nil {
return err
}
if len(subs) == 0 {
return nil
}
return w.SubscribeToChannels(nil, subs)
return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs)
}
for x := range w.connectionManager {
@@ -658,17 +663,9 @@ func (w *Websocket) FlushChannels() error {
w.connectionManager[x].Connection = conn
}
subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs)
if len(unsubs) != 0 {
if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil {
return err
}
}
if len(subs) != 0 {
if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil {
return err
}
err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs)
if err != nil {
return err
}
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
@@ -683,6 +680,31 @@ func (w *Websocket) FlushChannels() error {
return nil
}
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
// have been subscribed to or unsubscribed from.
func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
subs, unsubs := store.Diff(incoming)
if len(unsubs) != 0 {
if err := w.UnsubscribeChannels(c, unsubs); err != nil {
return err
}
if contained := store.Contained(unsubs); len(contained) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained)
}
}
if len(subs) != 0 {
if err := w.SubscribeToChannels(c, subs); err != nil {
return err
}
if missing := store.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}
func (w *Websocket) setState(s uint32) {
w.state.Store(s)
}
@@ -830,21 +852,6 @@ func (w *Websocket) GetName() string {
return w.exchangeName
}
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) {
var subscriptionStore **subscription.Store
if wrapper, ok := w.connections[conn]; ok && conn != nil {
subscriptionStore = &wrapper.Subscriptions
} else {
subscriptionStore = &w.subscriptions
}
if *subscriptionStore == nil {
*subscriptionStore = subscription.NewStore()
}
return (*subscriptionStore).Diff(newSubs)
}
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error {
if len(channels) == 0 {
@@ -855,6 +862,11 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L
return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels)
})
}
if w.Unsubscriber == nil {
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
}
return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error {
return w.Unsubscriber(channels)
})

View File

@@ -63,31 +63,33 @@ type testSubKey struct {
Mood string
}
var defaultSetup = &WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
func newDefaultSetup() *WebsocketSetup {
return &WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
},
API: config.APIConfig{
AuthenticatedWebsocketSupport: true,
},
WebsocketTrafficTimeout: time.Second * 5,
Name: "GTX",
},
API: config.APIConfig{
AuthenticatedWebsocketSupport: true,
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
WebsocketTrafficTimeout: time.Second * 5,
Name: "GTX",
},
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
}
}
func TestSetup(t *testing.T) {
@@ -186,13 +188,23 @@ func TestConnectionMessageErrors(t *testing.T) {
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
ws := NewWebsocket()
err = ws.Setup(defaultSetup)
err = ws.Setup(newDefaultSetup())
require.NoError(t, err, "Setup must not error")
ws.trafficTimeout = time.Minute
ws.connector = connect
err = ws.Connect()
require.NoError(t, err, "Connect must not error")
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
require.NoError(t, ws.Shutdown())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := ws.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
require.NoError(t, ws.Connect(), "Connect must not error")
checkToRoutineResult := func(t *testing.T) {
t.Helper()
@@ -240,7 +252,7 @@ func TestConnectionMessageErrors(t *testing.T) {
require.ErrorIs(t, err, errDastardlyReason)
ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) {
return subscription.List{{}}, nil
return subscription.List{{Channel: "test"}}, nil
}
err = ws.Connect()
require.ErrorIs(t, err, errNoConnectFunc)
@@ -275,10 +287,26 @@ func TestConnectionMessageErrors(t *testing.T) {
err = ws.Connect()
require.ErrorIs(t, err, errDastardlyReason)
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return errDastardlyReason
}
ws.connectionManager[0].Setup.Authenticate = nil
err = ws.Connect()
require.ErrorIs(t, err, errDastardlyReason)
require.NoError(t, ws.shutdown())
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return nil
}
err = ws.Connect()
require.ErrorIs(t, err, ErrSubscriptionsNotAdded)
require.NoError(t, ws.shutdown())
ws.connectionManager[0].Subscriptions = subscription.NewStore()
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"})
}
err = ws.Connect()
require.NoError(t, err)
err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test"))
@@ -297,6 +325,7 @@ func TestWebsocket(t *testing.T) {
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")
ws.setEnabled(true)
defaultSetup := newDefaultSetup()
err = ws.Setup(defaultSetup) // Sets to enabled again
require.NoError(t, err, "Setup may not error")
@@ -340,8 +369,19 @@ func TestWebsocket(t *testing.T) {
assert.NoError(t, ws.Shutdown())
ws.connector = func() error { return nil }
err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
require.NoError(t, ws.Shutdown())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := ws.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
assert.NoError(t, ws.Connect(), "Connect should not error")
ws.defaultURL = "ws://demos.kaazing.com/echo"
ws.defaultURLAuth = "ws://demos.kaazing.com/echo"
@@ -407,14 +447,14 @@ func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscr
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
@@ -447,9 +487,9 @@ func TestSubscribeUnsubscribe(t *testing.T) {
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
multi := NewWebsocket()
set := *defaultSetup
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
assert.NoError(t, multi.Setup(&set))
assert.NoError(t, multi.Setup(set))
amazingCandidate := &ConnectionSetup{
URL: "AMAZING",
@@ -472,8 +512,8 @@ func TestSubscribeUnsubscribe(t *testing.T) {
subs, err = amazingCandidate.GenerateSubscriptions()
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
@@ -514,12 +554,12 @@ func TestResubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
wackedOutSetup := *defaultSetup
wackedOutSetup := newDefaultSetup()
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
err := ws.Setup(&wackedOutSetup)
err := ws.Setup(wackedOutSetup)
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
err = ws.Setup(defaultSetup)
err = ws.Setup(newDefaultSetup())
assert.NoError(t, err, "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
@@ -992,52 +1032,6 @@ func TestCheckWebsocketURL(t *testing.T) {
assert.NoError(t, err, "checkWebsocketURL should not error")
}
// TestGetChannelDifference exercises GetChannelDifference
// See subscription.TestStoreDiff for further testing
func TestGetChannelDifference(t *testing.T) {
t.Parallel()
w := &Websocket{}
assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store")
subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
require.Empty(t, unsubs, "Should get no unsubs")
require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error")
subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
w = &Websocket{}
sweetConn := &WebsocketConnection{}
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
w.connections = map[Connection]*ConnectionWrapper{
sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}},
}
naughtyConn := &WebsocketConnection{}
subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs))
require.Empty(t, unsubs, "Should get no unsubs")
err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel})
require.NoError(t, err)
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
require.Empty(t, subs, "Should get no subs")
require.Empty(t, unsubs, "Should get no unsubs")
subs, unsubs = w.GetChannelDifference(sweetConn, nil)
require.Empty(t, subs, "Should get no subs")
require.Equal(t, 1, len(unsubs))
}
// GenSubs defines a theoretical exchange with pair management
type GenSubs struct {
EnabledPairs currency.Pairs
@@ -1111,16 +1105,37 @@ func TestFlushChannels(t *testing.T) {
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)}
w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil }
err = w.FlushChannels()
require.NoError(t, err, "Flush Channels must not error")
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added")
w.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := w.subscriptions.Add(sub); err != nil {
return err
}
}
return nil
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs
err = w.FlushChannels() // error on full subscribeToChannels
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub
err = w.FlushChannels() // No subs to sub
assert.NoError(t, err, "Flush Channels should not error")
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved)
w.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := w.subscriptions.Remove(sub); err != nil {
return err
}
}
return nil
}
assert.NoError(t, w.FlushChannels(), "FlushChannels should not error")
w.GenerateSubs = newgen.generateSubs
subs, err := w.GenerateSubs()
@@ -1156,21 +1171,24 @@ func TestFlushChannels(t *testing.T) {
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
defer mock.Close()
w.subscriptions = subscription.NewStore()
amazingCandidate := &ConnectionSetup{
URL: "ws" + mock.URL[len("http"):] + "/ws",
Connector: func(ctx context.Context, conn Connection) error {
return conn.DialContext(ctx, websocket.DefaultDialer, nil)
},
GenerateSubscriptions: newgen.generateSubs,
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(w)(ctx, c, s)
},
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(w)(ctx, c, s)
},
Handler: func(context.Context, []byte) error { return nil },
Subscriber: func(context.Context, Connection, subscription.List) error { return nil },
Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil },
Handler: func(context.Context, []byte) error { return nil },
}
require.NoError(t, w.SetupNewConnection(amazingCandidate))
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store")
w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(w)(ctx, c, s)
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
// Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines.
@@ -1181,6 +1199,11 @@ func TestFlushChannels(t *testing.T) {
// of the connection from management.
w.features.Subscribe = true
w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store")
w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(w)(ctx, c, s)
}
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
}
@@ -1224,7 +1247,7 @@ func TestSetupNewConnection(t *testing.T) {
web := NewWebsocket()
err = web.Setup(defaultSetup)
err = web.Setup(newDefaultSetup())
assert.NoError(t, err, "Setup should not error")
err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"})
@@ -1235,9 +1258,9 @@ func TestSetupNewConnection(t *testing.T) {
// Test connection candidates for multi connection tracking.
multi := NewWebsocket()
set := *defaultSetup
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
require.NoError(t, multi.Setup(&set))
require.NoError(t, multi.Setup(set))
err = multi.SetupNewConnection(nil)
require.ErrorIs(t, err, errExchangeConfigEmpty)
@@ -1566,3 +1589,43 @@ func TestGetConnection(t *testing.T) {
require.NoError(t, err)
assert.Same(t, expected, conn)
}
func TestUpdateChannelSubscriptions(t *testing.T) {
t.Parallel()
ws := Websocket{}
store := subscription.NewStore()
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.ErrorIs(t, err, common.ErrNilPointer)
require.Zero(t, store.Len())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Add(sub); err != nil {
return err
}
}
return nil
}
ws.subscriptions = store
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.NoError(t, err)
require.Equal(t, 1, store.Len())
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.ErrorIs(t, err, common.ErrNilPointer)
ws.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Remove(sub); err != nil {
return err
}
}
return nil
}
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.NoError(t, err)
require.Zero(t, store.Len())
}

View File

@@ -206,3 +206,33 @@ func (s *Store) Len() int {
defer s.mu.RUnlock()
return len(s.m)
}
// Contained returns a list of subscriptions in `compare` that are already in the store.
func (s *Store) Contained(compare List) (matched List) {
if s == nil || s.m == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
for _, sub := range compare {
if found := s.get(sub); found != nil {
matched = append(matched, found)
}
}
return matched
}
// Missing returns a list of subscriptions in `compare` that are not in the store.
func (s *Store) Missing(compare List) (missing List) {
if s == nil || s.m == nil {
return compare // All are missing
}
s.mu.RLock()
defer s.mu.RUnlock()
for _, sub := range compare {
if found := s.get(sub); found == nil {
missing = append(missing, sub)
}
}
return missing
}

View File

@@ -204,3 +204,56 @@ func EqualLists(tb testing.TB, a, b List) {
assert.Fail(tb, fail, "Subscriptions should be equal")
}
}
func TestContained(t *testing.T) {
t.Parallel()
var s *Store
matched := s.Contained(nil)
assert.Nil(t, matched)
matched = s.Contained(List{{Channel: TickerChannel}})
assert.Nil(t, matched)
s = NewStore()
matched = s.Contained(nil)
assert.Nil(t, matched)
matched = s.Contained(List{})
assert.Nil(t, matched)
matched = s.Contained(List{{Channel: TickerChannel}})
assert.Nil(t, matched)
require.NoError(t, s.add(&Subscription{Channel: TickerChannel}))
matched = s.Contained(List{{Channel: TickerChannel}})
assert.Len(t, matched, 1)
}
func TestMissing(t *testing.T) {
t.Parallel()
var s *Store
unmatched := s.Missing(nil)
assert.Nil(t, unmatched)
unmatched = s.Missing(List{{Channel: TickerChannel}})
assert.Len(t, unmatched, 1)
s = NewStore()
unmatched = s.Missing(nil)
assert.Nil(t, unmatched)
unmatched = s.Missing(List{})
assert.Nil(t, unmatched)
unmatched = s.Missing(List{{Channel: TickerChannel}})
assert.Len(t, unmatched, 1)
require.NoError(t, s.add(&Subscription{Channel: TickerChannel}))
unmatched = s.Missing(List{{Channel: TickerChannel}})
assert.Nil(t, unmatched)
}