Files
gocryptotrader/internal/exchange/websocket/subscriptions_test.go
Gareth Kirwan b4e45e9a1b Websocket: Restructure files and types (#1859)
* Websocket: Rename stream package

* Websocket: Rename Websocket to Manager

* Websocket: Replace explicit errs with common.NilGuard

* Websocket: Move websocket_types.go to types.go

* Websocket: Minor field comment and alignment in types

* Webosocket: Rename WebsocketConnection to Connection

* Alphapoint: Make gorilla ws import explicit

Just to avoid confusion with our own packages.

* Websocket: Move stream_match to match

* Websocket: Move websocket_connection to connection

* Websocket: Move websocket.go to manager.go

* Websocket: Break out all subscription methods into subscriptions.go

* Websocket: Move connection type into its file

* Websocket: Remove PositionUpdated

Type is not used anywhere

* Kraken: Use local constant for pong

Was the only use of websocket.Pong and doesn't really feel right to
represent kraken's api resp in one of our packages

* Websocket: Move connection sub-types to connection package

* Websocket: Move manager types into manager

* Websocket: Move ConnectionWrapper into manager

* Websocket: Move websocket_test to manager_test

* Websocket: Privatise connectionWrapper

* Websocket: Remaining types into types.go

These really belong somewhere else mostly, but this will do for now

* Websocket: Tidy up connection method vars

* Gofumpt: Moving package imports around

* Websocket: Rename errDuplicateConnectionSetup

* Websocket: Fix duplicate import of gws

* Websocket: Fix gofumpt -extra

* Websocket: Standardise import of gws across other pkgs

* Kraken: Remove unused sub conf consts

These were replaced by the generic Levels and Depth fields on all subs

* Websocket: Privitise ConnectioWrapper fields

* Websocket: inline single use var WebsocketNotAuthenticatedUsingRest

* Websocket: Move documentation to template

* Bithumb: Assertify TestWsHandleData
2025-04-10 16:25:02 +10:00

306 lines
16 KiB
Go

package websocket
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
)
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewManager()
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
require.NoError(t, err, "Generating test subscriptions must not error")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") {
assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error")
assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error")
ws.Subscriber = func(subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = ws.SubscribeToChannels(nil, subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
multi := NewManager()
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
assert.NoError(t, multi.Setup(set))
amazingCandidate := &ConnectionSetup{
URL: "AMAZING",
Connector: func(context.Context, Connection) error { return nil },
GenerateSubscriptions: ws.GenerateSubs,
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(multi)(ctx, c, s)
},
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(multi)(ctx, c, s)
},
Handler: func(context.Context, []byte) error { return nil },
}
require.NoError(t, multi.SetupNewConnection(amazingCandidate))
amazingConn := multi.getConnectionFromSetup(amazingCandidate)
multi.connections = map[Connection]*connectionWrapper{
amazingConn: multi.connectionManager[0],
}
subs, err = amazingCandidate.GenerateSubscriptions()
require.NoError(t, err, "Generating test subscriptions must not error")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error")
assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error")
assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions")
bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") {
assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
}
assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error")
amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = multi.SubscribeToChannels(amazingConn, subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
}
// TestResubscribe tests Resubscribing to existing subscriptions
func TestResubscribe(t *testing.T) {
t.Parallel()
ws := NewManager()
wackedOutSetup := newDefaultSetup()
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
err := ws.Setup(wackedOutSetup)
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
err = ws.Setup(newDefaultSetup())
assert.NoError(t, err, "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
channel := subscription.List{{Channel: "resubTest"}}
assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed")
}
// TestSubscriptions tests adding, getting and removing subscriptions
func TestSubscriptions(t *testing.T) {
t.Parallel()
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Manager)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error")
assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing")
err := w.RemoveSubscriptions(nil, s)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed")
require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error")
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state")
}
// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions
func TestSuccessfulSubscriptions(t *testing.T) {
t.Parallel()
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Manager)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription must not error")
assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error")
require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error")
err := w.RemoveSubscriptions(nil, c)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found")
assert.ErrorIs(t, (*Manager)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
w.subscriptions = nil
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket")
}
// TestGetSubscription logic test
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Manager).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Manager{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := NewManager()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error")
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}
// TestGetSubscriptions logic test
func TestGetSubscriptions(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Manager).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Manager{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
w := NewManager()
s := subscription.List{
{Key: 42, Channel: "hello3"},
{Key: 45, Channel: "hello4"},
}
err := w.AddSubscriptions(nil, s...)
require.NoError(t, err, "AddSubscriptions must not error")
assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details")
}
func TestCheckSubscriptions(t *testing.T) {
t.Parallel()
ws := Manager{}
err := ws.checkSubscriptions(nil, nil)
assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions")
assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil, nil)
assert.NoError(t, err, "checkSubscriptions should not error on a nil list")
ws.MaxSubscriptionsPerConnection = 1
err = ws.checkSubscriptions(nil, subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil, subscription.List{{}, {}})
assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly")
ws.MaxSubscriptionsPerConnection = 2
ws.subscriptions = subscription.NewStore()
err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"})
require.NoError(t, err, "Add subscription must not error")
err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly")
err = ws.checkSubscriptions(nil, subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error")
}
func TestUpdateChannelSubscriptions(t *testing.T) {
t.Parallel()
ws := NewManager()
store := subscription.NewStore()
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.ErrorIs(t, err, common.ErrNilPointer)
require.Zero(t, store.Len())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Add(sub); err != nil {
return err
}
}
return nil
}
ws.subscriptions = store
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.NoError(t, err)
require.Equal(t, 1, store.Len())
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.ErrorIs(t, err, common.ErrNilPointer)
ws.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Remove(sub); err != nil {
return err
}
}
return nil
}
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.NoError(t, err)
require.Zero(t, store.Len())
}
func currySimpleSub(w *Manager) func(subscription.List) error {
return func(subs subscription.List) error {
return w.AddSuccessfulSubscriptions(nil, subs...)
}
}
func currySimpleSubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
return func(_ context.Context, conn Connection, subs subscription.List) error {
return w.AddSuccessfulSubscriptions(conn, subs...)
}
}
func currySimpleUnsub(w *Manager) func(subscription.List) error {
return func(unsubs subscription.List) error {
return w.RemoveSubscriptions(nil, unsubs...)
}
}
func currySimpleUnsubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
return func(_ context.Context, conn Connection, unsubs subscription.List) error {
return w.RemoveSubscriptions(conn, unsubs...)
}
}