subscriptions: Encapsulate, replace Pair with Pairs and refactor; improve exchange support

* Websocket: Use ErrSubscribedAlready

instead of errChannelAlreadySubscribed

* Subscriptions: Replace Pair with Pairs

Given that some subscriptions have multiple pairs, support that as the
standard.

* Docs: Update subscriptions in add new exch

* RPC: Update Subscription Pairs

* Linter: Disable testifylint.Len

We deliberately use Equal over Len to avoid spamming the contents of large Slices

* Websocket: Add suffix to state consts

* Binance: Subscription Pairs support

* Bitfinex: Subscription Pairs support

* Bithumb: Subscription Pairs support

* Bitmex: Subscription Pairs support

* Bitstamp: Subscription Pairs support

* BTCMarkets: Subscription Pairs support

* BTSE: Subscription Pairs support

* Coinbase: Subscription Pairs support

* Coinut: Subscription Pairs support

* GateIO: Subscription Pairs support

* Gemini: Subscription Pairs support and improvement

* Hitbtc: Subscription Pairs support

* Huboi: Subscription Pairs support

* Kucoin: Subscription Pairs support

* Okcoin: Subscription Pairs support

* Poloniex: Subscription Pairs support

* Kraken: Add subscription Pairs support

Note: This is a naieve implementation because we want to rebase the
kraken websocket rewrite on top of this

* Bybit: Subscription Pairs support

* Okx: Subscription Pairs support

* Bitmex: Subsription configuration

* Fixes unauthenticated websocket left as CanUseAuth
* Fixes auth subs happening privately

* CoinbasePro: Subscription Configuration

* Consolidate ProductIDs when all subscriptions are for the same list

* Websocket: Log actual sent message when Verbose

* Subscriptions: Improve clarity of which key is which in Match

* Subscriptions: Lint fix for HugeParam

* Subscriptions: Add AddPairs and move keys from test

* Subscriptions: Simplify subscription keys and add key types

* Subscriptions: Add List.GroupPairs Rename sub.AddPairs

* Subscription: Fix ExactKey not matching 0 pairs

* Subscriptions: Remove unused IdentityKey and HasPairKey

* Subscriptions: Fix GetKey test

* Subscriptions: Test coverage improvements

* Websocket: Change State on Add/Remove

* Subscriptions: Improve error context

* Subscriptions: Fix Enable: false subs not ignored

* Bitfinex: Fix WsAuth test failing on DataHandler

DataHandler is eaten by dataMonitor now, so we need to use ToRoutine

* Deribit: Subscription Pairs support

* Websocket: Accept nil lists for checkSubscriptions

If the user passes in a nil (implicitly empty) list, we would not panic.
Therefore the burden of correctness about that data lies with them.
The list of subscriptions is empty, and that's okay, and possibly
convenient

* Websocket: Add context to NilPointer errors

* Subscriptions: Add context to nil errors

* Exchange: Fix error expectations in UnsubToWSChans
This commit is contained in:
Gareth Kirwan
2024-06-07 08:54:08 +07:00
committed by GitHub
parent afb6f75d88
commit 1199f38546
75 changed files with 4551 additions and 3891 deletions

View File

