Files
gocryptotrader/exchanges/stream/websocket.go
Gareth Kirwan 1199f38546 subscriptions: Encapsulate, replace Pair with Pairs and refactor; improve exchange support
* Websocket: Use ErrSubscribedAlready

instead of errChannelAlreadySubscribed

* Subscriptions: Replace Pair with Pairs

Given that some subscriptions have multiple pairs, support that as the
standard.

* Docs: Update subscriptions in add new exch

* RPC: Update Subscription Pairs

* Linter: Disable testifylint.Len

We deliberately use Equal over Len to avoid spamming the contents of large Slices

* Websocket: Add suffix to state consts

* Binance: Subscription Pairs support

* Bitfinex: Subscription Pairs support

* Bithumb: Subscription Pairs support

* Bitmex: Subscription Pairs support

* Bitstamp: Subscription Pairs support

* BTCMarkets: Subscription Pairs support

* BTSE: Subscription Pairs support

* Coinbase: Subscription Pairs support

* Coinut: Subscription Pairs support

* GateIO: Subscription Pairs support

* Gemini: Subscription Pairs support and improvement

* Hitbtc: Subscription Pairs support

* Huboi: Subscription Pairs support

* Kucoin: Subscription Pairs support

* Okcoin: Subscription Pairs support

* Poloniex: Subscription Pairs support

* Kraken: Add subscription Pairs support

Note: This is a naieve implementation because we want to rebase the
kraken websocket rewrite on top of this

* Bybit: Subscription Pairs support

* Okx: Subscription Pairs support

* Bitmex: Subsription configuration

* Fixes unauthenticated websocket left as CanUseAuth
* Fixes auth subs happening privately

* CoinbasePro: Subscription Configuration

* Consolidate ProductIDs when all subscriptions are for the same list

* Websocket: Log actual sent message when Verbose

* Subscriptions: Improve clarity of which key is which in Match

* Subscriptions: Lint fix for HugeParam

* Subscriptions: Add AddPairs and move keys from test

* Subscriptions: Simplify subscription keys and add key types

* Subscriptions: Add List.GroupPairs Rename sub.AddPairs

* Subscription: Fix ExactKey not matching 0 pairs

* Subscriptions: Remove unused IdentityKey and HasPairKey

* Subscriptions: Fix GetKey test

* Subscriptions: Test coverage improvements

* Websocket: Change State on Add/Remove

* Subscriptions: Improve error context

* Subscriptions: Fix Enable: false subs not ignored

* Bitfinex: Fix WsAuth test failing on DataHandler

DataHandler is eaten by dataMonitor now, so we need to use ToRoutine

* Deribit: Subscription Pairs support

* Websocket: Accept nil lists for checkSubscriptions

If the user passes in a nil (implicitly empty) list, we would not panic.
Therefore the burden of correctness about that data lies with them.
The list of subscriptions is empty, and that's okay, and possibly
convenient

* Websocket: Add context to NilPointer errors

* Subscriptions: Add context to nil errors

* Exchange: Fix error expectations in UnsubToWSChans
2024-06-07 11:54:08 +10:00

965 lines
28 KiB
Go

