Okx: Integrate websocket book5 processing and add configurable max subscriptions per connection (#1275)

* okx: books 5 (cherry-pick)

* okx: shift types to types file, remove commented code and updated field name to better reflect pushed type

* linter: fix

* remove slowness

* * Introduce function checksubscriptions and shift check of subscriptions to internal websocket package
* Shift Max websocket connection int to Websocket setup (temp) for this use case only.

* glorious: nits

* linter: fix

* websocket: don't try and subscribed with nothing to subscribe to.

* Update exchanges/stream/websocket_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* glorious: nits

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2023-10-13 15:54:49 +11:00
committed by GitHub
parent 859c4512fb
commit 773441d5a7
9 changed files with 274 additions and 108 deletions

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/log"
)
@@ -48,6 +49,10 @@ var (
errWebsocketConnectorUnset = errors.New("websocket connector function not set")
errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set")
errClosedConnection = errors.New("use of closed network connection")
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
errNoSubscriptionsSupplied = errors.New("no subscriptions supplied")
errChannelSubscriptionAlreadySubscribed = errors.New("channel subscription already subscribed")
)
var globalReporter Reporter
@@ -167,6 +172,11 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
w.Trade.Setup(w.exchangeName, s.TradeFeed, w.DataHandler)
w.Fills.Setup(s.FillsFeed, w.DataHandler)
if s.MaxWebsocketSubscriptionsPerConnection < 0 {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
return nil
}
@@ -275,11 +285,18 @@ func (w *Websocket) Connect() error {
subs, err := w.GenerateSubs() // regenerate state on new connection
if err != nil {
return fmt.Errorf("%v %w: %v", w.exchangeName, ErrSubscriptionFailure, err)
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) == 0 {
return nil
}
err = w.checkSubscriptions(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
err = w.Subscriber(subs)
if err != nil {
return fmt.Errorf("%v %w: %v", w.exchangeName, ErrSubscriptionFailure, err)
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
return nil
}
@@ -905,24 +922,13 @@ func (w *Websocket) ResubscribeToChannel(subscribedChannel *ChannelSubscription)
// SubscribeToChannels appends supplied channels to channelsToSubscribe
func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: cannot subscribe no channels supplied",
w.exchangeName)
err := w.checkSubscriptions(channels)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
w.subscriptionMutex.Lock()
for x := range channels {
for y := range w.subscriptions {
if channels[x].Equal(&w.subscriptions[y]) {
w.subscriptionMutex.Unlock()
return fmt.Errorf("%s websocket: %v already subscribed",
w.exchangeName,
channels[x])
}
}
}
w.subscriptionMutex.Unlock()
if err := w.Subscriber(channels); err != nil {
return fmt.Errorf("%v %w: %v", w.exchangeName, ErrSubscriptionFailure, err)
err = w.Subscriber(channels)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
return nil
}
@@ -1004,3 +1010,31 @@ func checkWebsocketURL(s string) error {
}
return nil
}
// checkSubscriptions checks subscriptions against the max subscription limit
// and if the subscription already exists.
func (w *Websocket) checkSubscriptions(subs []ChannelSubscription) error {
if len(subs) == 0 {
return errNoSubscriptionsSupplied
}
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
len(w.subscriptions),
len(subs),
w.MaxSubscriptionsPerConnection)
}
for x := range subs {
for y := range w.subscriptions {
if subs[x].Equal(&w.subscriptions[y]) {
return fmt.Errorf("%w for %+v", errChannelSubscriptionAlreadySubscribed, subs[x])
}
}
}
return nil
}

View File

@@ -554,7 +554,15 @@ func TestSubscribeUnsubscribe(t *testing.T) {
func TestResubscribe(t *testing.T) {
t.Parallel()
ws := *New()
err := ws.Setup(defaultSetup)
wackedOutSetup := *defaultSetup
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
err := ws.Setup(&wackedOutSetup)
if !errors.Is(err, errInvalidMaxSubscriptions) {
t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidMaxSubscriptions)
}
err = ws.Setup(defaultSetup)
if err != nil {
t.Fatal(err)
}
@@ -1390,3 +1398,32 @@ func TestLatency(t *testing.T) {
t.Errorf("expected %v, got %v", exch, r.name)
}
}
func TestCheckSubscriptions(t *testing.T) {
t.Parallel()
ws := Websocket{}
err := ws.checkSubscriptions(nil)
if !errors.Is(err, errNoSubscriptionsSupplied) {
t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied)
}
ws.MaxSubscriptionsPerConnection = 1
err = ws.checkSubscriptions([]ChannelSubscription{{}, {}})
if !errors.Is(err, errSubscriptionsExceedsLimit) {
t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit)
}
ws.MaxSubscriptionsPerConnection = 2
ws.subscriptions = []ChannelSubscription{{Channel: "test"}}
err = ws.checkSubscriptions([]ChannelSubscription{{Channel: "test"}})
if !errors.Is(err, errChannelSubscriptionAlreadySubscribed) {
t.Fatalf("received: %v, but expected: %v", err, errChannelSubscriptionAlreadySubscribed)
}
err = ws.checkSubscriptions([]ChannelSubscription{{}})
if !errors.Is(err, nil) {
t.Fatalf("received: %v, but expected: %v", err, nil)
}
}

View File

@@ -93,6 +93,10 @@ type Websocket struct {
// Latency reporter
ExchangeLevelReporter Reporter
// MaxSubScriptionsPerConnection defines the maximum number of
// subscriptions per connection that is allowed by the exchange.
MaxSubscriptionsPerConnection int
}
// WebsocketSetup defines variables for setting up a websocket connection
@@ -114,6 +118,10 @@ type WebsocketSetup struct {
// Fill data config values
FillsFeed bool
// MaxWebsocketSubscriptionsPerConnection defines the maximum number of
// subscriptions per connection that is allowed by the exchange.
MaxWebsocketSubscriptionsPerConnection int
}
// WebsocketConnection contains all the data needed to send a message to a WS