package websocket import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() ws := NewManager() assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error") ws.Subscriber = currySimpleSub(ws) ws.Unsubscriber = currySimpleUnsub(ws) subs, err := ws.GenerateSubs() require.NoError(t, err, "Generating test subscriptions must not error") assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error") assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") } if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") } if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") } if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") { assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error") assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error") ws.Subscriber = func(subscription.List) error { return errDastardlyReason } assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") err = ws.SubscribeToChannels(nil, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") multi := NewManager() set := newDefaultSetup() set.UseMultiConnectionManagement = true assert.NoError(t, multi.Setup(set)) amazingCandidate := &ConnectionSetup{ URL: "AMAZING", Connector: func(context.Context, Connection) error { return nil }, GenerateSubscriptions: ws.GenerateSubs, Subscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleSubConn(multi)(ctx, c, s) }, Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error { return currySimpleUnsubConn(multi)(ctx, c, s) }, Handler: func(context.Context, Connection, []byte) error { return nil }, } require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) multi.connections = map[Connection]*connectionWrapper{ amazingConn: multi.connectionManager[0], } subs, err = amazingCandidate.GenerateSubscriptions() require.NoError(t, err, "Generating test subscriptions must not error") assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function") assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error") assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return") assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error") assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error") assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions") bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"}) if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") } if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") { assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") } if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") { assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel") } if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") { assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel") } assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil") assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error") assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error") amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason } assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber") err = multi.SubscribeToChannels(amazingConn, subscription.List{nil}) assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription") } // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() ws := NewManager() wackedOutSetup := newDefaultSetup() wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 err := ws.Setup(wackedOutSetup) assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error") err = ws.Setup(newDefaultSetup()) assert.NoError(t, err, "WS Setup should not error") ws.Subscriber = currySimpleSub(ws) ws.Unsubscriber = currySimpleUnsub(ws) channel := subscription.List{{Channel: "resubTest"}} assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error") assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed") } // TestSubscriptions tests adding, getting and removing subscriptions func TestSubscriptions(t *testing.T) { t.Parallel() w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil assert.ErrorIs(t, (*Manager)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket") s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error") assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error") assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing") err := w.RemoveSubscriptions(nil, s) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed") require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error") require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error") assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state") } // TestSuccessfulSubscriptions tests adding, getting and removing subscriptions func TestSuccessfulSubscriptions(t *testing.T) { t.Parallel() w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil assert.ErrorIs(t, (*Manager)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription must not error") assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error") require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error") assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error") err := w.RemoveSubscriptions(nil, c) require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found") assert.ErrorIs(t, (*Manager)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket") w.subscriptions = nil assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket") } // TestGetSubscription logic test func TestGetSubscription(t *testing.T) { t.Parallel() assert.Nil(t, (*Manager).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") assert.Nil(t, (&Manager{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") w := NewManager() assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") s := &subscription.Subscription{Key: 42, Channel: "hello3"} require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error") assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } // TestGetSubscriptions logic test func TestGetSubscriptions(t *testing.T) { t.Parallel() assert.Nil(t, (*Manager).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") assert.Nil(t, (&Manager{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") w := NewManager() s := subscription.List{ {Key: 42, Channel: "hello3"}, {Key: 45, Channel: "hello4"}, } err := w.AddSubscriptions(nil, s...) require.NoError(t, err, "AddSubscriptions must not error") assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details") } func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Manager{} err := ws.checkSubscriptions(nil, nil) assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions") assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions") ws.subscriptions = subscription.NewStore() err = ws.checkSubscriptions(nil, nil) assert.NoError(t, err, "checkSubscriptions should not error on a nil list") ws.MaxSubscriptionsPerConnection = 1 err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty") ws.subscriptions = subscription.NewStore() err = ws.checkSubscriptions(nil, subscription.List{{}, {}}) assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 ws.subscriptions = subscription.NewStore() err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) require.NoError(t, err, "Add subscription must not error") err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}}) assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") err = ws.checkSubscriptions(nil, subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } func TestUpdateChannelSubscriptions(t *testing.T) { t.Parallel() ws := NewManager() store := subscription.NewStore() err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) require.ErrorIs(t, err, common.ErrNilPointer) require.Zero(t, store.Len()) ws.Subscriber = func(subs subscription.List) error { for _, sub := range subs { if err := store.Add(sub); err != nil { return err } } return nil } ws.subscriptions = store err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}}) require.NoError(t, err) require.Equal(t, 1, store.Len()) err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) require.ErrorIs(t, err, common.ErrNilPointer) ws.Unsubscriber = func(subs subscription.List) error { for _, sub := range subs { if err := store.Remove(sub); err != nil { return err } } return nil } err = ws.updateChannelSubscriptions(nil, store, subscription.List{}) require.NoError(t, err) require.Zero(t, store.Len()) } func currySimpleSub(w *Manager) func(subscription.List) error { return func(subs subscription.List) error { return w.AddSuccessfulSubscriptions(nil, subs...) } } func currySimpleSubConn(w *Manager) func(context.Context, Connection, subscription.List) error { return func(_ context.Context, conn Connection, subs subscription.List) error { return w.AddSuccessfulSubscriptions(conn, subs...) } } func currySimpleUnsub(w *Manager) func(subscription.List) error { return func(unsubs subscription.List) error { return w.RemoveSubscriptions(nil, unsubs...) } } func currySimpleUnsubConn(w *Manager) func(context.Context, Connection, subscription.List) error { return func(_ context.Context, conn Connection, unsubs subscription.List) error { return w.RemoveSubscriptions(conn, unsubs...) } }