mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
stream: force subscription store check as stop gap for wrapper side implementation (#1717)
* Add check for subscription store insertion and validation with tests * bybit: fix subscriptions * deribit: fix subscriptions * linter: fix * glorious/nits: add test for updateChannelSubscriptions RM GetChannelDifference method as its only used locally and Diff method can be accessed directly * glorious/nits: add to store in the loop; add correct formatting to template for edge case perps with settlement * spelling: fix * glorious/nit: silly billy * gk: nits * gk/nits: split PartitionByPresence into Contained and Missing * gk/nits: formatPerpetualPairWithSettlement -> formatChannelPair * stream/websocket: stop full websocket disconnect on Connect when encountering subscription specific error paths * stream/websocket: rm nil assignment * glorious: nits * gk: niterinos * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * Update exchanges/stream/websocket_test.go Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io> * thrasher: nits --------- Co-authored-by: shazbert <ryan.oharareid@thrasher.io> Co-authored-by: Adrian Gallagher <adrian.gallagher@thrasher.io>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/encoding/json"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
"github.com/thrasher-corp/gocryptotrader/types"
|
||||
)
|
||||
|
||||
@@ -33,10 +34,11 @@ type Authenticate struct {
|
||||
|
||||
// SubscriptionArgument represents a subscription arguments.
|
||||
type SubscriptionArgument struct {
|
||||
auth bool `json:"-"`
|
||||
RequestID string `json:"req_id"`
|
||||
Operation string `json:"op"`
|
||||
Arguments []string `json:"args"`
|
||||
auth bool `json:"-"`
|
||||
RequestID string `json:"req_id"`
|
||||
Operation string `json:"op"`
|
||||
Arguments []string `json:"args"`
|
||||
associatedSubs subscription.List `json:"-"`
|
||||
}
|
||||
|
||||
// Fee holds fee information
|
||||
|
||||
@@ -167,30 +167,19 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
chans := []string{}
|
||||
authChans := []string{}
|
||||
for _, s := range subs {
|
||||
if s.Authenticated {
|
||||
authChans = append(authChans, s.QualifiedChannel)
|
||||
} else {
|
||||
chans = append(chans, s.QualifiedChannel)
|
||||
|
||||
for _, list := range []subscription.List{subs.Public(), subs.Private()} {
|
||||
for _, b := range common.Batch(list, 10) {
|
||||
args = append(args, SubscriptionArgument{
|
||||
auth: b[0].Authenticated,
|
||||
Operation: operation,
|
||||
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
|
||||
Arguments: b.QualifiedChannels(),
|
||||
associatedSubs: b,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, b := range common.Batch(chans, 10) {
|
||||
args = append(args, SubscriptionArgument{
|
||||
Operation: operation,
|
||||
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
|
||||
Arguments: b,
|
||||
})
|
||||
}
|
||||
if len(authChans) != 0 {
|
||||
args = append(args, SubscriptionArgument{
|
||||
auth: true,
|
||||
Operation: operation,
|
||||
RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10),
|
||||
Arguments: authChans,
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -225,6 +214,22 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
|
||||
if !resp.Success {
|
||||
return fmt.Errorf("%s with request ID %s msg: %s", resp.Operation, resp.RequestID, resp.RetMsg)
|
||||
}
|
||||
|
||||
var conn stream.Connection
|
||||
if payloads[a].auth {
|
||||
conn = by.Websocket.AuthConn
|
||||
} else {
|
||||
conn = by.Websocket.Conn
|
||||
}
|
||||
|
||||
if operation == "unsubscribe" {
|
||||
err = by.Websocket.RemoveSubscriptions(conn, payloads[a].associatedSubs...)
|
||||
} else {
|
||||
err = by.Websocket.AddSubscriptions(conn, payloads[a].associatedSubs...)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4094,3 +4094,14 @@ func TestGetCurrencyTradeURL(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp)
|
||||
}
|
||||
|
||||
func TestFormatChannelPair(t *testing.T) {
|
||||
t.Parallel()
|
||||
pair := currency.NewPair(currency.BTC, currency.NewCode("USDC-PERPETUAL"))
|
||||
pair.Delimiter = "-"
|
||||
assert.Equal(t, "BTC_USDC-PERPETUAL", formatChannelPair(pair))
|
||||
|
||||
pair = currency.NewPair(currency.BTC, currency.NewCode("PERPETUAL"))
|
||||
pair.Delimiter = "-"
|
||||
assert.Equal(t, "BTC-PERPETUAL", formatChannelPair(pair))
|
||||
}
|
||||
|
||||
@@ -779,6 +779,7 @@ func (d *Deribit) GetSubscriptionTemplate(_ *subscription.Subscription) (*templa
|
||||
"channelName": channelName,
|
||||
"interval": channelInterval,
|
||||
"isSymbolChannel": isSymbolChannel,
|
||||
"fmt": formatChannelPair,
|
||||
}).
|
||||
Parse(subTplText)
|
||||
}
|
||||
@@ -805,18 +806,19 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro
|
||||
if err != nil || len(subs) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
r := WsSubscriptionInput{
|
||||
JSONRPCVersion: rpcVersion,
|
||||
ID: d.Websocket.Conn.GenerateMessageID(false),
|
||||
Method: method,
|
||||
Params: map[string][]string{
|
||||
"channels": subs.QualifiedChannels(),
|
||||
},
|
||||
Params: map[string][]string{"channels": subs.QualifiedChannels()},
|
||||
}
|
||||
|
||||
data, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, r.ID, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var response wsSubscriptionResponse
|
||||
err = json.Unmarshal(data, &response)
|
||||
if err != nil {
|
||||
@@ -827,15 +829,25 @@ func (d *Deribit) handleSubscription(method string, subs subscription.List) erro
|
||||
subAck[c] = true
|
||||
}
|
||||
if len(subAck) != len(subs) {
|
||||
err = common.ErrUnknownError
|
||||
err = stream.ErrSubscriptionFailure
|
||||
}
|
||||
for _, s := range subs {
|
||||
if _, ok := subAck[s.QualifiedChannel]; ok {
|
||||
err = common.AppendError(err, d.Websocket.AddSuccessfulSubscriptions(d.Websocket.Conn, s))
|
||||
delete(subAck, s.QualifiedChannel)
|
||||
if !strings.Contains(method, "unsubscribe") {
|
||||
err = common.AppendError(err, d.Websocket.AddSuccessfulSubscriptions(d.Websocket.Conn, s))
|
||||
} else {
|
||||
err = common.AppendError(err, d.Websocket.RemoveSubscriptions(d.Websocket.Conn, s))
|
||||
}
|
||||
} else {
|
||||
err = common.AppendError(err, errors.New(s.String()))
|
||||
err = common.AppendError(err, errors.New(s.String()+" failed to "+method))
|
||||
}
|
||||
}
|
||||
|
||||
for key := range subAck {
|
||||
err = common.AppendError(err, fmt.Errorf("unexpected channel `%s` in result", key))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -899,11 +911,18 @@ func isSymbolChannel(s *subscription.Subscription) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func formatChannelPair(pair currency.Pair) string {
|
||||
if str := pair.Quote.String(); strings.Contains(str, "PERPETUAL") && strings.Contains(str, "-") {
|
||||
pair.Delimiter = "_"
|
||||
}
|
||||
return pair.String()
|
||||
}
|
||||
|
||||
const subTplText = `
|
||||
{{- if isSymbolChannel $.S -}}
|
||||
{{- range $asset, $pairs := $.AssetPairs }}
|
||||
{{- range $p := $pairs }}
|
||||
{{- channelName $.S -}} . {{- $p }}
|
||||
{{- channelName $.S -}} . {{- fmt $p }}
|
||||
{{- with $i := interval $.S -}} . {{- $i }}{{ end }}
|
||||
{{- $.PairSeparator }}
|
||||
{{- end }}
|
||||
|
||||
@@ -22,16 +22,17 @@ const jobBuffer = 5000
|
||||
|
||||
// Public websocket errors
|
||||
var (
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
|
||||
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
ErrRequestRouteNotFound = errors.New("request route not found")
|
||||
ErrSignatureNotSet = errors.New("signature not set")
|
||||
ErrRequestPayloadNotSet = errors.New("request payload not set")
|
||||
)
|
||||
|
||||
// Private websocket errors
|
||||
@@ -372,6 +373,10 @@ func (w *Websocket) connect() error {
|
||||
if err := w.SubscribeToChannels(nil, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := w.subscriptions.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -385,6 +390,9 @@ func (w *Websocket) connect() error {
|
||||
// be shutdown and the websocket to be disconnected.
|
||||
var multiConnectFatalError error
|
||||
|
||||
// subscriptionError is a non-fatal error that does not shutdown connections
|
||||
var subscriptionError error
|
||||
|
||||
// TODO: Implement concurrency below.
|
||||
for i := range w.connectionManager {
|
||||
if w.connectionManager[i].Setup.GenerateSubscriptions == nil {
|
||||
@@ -443,16 +451,20 @@ func (w *Websocket) connect() error {
|
||||
if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() {
|
||||
err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn)
|
||||
if err != nil {
|
||||
// Opted to not fail entirely here for POC. This should be
|
||||
// revisited and handled more gracefully.
|
||||
log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err)
|
||||
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", w.exchangeName, i+1, conn.URL, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs)
|
||||
if err != nil {
|
||||
multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err)
|
||||
break
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", w.exchangeName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if missing := w.connectionManager[i].Subscriptions.Missing(subs); len(missing) > 0 {
|
||||
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing))
|
||||
continue
|
||||
}
|
||||
|
||||
if w.verbose {
|
||||
@@ -496,7 +508,7 @@ func (w *Websocket) connect() error {
|
||||
go w.monitorFrame(nil, w.monitorConnection)
|
||||
}
|
||||
|
||||
return nil
|
||||
return subscriptionError
|
||||
}
|
||||
|
||||
// Disable disables the exchange websocket protocol
|
||||
@@ -625,14 +637,7 @@ func (w *Websocket) FlushChannels() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subs, unsubs := w.GetChannelDifference(nil, newSubs)
|
||||
if err := w.UnsubscribeChannels(nil, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(subs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return w.SubscribeToChannels(nil, subs)
|
||||
return w.updateChannelSubscriptions(nil, w.subscriptions, newSubs)
|
||||
}
|
||||
|
||||
for x := range w.connectionManager {
|
||||
@@ -658,17 +663,9 @@ func (w *Websocket) FlushChannels() error {
|
||||
w.connectionManager[x].Connection = conn
|
||||
}
|
||||
|
||||
subs, unsubs := w.GetChannelDifference(w.connectionManager[x].Connection, newSubs)
|
||||
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(w.connectionManager[x].Connection, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(w.connectionManager[x].Connection, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.updateChannelSubscriptions(w.connectionManager[x].Connection, w.connectionManager[x].Subscriptions, newSubs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
|
||||
@@ -683,6 +680,31 @@ func (w *Websocket) FlushChannels() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
|
||||
// have been subscribed to or unsubscribed from.
|
||||
func (w *Websocket) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
|
||||
subs, unsubs := store.Diff(incoming)
|
||||
if len(unsubs) != 0 {
|
||||
if err := w.UnsubscribeChannels(c, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if contained := store.Contained(unsubs); len(contained) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotRemoved, contained)
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := w.SubscribeToChannels(c, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := store.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", w.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Websocket) setState(s uint32) {
|
||||
w.state.Store(s)
|
||||
}
|
||||
@@ -830,21 +852,6 @@ func (w *Websocket) GetName() string {
|
||||
return w.exchangeName
|
||||
}
|
||||
|
||||
// GetChannelDifference finds the difference between the subscribed channels
|
||||
// and the new subscription list when pairs are disabled or enabled.
|
||||
func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) {
|
||||
var subscriptionStore **subscription.Store
|
||||
if wrapper, ok := w.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = &wrapper.Subscriptions
|
||||
} else {
|
||||
subscriptionStore = &w.subscriptions
|
||||
}
|
||||
if *subscriptionStore == nil {
|
||||
*subscriptionStore = subscription.NewStore()
|
||||
}
|
||||
return (*subscriptionStore).Diff(newSubs)
|
||||
}
|
||||
|
||||
// UnsubscribeChannels unsubscribes from a list of websocket channel
|
||||
func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.List) error {
|
||||
if len(channels) == 0 {
|
||||
@@ -855,6 +862,11 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L
|
||||
return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels)
|
||||
})
|
||||
}
|
||||
|
||||
if w.Unsubscriber == nil {
|
||||
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
return w.unsubscribe(w.subscriptions, channels, func(channels subscription.List) error {
|
||||
return w.Unsubscriber(channels)
|
||||
})
|
||||
|
||||
@@ -63,31 +63,33 @@ type testSubKey struct {
|
||||
Mood string
|
||||
}
|
||||
|
||||
var defaultSetup = &WebsocketSetup{
|
||||
ExchangeConfig: &config.Exchange{
|
||||
Features: &config.FeaturesConfig{
|
||||
Enabled: config.FeaturesEnabledConfig{Websocket: true},
|
||||
func newDefaultSetup() *WebsocketSetup {
|
||||
return &WebsocketSetup{
|
||||
ExchangeConfig: &config.Exchange{
|
||||
Features: &config.FeaturesConfig{
|
||||
Enabled: config.FeaturesEnabledConfig{Websocket: true},
|
||||
},
|
||||
API: config.APIConfig{
|
||||
AuthenticatedWebsocketSupport: true,
|
||||
},
|
||||
WebsocketTrafficTimeout: time.Second * 5,
|
||||
Name: "GTX",
|
||||
},
|
||||
API: config.APIConfig{
|
||||
AuthenticatedWebsocketSupport: true,
|
||||
DefaultURL: "testDefaultURL",
|
||||
RunningURL: "wss://testRunningURL",
|
||||
Connector: func() error { return nil },
|
||||
Subscriber: func(subscription.List) error { return nil },
|
||||
Unsubscriber: func(subscription.List) error { return nil },
|
||||
GenerateSubscriptions: func() (subscription.List, error) {
|
||||
return subscription.List{
|
||||
{Channel: "TestSub"},
|
||||
{Channel: "TestSub2", Key: "purple"},
|
||||
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
|
||||
{Channel: "TestSub4", Key: 42},
|
||||
}, nil
|
||||
},
|
||||
WebsocketTrafficTimeout: time.Second * 5,
|
||||
Name: "GTX",
|
||||
},
|
||||
DefaultURL: "testDefaultURL",
|
||||
RunningURL: "wss://testRunningURL",
|
||||
Connector: func() error { return nil },
|
||||
Subscriber: func(subscription.List) error { return nil },
|
||||
Unsubscriber: func(subscription.List) error { return nil },
|
||||
GenerateSubscriptions: func() (subscription.List, error) {
|
||||
return subscription.List{
|
||||
{Channel: "TestSub"},
|
||||
{Channel: "TestSub2", Key: "purple"},
|
||||
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
|
||||
{Channel: "TestSub4", Key: 42},
|
||||
}, nil
|
||||
},
|
||||
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
|
||||
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
@@ -186,13 +188,23 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
|
||||
|
||||
ws := NewWebsocket()
|
||||
err = ws.Setup(defaultSetup)
|
||||
err = ws.Setup(newDefaultSetup())
|
||||
require.NoError(t, err, "Setup must not error")
|
||||
ws.trafficTimeout = time.Minute
|
||||
ws.connector = connect
|
||||
|
||||
err = ws.Connect()
|
||||
require.NoError(t, err, "Connect must not error")
|
||||
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := ws.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, ws.Connect(), "Connect must not error")
|
||||
|
||||
checkToRoutineResult := func(t *testing.T) {
|
||||
t.Helper()
|
||||
@@ -240,7 +252,7 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
|
||||
ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) {
|
||||
return subscription.List{{}}, nil
|
||||
return subscription.List{{Channel: "test"}}, nil
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errNoConnectFunc)
|
||||
@@ -275,10 +287,26 @@ func TestConnectionMessageErrors(t *testing.T) {
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return errDastardlyReason
|
||||
}
|
||||
ws.connectionManager[0].Setup.Authenticate = nil
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, errDastardlyReason)
|
||||
require.NoError(t, ws.shutdown())
|
||||
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return nil
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.ErrorIs(t, err, ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.shutdown())
|
||||
|
||||
ws.connectionManager[0].Subscriptions = subscription.NewStore()
|
||||
ws.connectionManager[0].Setup.Subscriber = func(context.Context, Connection, subscription.List) error {
|
||||
return ws.connectionManager[0].Subscriptions.Add(&subscription.Subscription{Channel: "test"})
|
||||
}
|
||||
err = ws.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ws.connectionManager[0].Connection.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte("test"))
|
||||
@@ -297,6 +325,7 @@ func TestWebsocket(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")
|
||||
|
||||
ws.setEnabled(true)
|
||||
defaultSetup := newDefaultSetup()
|
||||
err = ws.Setup(defaultSetup) // Sets to enabled again
|
||||
require.NoError(t, err, "Setup may not error")
|
||||
|
||||
@@ -340,8 +369,19 @@ func TestWebsocket(t *testing.T) {
|
||||
assert.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.connector = func() error { return nil }
|
||||
err = ws.Connect()
|
||||
assert.NoError(t, err, "Connect should not error")
|
||||
|
||||
require.ErrorIs(t, ws.Connect(), ErrSubscriptionsNotAdded)
|
||||
require.NoError(t, ws.Shutdown())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := ws.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, ws.Connect(), "Connect should not error")
|
||||
|
||||
ws.defaultURL = "ws://demos.kaazing.com/echo"
|
||||
ws.defaultURLAuth = "ws://demos.kaazing.com/echo"
|
||||
@@ -407,14 +447,14 @@ func currySimpleUnsubConn(w *Websocket) func(context.Context, Connection, subscr
|
||||
func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewWebsocket()
|
||||
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
|
||||
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
ws.Unsubscriber = currySimpleUnsub(ws)
|
||||
|
||||
subs, err := ws.GenerateSubs()
|
||||
require.NoError(t, err, "Generating test subscriptions should not error")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
@@ -447,9 +487,9 @@ func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
|
||||
|
||||
multi := NewWebsocket()
|
||||
set := *defaultSetup
|
||||
set := newDefaultSetup()
|
||||
set.UseMultiConnectionManagement = true
|
||||
assert.NoError(t, multi.Setup(&set))
|
||||
assert.NoError(t, multi.Setup(set))
|
||||
|
||||
amazingCandidate := &ConnectionSetup{
|
||||
URL: "AMAZING",
|
||||
@@ -472,8 +512,8 @@ func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
|
||||
subs, err = amazingCandidate.GenerateSubscriptions()
|
||||
require.NoError(t, err, "Generating test subscriptions should not error")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(nil, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.NoError(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), "Should not error when w.subscriptions is nil")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.ErrorIs(t, new(Websocket).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
@@ -514,12 +554,12 @@ func TestResubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewWebsocket()
|
||||
|
||||
wackedOutSetup := *defaultSetup
|
||||
wackedOutSetup := newDefaultSetup()
|
||||
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
|
||||
err := ws.Setup(&wackedOutSetup)
|
||||
err := ws.Setup(wackedOutSetup)
|
||||
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
|
||||
|
||||
err = ws.Setup(defaultSetup)
|
||||
err = ws.Setup(newDefaultSetup())
|
||||
assert.NoError(t, err, "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
@@ -992,52 +1032,6 @@ func TestCheckWebsocketURL(t *testing.T) {
|
||||
assert.NoError(t, err, "checkWebsocketURL should not error")
|
||||
}
|
||||
|
||||
// TestGetChannelDifference exercises GetChannelDifference
|
||||
// See subscription.TestStoreDiff for further testing
|
||||
func TestGetChannelDifference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
w := &Websocket{}
|
||||
assert.NotPanics(t, func() { w.GetChannelDifference(nil, subscription.List{}) }, "Should not panic when called without a store")
|
||||
subs, unsubs := w.GetChannelDifference(nil, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
require.NoError(t, w.AddSubscriptions(nil, subs...), "AddSubscriptions must not error")
|
||||
subs, unsubs = w.GetChannelDifference(nil, subscription.List{{Channel: subscription.TickerChannel}})
|
||||
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
|
||||
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
|
||||
|
||||
w = &Websocket{}
|
||||
sweetConn := &WebsocketConnection{}
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
w.connections = map[Connection]*ConnectionWrapper{
|
||||
sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}},
|
||||
}
|
||||
|
||||
naughtyConn := &WebsocketConnection{}
|
||||
subs, unsubs = w.GetChannelDifference(naughtyConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Equal(t, 1, len(subs))
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel})
|
||||
require.NoError(t, err)
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}})
|
||||
require.Empty(t, subs, "Should get no subs")
|
||||
require.Empty(t, unsubs, "Should get no unsubs")
|
||||
|
||||
subs, unsubs = w.GetChannelDifference(sweetConn, nil)
|
||||
require.Empty(t, subs, "Should get no subs")
|
||||
require.Equal(t, 1, len(unsubs))
|
||||
}
|
||||
|
||||
// GenSubs defines a theoretical exchange with pair management
|
||||
type GenSubs struct {
|
||||
EnabledPairs currency.Pairs
|
||||
@@ -1111,16 +1105,37 @@ func TestFlushChannels(t *testing.T) {
|
||||
// Disable pair and flush system
|
||||
newgen.EnabledPairs = []currency.Pair{currency.NewPair(currency.BTC, currency.AUD)}
|
||||
w.GenerateSubs = func() (subscription.List, error) { return subscription.List{{Channel: "test"}}, nil }
|
||||
err = w.FlushChannels()
|
||||
require.NoError(t, err, "Flush Channels must not error")
|
||||
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "FlushChannels must error correctly on no subscriptions added")
|
||||
|
||||
w.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := w.subscriptions.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
|
||||
w.GenerateSubs = func() (subscription.List, error) { return nil, errDastardlyReason } // error on generateSubs
|
||||
err = w.FlushChannels() // error on full subscribeToChannels
|
||||
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
|
||||
|
||||
w.GenerateSubs = func() (subscription.List, error) { return nil, nil } // No subs to sub
|
||||
err = w.FlushChannels() // No subs to sub
|
||||
assert.NoError(t, err, "Flush Channels should not error")
|
||||
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved)
|
||||
|
||||
w.Unsubscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := w.subscriptions.Remove(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, w.FlushChannels(), "FlushChannels should not error")
|
||||
|
||||
w.GenerateSubs = newgen.generateSubs
|
||||
subs, err := w.GenerateSubs()
|
||||
@@ -1156,21 +1171,24 @@ func TestFlushChannels(t *testing.T) {
|
||||
mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) }))
|
||||
defer mock.Close()
|
||||
|
||||
w.subscriptions = subscription.NewStore()
|
||||
|
||||
amazingCandidate := &ConnectionSetup{
|
||||
URL: "ws" + mock.URL[len("http"):] + "/ws",
|
||||
Connector: func(ctx context.Context, conn Connection) error {
|
||||
return conn.DialContext(ctx, websocket.DefaultDialer, nil)
|
||||
},
|
||||
GenerateSubscriptions: newgen.generateSubs,
|
||||
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleSubConn(w)(ctx, c, s)
|
||||
},
|
||||
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleUnsubConn(w)(ctx, c, s)
|
||||
},
|
||||
Handler: func(context.Context, []byte) error { return nil },
|
||||
Subscriber: func(context.Context, Connection, subscription.List) error { return nil },
|
||||
Unsubscriber: func(context.Context, Connection, subscription.List) error { return nil },
|
||||
Handler: func(context.Context, []byte) error { return nil },
|
||||
}
|
||||
require.NoError(t, w.SetupNewConnection(amazingCandidate))
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotAdded, "Must error when no subscriptions are added to the subscription store")
|
||||
|
||||
w.connectionManager[0].Setup.Subscriber = func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleSubConn(w)(ctx, c, s)
|
||||
}
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
|
||||
// Forces full connection cycle (shutdown, connect, subscribe). This will also start monitoring routines.
|
||||
@@ -1181,6 +1199,11 @@ func TestFlushChannels(t *testing.T) {
|
||||
// of the connection from management.
|
||||
w.features.Subscribe = true
|
||||
w.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
|
||||
require.ErrorIs(t, w.FlushChannels(), ErrSubscriptionsNotRemoved, "Must error when no subscriptions are removed from subscription store")
|
||||
|
||||
w.connectionManager[0].Setup.Unsubscriber = func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleUnsubConn(w)(ctx, c, s)
|
||||
}
|
||||
require.NoError(t, w.FlushChannels(), "FlushChannels must not error")
|
||||
}
|
||||
|
||||
@@ -1224,7 +1247,7 @@ func TestSetupNewConnection(t *testing.T) {
|
||||
|
||||
web := NewWebsocket()
|
||||
|
||||
err = web.Setup(defaultSetup)
|
||||
err = web.Setup(newDefaultSetup())
|
||||
assert.NoError(t, err, "Setup should not error")
|
||||
|
||||
err = web.SetupNewConnection(&ConnectionSetup{URL: "urlstring"})
|
||||
@@ -1235,9 +1258,9 @@ func TestSetupNewConnection(t *testing.T) {
|
||||
|
||||
// Test connection candidates for multi connection tracking.
|
||||
multi := NewWebsocket()
|
||||
set := *defaultSetup
|
||||
set := newDefaultSetup()
|
||||
set.UseMultiConnectionManagement = true
|
||||
require.NoError(t, multi.Setup(&set))
|
||||
require.NoError(t, multi.Setup(set))
|
||||
|
||||
err = multi.SetupNewConnection(nil)
|
||||
require.ErrorIs(t, err, errExchangeConfigEmpty)
|
||||
@@ -1566,3 +1589,43 @@ func TestGetConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, expected, conn)
|
||||
}
|
||||
|
||||
func TestUpdateChannelSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ws := Websocket{}
|
||||
store := subscription.NewStore()
|
||||
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
require.Zero(t, store.Len())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ws.subscriptions = store
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, store.Len())
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
ws.Unsubscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Remove(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, store.Len())
|
||||
}
|
||||
|
||||
@@ -206,3 +206,33 @@ func (s *Store) Len() int {
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.m)
|
||||
}
|
||||
|
||||
// Contained returns a list of subscriptions in `compare` that are already in the store.
|
||||
func (s *Store) Contained(compare List) (matched List) {
|
||||
if s == nil || s.m == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, sub := range compare {
|
||||
if found := s.get(sub); found != nil {
|
||||
matched = append(matched, found)
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
// Missing returns a list of subscriptions in `compare` that are not in the store.
|
||||
func (s *Store) Missing(compare List) (missing List) {
|
||||
if s == nil || s.m == nil {
|
||||
return compare // All are missing
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, sub := range compare {
|
||||
if found := s.get(sub); found == nil {
|
||||
missing = append(missing, sub)
|
||||
}
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
@@ -204,3 +204,56 @@ func EqualLists(tb testing.TB, a, b List) {
|
||||
assert.Fail(tb, fail, "Subscriptions should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContained(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var s *Store
|
||||
matched := s.Contained(nil)
|
||||
assert.Nil(t, matched)
|
||||
|
||||
matched = s.Contained(List{{Channel: TickerChannel}})
|
||||
assert.Nil(t, matched)
|
||||
|
||||
s = NewStore()
|
||||
matched = s.Contained(nil)
|
||||
assert.Nil(t, matched)
|
||||
|
||||
matched = s.Contained(List{})
|
||||
assert.Nil(t, matched)
|
||||
|
||||
matched = s.Contained(List{{Channel: TickerChannel}})
|
||||
assert.Nil(t, matched)
|
||||
|
||||
require.NoError(t, s.add(&Subscription{Channel: TickerChannel}))
|
||||
|
||||
matched = s.Contained(List{{Channel: TickerChannel}})
|
||||
assert.Len(t, matched, 1)
|
||||
}
|
||||
|
||||
func TestMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var s *Store
|
||||
|
||||
unmatched := s.Missing(nil)
|
||||
assert.Nil(t, unmatched)
|
||||
|
||||
unmatched = s.Missing(List{{Channel: TickerChannel}})
|
||||
assert.Len(t, unmatched, 1)
|
||||
|
||||
s = NewStore()
|
||||
unmatched = s.Missing(nil)
|
||||
assert.Nil(t, unmatched)
|
||||
|
||||
unmatched = s.Missing(List{})
|
||||
assert.Nil(t, unmatched)
|
||||
|
||||
unmatched = s.Missing(List{{Channel: TickerChannel}})
|
||||
assert.Len(t, unmatched, 1)
|
||||
|
||||
require.NoError(t, s.add(&Subscription{Channel: TickerChannel}))
|
||||
|
||||
unmatched = s.Missing(List{{Channel: TickerChannel}})
|
||||
assert.Nil(t, unmatched)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user