@@ -5,12 +5,14 @@ import (
"fmt"
"net"
"net/url"
"sync"
"slices"
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/log"
)
@@ -22,12 +24,9 @@ const (
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscribedAlready = errors.New("duplicate subscription")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrChannelInStateAlready = errors.New("channel already in state")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
)
@@ -57,9 +56,6 @@ var (
errClosedConnection = errors.New("use of closed network connection")
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
errNoSubscriptionsSupplied = errors.New("no subscriptions supplied")
errChannelAlreadySubscribed = errors.New("channel already subscribed")
errInvalidChannelState = errors.New("invalid Channel state")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("websocket connect func not set")
errAlreadyConnected = errors.New("websocket already connected")
@@ -84,11 +80,13 @@ func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, jobBuffer),
ToRoutine: make(chan interface{}, jobBuffer),
ShutdownC: make(chan struct{}),
TrafficAlert: make(chan struct{}, 1),
ReadMessageErrors: make(chan error),
Subscribe: make(chan []subscription.Subscription),
Unsubscribe: make(chan []subscription.Subscription),
Match: NewMatch(),
subscriptions: subscription.NewStore(),
features: &protocol.Features{},
Orderbook: buffer.Orderbook{},
}
}
@@ -181,7 +179,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout
w.ShutdownC = make(chan struct{})
w.Wg = new(sync.WaitGroup)
w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport)
if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil {
@@ -195,7 +192,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnected)
w.setState(disconnectedState)
return nil
}
@@ -243,7 +240,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
Traffic: w.TrafficAlert,
readMessageErrors: w.ReadMessageErrors,
ShutdownC: w.ShutdownC,
Wg: w.Wg,
Wg: &w.Wg,
Match: w.Match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
@@ -277,20 +274,21 @@ func (w *Websocket) Connect() error {
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected)
}
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
return fmt.Errorf("%w: subscriptions", common.ErrNilPointer)
}
w.subscriptions.Clear()
w.dataMonitor()
w.trafficMonitor()
w.setState(connecting)
w.setState(connectingState)
err := w.connector()
if err != nil {
w.setState(disconnected)
w.setState(disconnectedState)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setState(connected)
w.setState(connectedState)
if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
@@ -303,17 +301,12 @@ func (w *Websocket) Connect() error {
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) == 0 {
return nil
}
err = w.checkSubscriptions(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
err = w.Subscriber(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
if len(subs) != 0 {
if err := w.SubscribeToChannels(subs); err != nil {
return err
}
}
return nil
}
@@ -469,11 +462,9 @@ func (w *Websocket) Shutdown() error {
}
// flush any subscriptions from last connection if needed
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
w.subscriptions.Clear()
w.setState(disconnected)
w.setState(disconnectedState)
close(w.ShutdownC)
w.Wg.Wait()
@@ -527,9 +518,7 @@ func (w *Websocket) FlushChannels() error {
if len(newsubs) != 0 {
// Purge subscription list as there will be conflicts
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
w.subscriptions.Clear()
return w.SubscribeToChannels(newsubs)
}
return nil
@@ -606,17 +595,17 @@ func (w *Websocket) setState(s uint32) {
// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
return w.state.Load() != uninitialised
return w.state.Load() != uninitialisedState
}
// IsConnected returns whether the websocket is connected
func (w *Websocket) IsConnected() bool {
return w.state.Load() == connected
return w.state.Load() == connectedState
}
// IsConnecting returns whether the websocket is connecting
func (w *Websocket) IsConnecting() bool {
return w.state.Load() == connecting
return w.state.Load() == connectingState
}
func (w *Websocket) setEnabled(b bool) {
@@ -786,163 +775,134 @@ func (w *Websocket) GetName() string {
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) {
w.subscriptionMutex.RLock()
unsubMap := make(map[any]subscription.Subscription, len(w.subscriptions))
for k, c := range w.subscriptions {
unsubMap[k] = *c
func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) {
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
w.subscriptionMutex.RUnlock()
for i := range genSubs {
key := genSubs[i].EnsureKeyed()
if _, ok := unsubMap[key]; ok {
delete(unsubMap, key) // If it's in both then we remove it from the unsubscribe list
} else {
sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe
}
}
for x := range unsubMap {
unsub = append(unsub, unsubMap[x])
}
return
return w.subscriptions.Diff(newSubs)
}
// UnsubscribeChannels unsubscribes from a websocket channel
func (w *Websocket) UnsubscribeChannels(channels []subscription.Subscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoSubscriptionsSupplied)
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(channels subscription.List) error {
if w.subscriptions == nil || len(channels) == 0 {
return nil // No channels to unsubscribe from is not an error
}
w.subscriptionMutex.RLock()
for i := range channels {
key := channels[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; !ok {
w.subscriptionMutex.RUnlock()
return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, ErrSubscriptionNotFound, channels[i])
for _, s := range channels {
if w.subscriptions.Get(s) == nil {
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
}
}
w.subscriptionMutex.RUnlock()
return w.Unsubscriber(channels)
}
// ResubscribeToChannel resubscribes to channel
func (w *Websocket) ResubscribeToChannel(subscribedChannel *subscription.Subscription) error {
err := w.UnsubscribeChannels([]subscription.Subscription{*subscribedChannel})
if err != nil {
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
// Errors if subscription is already subscribing
func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {
l := subscription.List{s}
if err := s.SetState(subscription.ResubscribingState); err != nil {
return fmt.Errorf("%w: %s", err, s)
}
if err := w.UnsubscribeChannels(l); err != nil {
return err
}
return w.SubscribeToChannels([]subscription.Subscription{*subscribedChannel})
return w.SubscribeToChannels(l)
}
// SubscribeToChannels appends supplied channels to channelsToSubscribe
func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) error {
if err := w.checkSubscriptions(channels); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
if slices.Contains(subs, nil) {
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
}
if err := w.Subscriber(channels); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
if err := w.checkSubscriptions(subs); err != nil {
return err
}
if err := w.Subscriber(subs); err != nil {
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
}
return nil
}
// AddSubscription adds a subscription to the subscription lists
// Unlike AddSubscriptions this method will error if the subscription already exists
func (w *Websocket) AddSubscription(c *subscription.Subscription) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// AddSubscriptions adds subscriptions to the subscription store
// Sets state to Subscribing unless the state is already set
func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
w.subscriptions = subscription.NewStore()
}
key := c.EnsureKeyed()
if _, ok := w.subscriptions[key]; ok {
return ErrSubscribedAlready
var errs error
for _, s := range subs {
if s.State() == subscription.InactiveState {
if err := s.SetState(subscription.SubscribingState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
n := *c // Fresh copy; we don't want to use the pointer we were given and allow encapsulation/locks to be bypassed
w.subscriptions[key] = &n
return nil
return errs
}
// SetSubscriptionState sets an existing subscription state
// returns an error if the subscription is not found, or the new state is already set
func (w *Websocket) SetSubscriptionState(c *subscription.Subscription, state subscription.State) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
w.subscriptions = subscription.NewStore()
}
key := c.EnsureKeyed()
p, ok := w.subscriptions[key]
if !ok {
return ErrSubscriptionNotFound
var errs error
for _, s := range subs {
if err := s.SetState(subscription.SubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
if state == p.State {
return ErrChannelInStateAlready
}
if state > subscription.UnsubscribingState {
return errInvalidChannelState
}
p.State = state
return nil
return errs
}
// AddSuccessfulSubscriptions adds subscriptions to the subscription lists that
// has been successfully subscribed
func (w *Websocket) AddSuccessfulSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
}
for _, cN := range channels { //nolint:gocritic // See below comment
c := cN // cN is an iteration var; Not safe to make a pointer to
key := c.EnsureKeyed()
c.State = subscription.SubscribedState
w.subscriptions[key] = &c
var errs error
for _, s := range subs {
if err := s.SetState(subscription.UnsubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Remove(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// RemoveSubscriptions removes subscriptions from the subscription list
func (w *Websocket) RemoveSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
}
for i := range channels {
key := channels[i].EnsureKeyed()
delete(w.subscriptions, key)
}
}
// GetSubscription returns a pointer to a copy of the subscription at the key provided
// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
func (w *Websocket) GetSubscription(key any) *subscription.Subscription {
if key == nil || w == nil || w.subscriptions == nil {
if w == nil || w.subscriptions == nil || key == nil {
return nil
}
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if s, ok := w.subscriptions[key]; ok {
c := *s
return &c
}
return nil
return w.subscriptions.Get(key)
}
// GetSubscriptions returns a new slice of the subscriptions
func (w *Websocket) GetSubscriptions() []subscription.Subscription {
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
subs := make([]subscription.Subscription, 0, len(w.subscriptions))
for _, c := range w.subscriptions {
subs = append(subs, *c)
func (w *Websocket) GetSubscriptions() subscription.List {
if w == nil || w.subscriptions == nil {
return nil
}
return subs
return w.subscriptions.List()
}
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
@@ -978,28 +938,25 @@ func checkWebsocketURL(s string) error {
return nil
}
// checkSubscriptions checks subscriptions against the max subscription limit
// and if the subscription already exists.
func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error {
if len(subs) == 0 {
return errNoSubscriptionsSupplied
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
// The subscription state is not considered when counting existing subscriptions
func (w *Websocket) checkSubscriptions(subs subscription.List) error {
if w.subscriptions == nil {
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
}
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection {
existing := w.subscriptions.Len()
if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
len(w.subscriptions),
existing,
len(subs),
w.MaxSubscriptionsPerConnection)
}
for i := range subs {
key := subs[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; ok {
return fmt.Errorf("%w for %+v", errChannelAlreadySubscribed, subs[i])
for _, s := range subs {
if found := w.subscriptions.Get(s); found != nil {
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
}
}

View File

@@ -90,25 +90,22 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header
// SendJSONMessage sends a JSON encoded message over the connection
func (w *WebsocketConnection) SendJSONMessage(data interface{}) error {
if !w.IsConnected() {
return fmt.Errorf("%s websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%s websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
log.Debugf(log.WebsocketMgr,
"%s websocket connection: sending message to websocket %+v\n",
w.ExchangeName,
data)
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%s websocket connection: sending message: %s\n", w.ExchangeName, msg)
}
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
}
return w.Connection.WriteJSON(data)
@@ -117,29 +114,23 @@ func (w *WebsocketConnection) SendJSONMessage(data interface{}) error {
// SendRawMessage sends a message over the connection without JSON encoding it
func (w *WebsocketConnection) SendRawMessage(messageType int, message []byte) error {
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket connection: sending message [%s]\n",
w.ExchangeName,
message)
log.Debugf(log.WebsocketMgr, "%v websocket connection: sending message [%s]\n", w.ExchangeName, message)
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
}
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
return w.Connection.WriteMessage(messageType, message)
}

View File

@@ -10,7 +10,6 @@ import (
"net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
@@ -20,6 +19,7 @@ import (
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
@@ -79,10 +79,10 @@ var defaultSetup = &WebsocketSetup{
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func([]subscription.Subscription) error { return nil },
Unsubscriber: func([]subscription.Subscription) error { return nil },
GenerateSubscriptions: func() ([]subscription.Subscription, error) {
return []subscription.Subscription{
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
@@ -147,16 +147,16 @@ func TestSetup(t *testing.T) {
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly")
websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Subscriber = func(subscription.List) error { return nil }
websocketSetup.Features.Unsubscribe = true
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly")
websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Unsubscriber = func(subscription.List) error { return nil }
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly")
websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil }
websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly")
@@ -193,14 +193,13 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
signal := struct{}{}
patience := 10 * time.Millisecond
ws.trafficTimeout = 200 * time.Millisecond
ws.ShutdownC = make(chan struct{})
ws.state.Store(connected)
ws.state.Store(connectedState)
thenish := time.Now()
ws.trafficMonitor()
assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connected, ws.state.Load(), "websocket must be connected")
require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected")
for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass
select {
@@ -226,7 +225,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
}
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts")
}
@@ -238,17 +237,16 @@ func TestTrafficMonitorConnecting(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.ShutdownC = make(chan struct{})
ws.state.Store(connecting)
ws.state.Store(connectingState)
ws.trafficTimeout = 50 * time.Millisecond
ws.trafficMonitor()
require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting")
require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting")
<-time.After(4 * ws.trafficTimeout)
require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connected)
require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connectedState)
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes")
}
@@ -260,8 +258,7 @@ func TestTrafficMonitorShutdown(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.ShutdownC = make(chan struct{})
ws.state.Store(connected)
ws.state.Store(connectedState)
ws.trafficTimeout = time.Minute
ws.trafficMonitor()
assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
@@ -307,12 +304,17 @@ func TestConnectionMessageErrors(t *testing.T) {
assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly")
wsWrong.setEnabled(true)
wsWrong.setState(connecting)
wsWrong.Wg = &sync.WaitGroup{}
wsWrong.setState(connectingState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly")
wsWrong.setState(disconnected)
wsWrong.setState(disconnectedState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error")
assert.ErrorContains(t, err, "subscriptions", "Connect should get a nil pointer error about subscriptions")
wsWrong.subscriptions = subscription.NewStore()
wsWrong.setState(disconnectedState)
wsWrong.connector = func() error { return errDastardlyReason }
err = wsWrong.Connect()
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
@@ -321,7 +323,7 @@ func TestConnectionMessageErrors(t *testing.T) {
err = ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.trafficTimeout = time.Minute
ws.connector = func() error { return nil }
ws.connector = connect
err = ws.Connect()
require.NoError(t, err, "Connect must not error")
@@ -381,7 +383,7 @@ func TestWebsocket(t *testing.T) {
err = ws.SetProxyAddress("https://192.168.0.1:1337")
assert.NoError(t, err, "SetProxyAddress should not error when not yet connected")
ws.setState(connected)
ws.setState(connectedState)
err = ws.SetProxyAddress("https://192.168.0.1:1336")
assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there")
@@ -404,14 +406,14 @@ func TestWebsocket(t *testing.T) {
assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly")
assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly")
ws.setState(connected)
ws.setState(connectedState)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn")
assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn")
ws.AuthConn = &WebsocketConnection{}
ws.setState(disconnected)
ws.setState(disconnectedState)
err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
@@ -446,34 +448,39 @@ func TestWebsocket(t *testing.T) {
ws.Wg.Wait()
}
func currySimpleSub(w *Websocket) func(subscription.List) error {
return func(subs subscription.List) error {
return w.AddSuccessfulSubscriptions(subs...)
}
}
func currySimpleUnsub(w *Websocket) func(subscription.List) error {
return func(unsubs subscription.List) error {
return w.RemoveSubscriptions(unsubs...)
}
}
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
assert.NoError(t, err, "Generating test subscriptions should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed")
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(subs), "Should not error when w.subscriptions is nil")
assert.NoError(t, ws.UnsubscribeChannels(nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
byDefKey := ws.GetSubscription(subscription.DefaultKey{Channel: "TestSub"})
if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") {
assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
@@ -486,9 +493,15 @@ func TestSubscribeUnsubscribe(t *testing.T) {
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), errChannelAlreadySubscribed, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, ws.SubscribeToChannels(nil), "Subscribe to an nil List should not error")
assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error")
ws.Subscriber = func(subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, ws.SubscribeToChannels(subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = ws.SubscribeToChannels(subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
}
// TestResubscribe tests Resubscribing to existing subscriptions
@@ -504,61 +517,56 @@ func TestResubscribe(t *testing.T) {
err = ws.Setup(defaultSetup)
assert.NoError(t, err, "WS Setup should not error")
fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
channel := []subscription.Subscription{{Channel: "resubTest"}}
channel := subscription.List{{Channel: "resubTest"}}
assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed")
assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed")
}
// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
// TestSubscriptions tests adding, getting and removing subscriptions
func TestSubscriptions(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Websocket)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error")
assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSubscriptions(s), subscription.ErrDuplicate, "Adding same subscription should return error")
assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing")
c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")
err := w.RemoveSubscriptions(s)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed")
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
found := ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing")
assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error")
ws.AddSuccessfulSubscriptions(*c)
found = ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribedState, found.State, "Subscription should be subscribed state")
assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error")
found = ws.GetSubscription(42)
assert.Equal(t, subscription.UnsubscribingState, found.State, "Subscription should be unsubscribing state")
require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error")
require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error")
assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state")
}
// TestRemoveSubscriptions tests removing a subscription
func TestRemoveSubscriptions(t *testing.T) {
// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions
func TestSuccessfulSubscriptions(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSuccessfulSubscriptions(c), "Adding first subscription should not error")
assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrInStateAlready, "Adding subscription in same state should return error")
require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrDuplicate, "Adding same subscription should return error")
c := &subscription.Subscription{Key: 42, Channel: "Unite!"}
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable")
ws.RemoveSubscriptions(*c)
assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub")
err := w.RemoveSubscriptions(c)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.ErrorIs(t, w.RemoveSubscriptions(c), subscription.ErrNotFound, "Should error correctly when not found")
assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
w.subscriptions = nil
assert.ErrorIs(t, w.RemoveSubscriptions(c), common.ErrNilPointer, "Should error correctly when nil websocket")
}
// TestConnectionMonitorNoConnection logic test
@@ -566,10 +574,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
ws.exchangeName = "hello"
ws.Wg = &sync.WaitGroup{}
ws.setEnabled(true)
err := ws.connectionMonitor()
require.NoError(t, err, "connectionMonitor must not error")
@@ -582,32 +587,27 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil")
w := Websocket{
subscriptions: subscriptionMap{
42: {
Channel: "hello3",
},
},
}
assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil")
c := w.GetSubscription(42)
if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") {
assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details")
}
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := NewWebsocket()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
require.NoError(t, w.AddSubscriptions(s), "AddSubscriptions must not error")
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}
// TestGetSubscriptions logic test
func TestGetSubscriptions(t *testing.T) {
t.Parallel()
w := Websocket{
subscriptions: subscriptionMap{
42: {
Channel: "hello3",
},
},
assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
w := NewWebsocket()
s := subscription.List{
{Key: 42, Channel: "hello3"},
{Key: 45, Channel: "hello4"},
}
assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details")
err := w.AddSubscriptions(s...)
require.NoError(t, err, "AddSubscriptions must not error")
assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details")
}
// TestSetCanUseAuthenticatedEndpoints logic test
@@ -883,7 +883,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
ws := &Websocket{}
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")
ws.setState(connected)
ws.setState(connectedState)
require.True(t, ws.IsConnected(), "IsConnected must return true")
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")
@@ -939,77 +939,42 @@ func TestCheckWebsocketURL(t *testing.T) {
assert.NoError(t, err, "checkWebsocketURL should not error")
}
// TestGetChannelDifference exercises GetChannelDifference
// See subscription.TestStoreDiff for further testing
func TestGetChannelDifference(t *testing.T) {
t.Parallel()
web := Websocket{}
newChans := []subscription.Subscription{
{
Channel: "Test1",
},
{
Channel: "Test2",
},
{
Channel: "Test3",
},
}
subs, unsubs := web.GetChannelDifference(newChans)
assert.Len(t, subs, 3, "Should get the correct number of subs")
assert.Empty(t, unsubs, "Should get the correct number of unsubs")
web.AddSuccessfulSubscriptions(subs...)
flushedSubs := []subscription.Subscription{
{
Channel: "Test2",
},
}
subs, unsubs = web.GetChannelDifference(flushedSubs)
assert.Empty(t, subs, "Should get the correct number of subs")
assert.Len(t, unsubs, 2, "Should get the correct number of unsubs")
flushedSubs = []subscription.Subscription{
{
Channel: "Test2",
},
{
Channel: "Test4",
},
}
subs, unsubs = web.GetChannelDifference(flushedSubs)
if assert.Len(t, subs, 1, "Should get the correct number of subs") {
assert.Equal(t, "Test4", subs[0].Channel, "Should subscribe to the right channel")
}
if assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") {
sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel })
assert.Equal(t, "Test1", unsubs[0].Channel, "Should unsubscribe from the right channels")
assert.Equal(t, "Test3", unsubs[1].Channel, "Should unsubscribe from the right channels")
}
w := &Websocket{}
assert.NotPanics(t, func() { w.GetChannelDifference(subscription.List{}) }, "Should not panic when called without a store")
subs, unsubs := w.GetChannelDifference(subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
require.Empty(t, unsubs, "Should get no unsubs")
require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error")
subs, unsubs = w.GetChannelDifference(subscription.List{{Channel: subscription.TickerChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
}
// GenSubs defines a theoretical exchange with pair management
type GenSubs struct {
EnabledPairs currency.Pairs
subscribos []subscription.Subscription
unsubscribos []subscription.Subscription
subscribos subscription.List
unsubscribos subscription.List
}
// generateSubs default subs created from the enabled pairs list
func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) {
superduperchannelsubs := make([]subscription.Subscription, len(g.EnabledPairs))
func (g *GenSubs) generateSubs() (subscription.List, error) {
superduperchannelsubs := make(subscription.List, len(g.EnabledPairs))
for i := range g.EnabledPairs {
superduperchannelsubs[i] = subscription.Subscription{
superduperchannelsubs[i] = &subscription.Subscription{
Channel: "TEST:" + strconv.FormatInt(int64(i), 10),
Pair: g.EnabledPairs[i],
Pairs: currency.Pairs{g.EnabledPairs[i]},
}
}
return superduperchannelsubs, nil
}
func (g *GenSubs) SUBME(subs []subscription.Subscription) error {
func (g *GenSubs) SUBME(subs subscription.List) error {
if len(subs) == 0 {
return errors.New("WOW")
}
@@ -1017,7 +982,7 @@ func (g *GenSubs) SUBME(subs []subscription.Subscription) error {
return nil
}
func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error {
func (g *GenSubs) UNSUBME(unsubs subscription.List) error {
if len(unsubs) == 0 {
return errors.New("WOW")
}
@@ -1044,87 +1009,70 @@ func TestFlushChannels(t *testing.T) {
err = dodgyWs.FlushChannels()
assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly")
w := Websocket{
connector: connect,
ShutdownC: make(chan struct{}),
Subscriber: newgen.SUBME,
Unsubscriber: newgen.UNSUBME,
Wg: new(sync.WaitGroup),
features: &protocol.Features{
// No features
},
trafficTimeout: time.Second * 30, // Added for when we utilise connect()
// in FlushChannels() so the traffic monitor doesn't time out and turn
// this to an unconnected state
}
w.setEnabled(true)
w.setState(connected)
w := NewWebsocket()
w.connector = connect
w.Subscriber = newgen.SUBME
w.Unsubscriber = newgen.UNSUBME
// Added for when we utilise connect() in FlushChannels() so the traffic monitor doesn't time out and turn this to an unconnected state
w.trafficTimeout = time.Second * 30
problemFunc := func() ([]subscription.Subscription, error) {
w.setEnabled(true)
w.setState(connectedState)
problemFunc := func() (subscription.List, error) {
return nil, errDastardlyReason
}
noSub := func() ([]subscription.Subscription, error) {
noSub := func() (subscription.List, error) {
return nil, nil
}
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{
currency.NewPair(currency.BTC, currency.AUD)}
w.GenerateSubs = func() ([]subscription.Subscription, error) {
return []subscription.Subscription{{Channel: "test"}}, nil
w.GenerateSubs = func() (subscription.List, error) {
return subscription.List{{Channel: "test"}}, nil
}
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
require.NoError(t, err, "Flush Channels must not error")
w.features.FullPayloadSubscribe = true
w.GenerateSubs = problemFunc
err = w.FlushChannels() // error on full subscribeToChannels
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
w.GenerateSubs = noSub
err = w.FlushChannels() // No subs to unsub
assert.NoError(t, err, "FlushChannels should not error")
err = w.FlushChannels() // No subs to sub
assert.NoError(t, err, "Flush Channels should not error")
w.GenerateSubs = newgen.generateSubs
subs, err := w.GenerateSubs()
require.NoError(t, err, "GenerateSubs must not error")
w.AddSuccessfulSubscriptions(subs...)
require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.features.FullPayloadSubscribe = false
w.features.Subscribe = true
w.GenerateSubs = problemFunc
err = w.FlushChannels()
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")
w.GenerateSubs = newgen.generateSubs
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{
41: {
Key: 41,
Channel: "match channel",
Pair: currency.NewPair(currency.BTC, currency.AUD),
},
42: {
Key: 42,
Channel: "unsub channel",
Pair: currency.NewPair(currency.THETA, currency.USDT),
},
}
w.subscriptionMutex.Unlock()
w.subscriptions = subscription.NewStore()
err = w.subscriptions.Add(&subscription.Subscription{
Key: 41,
Channel: "match channel",
Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)},
})
require.NoError(t, err, "AddSubscription must not error")
err = w.subscriptions.Add(&subscription.Subscription{
Key: 42,
Channel: "unsub channel",
Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)},
})
require.NoError(t, err, "AddSubscription must not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.setState(connected)
w.setState(connectedState)
w.features.Unsubscribe = true
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
@@ -1132,27 +1080,20 @@ func TestFlushChannels(t *testing.T) {
func TestDisable(t *testing.T) {
t.Parallel()
w := Websocket{
ShutdownC: make(chan struct{}),
}
w := NewWebsocket()
w.setEnabled(true)
w.setState(connected)
w.setState(connectedState)
require.NoError(t, w.Disable(), "Disable must not error")
assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly")
}
func TestEnable(t *testing.T) {
t.Parallel()
w := Websocket{
connector: connect,
Wg: new(sync.WaitGroup),
ShutdownC: make(chan struct{}),
GenerateSubs: func() ([]subscription.Subscription, error) {
return []subscription.Subscription{{Channel: "test"}}, nil
},
Subscriber: func([]subscription.Subscription) error { return nil },
}
w := NewWebsocket()
w.connector = connect
w.Subscriber = func(subscription.List) error { return nil }
w.Unsubscriber = func(subscription.List) error { return nil }
w.GenerateSubs = func() (subscription.List, error) { return nil, nil }
require.NoError(t, w.Enable(), "Enable must not error")
assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly")
}
@@ -1259,19 +1200,30 @@ func TestCheckSubscriptions(t *testing.T) {
t.Parallel()
ws := Websocket{}
err := ws.checkSubscriptions(nil)
assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly")
assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions")
assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil)
assert.NoError(t, err, "checkSubscriptions should not error on a nil list")
ws.MaxSubscriptionsPerConnection = 1
err = ws.checkSubscriptions([]subscription.Subscription{{}, {}})
err = ws.checkSubscriptions(subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(subscription.List{{}, {}})
assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly")
ws.MaxSubscriptionsPerConnection = 2
ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}}
err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly")
ws.subscriptions = subscription.NewStore()
err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"})
require.NoError(t, err, "Add subscription must not error")
err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly")
err = ws.checkSubscriptions([]subscription.Subscription{{}})
err = ws.checkSubscriptions(subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error")
}

View File

@@ -22,13 +22,11 @@ const (
UnhandledMessage = " - Unhandled websocket message: "
)
type subscriptionMap map[any]*subscription.Subscription
const (
uninitialised uint32 = iota
disconnected
connecting
connected
uninitialisedState uint32 = iota
disconnectedState
connectingState
connectedState
)
// Websocket defines a return type for websocket connections via the interface
@@ -52,20 +50,14 @@ type Websocket struct {
m sync.Mutex
connector func() error
subscriptionMutex sync.RWMutex
subscriptions subscriptionMap
Subscribe chan []subscription.Subscription
Unsubscribe chan []subscription.Subscription
subscriptions *subscription.Store
// Subscriber function for package defined websocket subscriber
// functionality
Subscriber func([]subscription.Subscription) error
// Unsubscriber function for packaged defined websocket unsubscriber
// functionality
Unsubscriber func([]subscription.Subscription) error
// GenerateSubs function for package defined websocket generate
// subscriptions functionality
GenerateSubs func() ([]subscription.Subscription, error)
// Subscriber function for exchange specific subscribe implementation
Subscriber func(subscription.List) error
// Subscriber function for exchange specific unsubscribe implementation
Unsubscriber func(subscription.List) error
// GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets
GenerateSubs func() (subscription.List, error)
DataHandler chan interface{}
ToRoutine chan interface{}
@@ -74,7 +66,7 @@ type Websocket struct {
// shutdown synchronises shutdown event across routines
ShutdownC chan struct{}
Wg *sync.WaitGroup
Wg sync.WaitGroup
// Orderbook is a local buffer of orderbooks
Orderbook buffer.Orderbook
@@ -112,9 +104,9 @@ type WebsocketSetup struct {
RunningURL string
RunningURLAuth string
Connector func() error
Subscriber func([]subscription.Subscription) error
Unsubscriber func([]subscription.Subscription) error
GenerateSubscriptions func() ([]subscription.Subscription, error)
Subscriber func(subscription.List) error
Unsubscriber func(subscription.List) error
GenerateSubscriptions func() (subscription.List, error)
Features *protocol.Features
// Local orderbook buffer config values