package stream
import (
"errors"
"fmt"
"net"
"net/url"
"slices"
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/log"
)
const (
jobBuffer = 5000
)
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
)
// Private websocket errors
var (
errAlreadyRunning = errors.New("connection monitor is already running")
errExchangeConfigIsNil = errors.New("exchange config is nil")
errExchangeConfigEmpty = errors.New("exchange config is empty")
errWebsocketIsNil = errors.New("websocket is nil")
errWebsocketSetupIsNil = errors.New("websocket setup is nil")
errWebsocketAlreadyInitialised = errors.New("websocket already initialised")
errWebsocketAlreadyEnabled = errors.New("websocket already enabled")
errWebsocketFeaturesIsUnset = errors.New("websocket features is unset")
errConfigFeaturesIsNil = errors.New("exchange config features is nil")
errDefaultURLIsEmpty = errors.New("default url is empty")
errRunningURLIsEmpty = errors.New("running url cannot be empty")
errInvalidWebsocketURL = errors.New("invalid websocket url")
errExchangeConfigNameEmpty = errors.New("exchange config name empty")
errInvalidTrafficTimeout = errors.New("invalid traffic timeout")
errTrafficAlertNil = errors.New("traffic alert is nil")
errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set")
errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set")
errWebsocketConnectorUnset = errors.New("websocket connector function not set")
errReadMessageErrorsNil = errors.New("read message errors is nil")
errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set")
errClosedConnection = errors.New("use of closed network connection")
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("websocket connect func not set")
errAlreadyConnected = errors.New("websocket already connected")
errCannotShutdown = errors.New("websocket cannot shutdown")
errAlreadyReconnecting = errors.New("websocket in the process of reconnection")
errConnSetup = errors.New("error in connection setup")
)
var (
globalReporter Reporter
trafficCheckInterval = 100 * time.Millisecond
)
// SetupGlobalReporter sets a reporter interface to be used
// for all exchange requests
func SetupGlobalReporter(r Reporter) {
globalReporter = r
}
// NewWebsocket initialises the websocket struct
func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, jobBuffer),
ToRoutine: make(chan interface{}, jobBuffer),
ShutdownC: make(chan struct{}),
TrafficAlert: make(chan struct{}, 1),
ReadMessageErrors: make(chan error),
Match: NewMatch(),
subscriptions: subscription.NewStore(),
features: &protocol.Features{},
Orderbook: buffer.Orderbook{},
}
}
// Setup sets main variables for websocket connection
func (w *Websocket) Setup(s *WebsocketSetup) error {
if w == nil {
return errWebsocketIsNil
}
if s == nil {
return errWebsocketSetupIsNil
}
w.m.Lock()
defer w.m.Unlock()
if w.IsInitialised() {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised)
}
if s.ExchangeConfig == nil {
return errExchangeConfigIsNil
}
if s.ExchangeConfig.Name == "" {
return errExchangeConfigNameEmpty
}
w.exchangeName = s.ExchangeConfig.Name
w.verbose = s.ExchangeConfig.Verbose
if s.Features == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketFeaturesIsUnset)
}
w.features = s.Features
if s.ExchangeConfig.Features == nil {
return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil)
}
w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket)
if s.Connector == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset)
}
w.connector = s.Connector
if s.Subscriber == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriberUnset)
}
w.Subscriber = s.Subscriber
if w.features.Unsubscribe && s.Unsubscriber == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketUnsubscriberUnset)
}
w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay
if w.connectionMonitorDelay <= 0 {
w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay
}
w.Unsubscriber = s.Unsubscriber
if s.GenerateSubscriptions == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriptionsGeneratorUnset)
}
w.GenerateSubs = s.GenerateSubscriptions
if s.DefaultURL == "" {
return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty)
}
w.defaultURL = s.DefaultURL
if s.RunningURL == "" {
return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty)
}
err := w.SetWebsocketURL(s.RunningURL, false, false)
if err != nil {
return fmt.Errorf("%s %w", w.exchangeName, err)
}
if s.RunningURLAuth != "" {
err = w.SetWebsocketURL(s.RunningURLAuth, true, false)
if err != nil {
return fmt.Errorf("%s %w", w.exchangeName, err)
}
}
if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second {
return fmt.Errorf("%s %w cannot be less than %s",
w.exchangeName,
errInvalidTrafficTimeout,
time.Second)
}
w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout
w.ShutdownC = make(chan struct{})
w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport)
if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil {
return err
}
w.Trade.Setup(w.exchangeName, s.TradeFeed, w.DataHandler)
w.Fills.Setup(s.FillsFeed, w.DataHandler)
if s.MaxWebsocketSubscriptionsPerConnection < 0 {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnectedState)
return nil
}
// SetupNewConnection sets up an auth or unauth streaming connection
func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
if w == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil)
}
if c == (ConnectionSetup{}) {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty)
}
if w.exchangeName == "" {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty)
}
if w.TrafficAlert == nil {
return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil)
}
if w.ReadMessageErrors == nil {
return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil)
}
connectionURL := w.GetWebsocketURL()
if c.URL != "" {
connectionURL = c.URL
}
if c.ConnectionLevelReporter == nil {
c.ConnectionLevelReporter = w.ExchangeLevelReporter
}
if c.ConnectionLevelReporter == nil {
c.ConnectionLevelReporter = globalReporter
}
newConn := &WebsocketConnection{
ExchangeName: w.exchangeName,
URL: connectionURL,
ProxyURL: w.GetProxyAddress(),
Verbose: w.verbose,
ResponseMaxLimit: c.ResponseMaxLimit,
Traffic: w.TrafficAlert,
readMessageErrors: w.ReadMessageErrors,
ShutdownC: w.ShutdownC,
Wg: &w.Wg,
Match: w.Match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
}
if c.Authenticated {
w.AuthConn = newConn
} else {
w.Conn = newConn
}
return nil
}
// Connect initiates a websocket connection by using a package defined connection
// function
func (w *Websocket) Connect() error {
if w.connector == nil {
return errNoConnectFunc
}
w.m.Lock()
defer w.m.Unlock()
if !w.IsEnabled() {
return ErrWebsocketNotEnabled
}
if w.IsConnecting() {
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting)
}
if w.IsConnected() {
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected)
}
if w.subscriptions == nil {
return fmt.Errorf("%w: subscriptions", common.ErrNilPointer)
}
w.subscriptions.Clear()
w.dataMonitor()
w.trafficMonitor()
w.setState(connectingState)
err := w.connector()
if err != nil {
w.setState(disconnectedState)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setState(connectedState)
if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
if err != nil {
log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err)
}
}
subs, err := w.GenerateSubs() // regenerate state on new connection
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) != 0 {
if err := w.SubscribeToChannels(subs); err != nil {
return err
}
}
return nil
}
// Disable disables the exchange websocket protocol
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
func (w *Websocket) Disable() error {
if !w.IsEnabled() {
return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled)
}
w.setEnabled(false)
return nil
}
// Enable enables the exchange websocket protocol
func (w *Websocket) Enable() error {
if w.IsConnected() || w.IsEnabled() {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled)
}
w.setEnabled(true)
return w.Connect()
}
// dataMonitor monitors job throughput and logs if there is a back log of data
func (w *Websocket) dataMonitor() {
if w.IsDataMonitorRunning() {
return
}
w.setDataMonitorRunning(true)
w.Wg.Add(1)
go func() {
defer func() {
w.setDataMonitorRunning(false)
w.Wg.Done()
}()
dropped := 0
for {
select {
case <-w.ShutdownC:
return
case d := <-w.DataHandler:
select {
case w.ToRoutine <- d:
if dropped != 0 {
log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped)
dropped = 0
}
default:
if dropped == 0 {
// If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible
log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName)
}
dropped++
}
}
}
}()
}
// connectionMonitor ensures that the WS keeps connecting
func (w *Websocket) connectionMonitor() error {
if w.checkAndSetMonitorRunning() {
return errAlreadyRunning
}
delay := w.connectionMonitorDelay
go func() {
timer := time.NewTimer(delay)
for {
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName)
}
if !w.IsEnabled() {
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName)
}
if w.IsConnected() {
if err := w.Shutdown(); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName)
}
timer.Stop()
w.setConnectionMonitorRunning(false)
return
}
select {
case err := <-w.ReadMessageErrors:
w.DataHandler <- err
if IsDisconnectionError(err) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
if w.IsConnected() {
if shutdownErr := w.Shutdown(); shutdownErr != nil {
log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr)
}
}
}
case <-timer.C:
if !w.IsConnecting() && !w.IsConnected() {
err := w.Connect()
if err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(delay)
}
}
}()
return nil
}
// Shutdown attempts to shut down a websocket connection and associated routines
// by using a package defined shutdown function
func (w *Websocket) Shutdown() error {
w.m.Lock()
defer w.m.Unlock()
if !w.IsConnected() {
return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected)
}
// TODO: Interrupt connection and or close connection when it is re-established.
if w.IsConnecting() {
return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting)
}
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName)
}
defer w.Orderbook.FlushBuffer()
if w.Conn != nil {
if err := w.Conn.Shutdown(); err != nil {
return err
}
}
if w.AuthConn != nil {
if err := w.AuthConn.Shutdown(); err != nil {
return err
}
}
// flush any subscriptions from last connection if needed
w.subscriptions.Clear()
w.setState(disconnectedState)
close(w.ShutdownC)
w.Wg.Wait()
w.ShutdownC = make(chan struct{})
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName)
}
return nil
}
// FlushChannels flushes channel subscriptions when there is a pair/asset change
func (w *Websocket) FlushChannels() error {
if !w.IsEnabled() {
return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled)
}
if !w.IsConnected() {
return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected)
}
if w.features.Subscribe {
newsubs, err := w.GenerateSubs()
if err != nil {
return err
}
subs, unsubs := w.GetChannelDifference(newsubs)
if w.features.Unsubscribe {
if len(unsubs) != 0 {
err := w.UnsubscribeChannels(unsubs)
if err != nil {
return err
}
}
}
if len(subs) < 1 {
return nil
}
return w.SubscribeToChannels(subs)
} else if w.features.FullPayloadSubscribe {
// FullPayloadSubscribe means that the endpoint requires all
// subscriptions to be sent via the websocket connection e.g. if you are
// subscribed to ticker and orderbook but require trades as well, you
// would need to send ticker, orderbook and trades channel subscription
// messages.
newsubs, err := w.GenerateSubs()
if err != nil {
return err
}
if len(newsubs) != 0 {
// Purge subscription list as there will be conflicts
w.subscriptions.Clear()
return w.SubscribeToChannels(newsubs)
}
return nil
}
if err := w.Shutdown(); err != nil {
return err
}
return w.Connect()
}
// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert
// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic
// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again
func (w *Websocket) trafficMonitor() {
if w.IsTrafficMonitorRunning() {
return
}
w.setTrafficMonitorRunning(true)
w.Wg.Add(1)
go func() {
t := time.NewTimer(w.trafficTimeout)
for {
select {
case <-w.ShutdownC:
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName)
}
t.Stop()
w.setTrafficMonitorRunning(false)
w.Wg.Done()
return
case <-time.After(trafficCheckInterval):
select {
case <-w.TrafficAlert:
if !t.Stop() {
<-t.C
}
t.Reset(w.trafficTimeout)
default:
}
case <-t.C:
checkAgain := w.IsConnecting()
select {
case <-w.TrafficAlert:
checkAgain = true
default:
}
if checkAgain {
t.Reset(w.trafficTimeout)
break
}
if w.verbose {
log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout)
}
w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call
w.Wg.Done() // Without this the w.Shutdown() call below will deadlock
if w.IsConnected() {
err := w.Shutdown()
if err != nil {
log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err)
}
}
return
}
}
}()
}
func (w *Websocket) setState(s uint32) {
w.state.Store(s)
}
// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
return w.state.Load() != uninitialisedState
}
// IsConnected returns whether the websocket is connected
func (w *Websocket) IsConnected() bool {
return w.state.Load() == connectedState
}
// IsConnecting returns whether the websocket is connecting
func (w *Websocket) IsConnecting() bool {
return w.state.Load() == connectingState
}
func (w *Websocket) setEnabled(b bool) {
w.enabled.Store(b)
}
// IsEnabled returns whether the websocket is enabled
func (w *Websocket) IsEnabled() bool {
return w.enabled.Load()
}
func (w *Websocket) setTrafficMonitorRunning(b bool) {
w.trafficMonitorRunning.Store(b)
}
// IsTrafficMonitorRunning returns status of the traffic monitor
func (w *Websocket) IsTrafficMonitorRunning() bool {
return w.trafficMonitorRunning.Load()
}
func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) {
return !w.connectionMonitorRunning.CompareAndSwap(false, true)
}
func (w *Websocket) setConnectionMonitorRunning(b bool) {
w.connectionMonitorRunning.Store(b)
}
// IsConnectionMonitorRunning returns status of connection monitor
func (w *Websocket) IsConnectionMonitorRunning() bool {
return w.connectionMonitorRunning.Load()
}
func (w *Websocket) setDataMonitorRunning(b bool) {
w.dataMonitorRunning.Store(b)
}
// IsDataMonitorRunning returns status of data monitor
func (w *Websocket) IsDataMonitorRunning() bool {
return w.dataMonitorRunning.Load()
}
// CanUseAuthenticatedWebsocketForWrapper Handles a common check to
// verify whether a wrapper can use an authenticated websocket endpoint
func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool {
if w.IsConnected() {
if w.CanUseAuthenticatedEndpoints() {
return true
}
log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName)
}
return false
}
// SetWebsocketURL sets websocket URL and can refresh underlying connections
func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error {
defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage
if auth {
if defaultVals {
url = w.defaultURLAuth
}
err := checkWebsocketURL(url)
if err != nil {
return err
}
w.runningURLAuth = url
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%s websocket: setting authenticated websocket URL: %s\n",
w.exchangeName,
url)
}
if w.AuthConn != nil {
w.AuthConn.SetURL(url)
}
} else {
if defaultVals {
url = w.defaultURL
}
err := checkWebsocketURL(url)
if err != nil {
return err
}
w.runningURL = url
if w.verbose {
log.Debugf(log.WebsocketMgr,
"%s websocket: setting unauthenticated websocket URL: %s\n",
w.exchangeName,
url)
}
if w.Conn != nil {
w.Conn.SetURL(url)
}
}
if w.IsConnected() && reconnect {
log.Debugf(log.WebsocketMgr,
"%s websocket: flushing websocket connection to %s\n",
w.exchangeName,
url)
return w.Shutdown()
}
return nil
}
// GetWebsocketURL returns the running websocket URL
func (w *Websocket) GetWebsocketURL() string {
return w.runningURL
}
// SetProxyAddress sets websocket proxy address
func (w *Websocket) SetProxyAddress(proxyAddr string) error {
w.m.Lock()
if proxyAddr != "" {
if _, err := url.ParseRequestURI(proxyAddr); err != nil {
w.m.Unlock()
return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err)
}
if w.proxyAddr == proxyAddr {
w.m.Unlock()
return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr)
}
log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr)
} else {
log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName)
}
if w.Conn != nil {
w.Conn.SetProxy(proxyAddr)
}
if w.AuthConn != nil {
w.AuthConn.SetProxy(proxyAddr)
}
w.proxyAddr = proxyAddr
if w.IsConnected() {
w.m.Unlock()
if err := w.Shutdown(); err != nil {
return err
}
return w.Connect()
}
w.m.Unlock()
return nil
}
// GetProxyAddress returns the current websocket proxy
func (w *Websocket) GetProxyAddress() string {
return w.proxyAddr
}
// GetName returns exchange name
func (w *Websocket) GetName() string {
return w.exchangeName
}
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) {
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
return w.subscriptions.Diff(newSubs)
}
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(channels subscription.List) error {
if w.subscriptions == nil || len(channels) == 0 {
return nil // No channels to unsubscribe from is not an error
}
for _, s := range channels {
if w.subscriptions.Get(s) == nil {
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
}
}
return w.Unsubscriber(channels)
}
// ResubscribeToChannel resubscribes to channel
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
// Errors if subscription is already subscribing
func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {
l := subscription.List{s}
if err := s.SetState(subscription.ResubscribingState); err != nil {
return fmt.Errorf("%w: %s", err, s)
}
if err := w.UnsubscribeChannels(l); err != nil {
return err
}
return w.SubscribeToChannels(l)
}
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
if slices.Contains(subs, nil) {
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
}
if err := w.checkSubscriptions(subs); err != nil {
return err
}
if err := w.Subscriber(subs); err != nil {
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
}
return nil
}
// AddSubscriptions adds subscriptions to the subscription store
// Sets state to Subscribing unless the state is already set
func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
var errs error
for _, s := range subs {
if s.State() == subscription.InactiveState {
if err := s.SetState(subscription.SubscribingState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
var errs error
for _, s := range subs {
if err := s.SetState(subscription.SubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
}
var errs error
for _, s := range subs {
if err := s.SetState(subscription.UnsubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Remove(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
func (w *Websocket) GetSubscription(key any) *subscription.Subscription {
if w == nil || w.subscriptions == nil || key == nil {
return nil
}
return w.subscriptions.Get(key)
}
// GetSubscriptions returns a new slice of the subscriptions
func (w *Websocket) GetSubscriptions() subscription.List {
if w == nil || w.subscriptions == nil {
return nil
}
return w.subscriptions.List()
}
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) {
w.canUseAuthenticatedEndpoints.Store(b)
}
// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner
func (w *Websocket) CanUseAuthenticatedEndpoints() bool {
return w.canUseAuthenticatedEndpoints.Load()
}
// IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error
func IsDisconnectionError(err error) bool {
if websocket.IsUnexpectedCloseError(err) {
return true
}
if _, ok := err.(*net.OpError); ok {
return !errors.Is(err, errClosedConnection)
}
return false
}
// checkWebsocketURL checks for a valid websocket url
func checkWebsocketURL(s string) error {
u, err := url.Parse(s)
if err != nil {
return err
}
if u.Scheme != "ws" && u.Scheme != "wss" {
return fmt.Errorf("cannot set %w %s", errInvalidWebsocketURL, s)
}
return nil
}
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
// The subscription state is not considered when counting existing subscriptions
func (w *Websocket) checkSubscriptions(subs subscription.List) error {
if w.subscriptions == nil {
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
}
existing := w.subscriptions.Len()
if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
existing,
len(subs),
w.MaxSubscriptionsPerConnection)
}
for _, s := range subs {
if found := w.subscriptions.Get(s); found != nil {
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
}
}
return nil
}