mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-14 07:26:47 +00:00
* stream: rate limiter definitions * Update exchanges/request/request_test.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io> Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
976 lines
28 KiB
Go
976 lines
28 KiB
Go
package stream
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/thrasher-corp/gocryptotrader/common"
|
|
"github.com/thrasher-corp/gocryptotrader/config"
|
|
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
|
|
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
|
|
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
|
"github.com/thrasher-corp/gocryptotrader/log"
|
|
)
|
|
|
|
const (
|
|
jobBuffer = 5000
|
|
)
|
|
|
|
// Public websocket errors
|
|
var (
|
|
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
|
|
ErrSubscriptionFailure = errors.New("subscription failure")
|
|
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
|
|
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
|
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
|
ErrNotConnected = errors.New("websocket is not connected")
|
|
ErrNoMessageListener = errors.New("websocket listener not found for message")
|
|
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
|
)
|
|
|
|
// Private websocket errors
|
|
var (
|
|
errAlreadyRunning = errors.New("connection monitor is already running")
|
|
errExchangeConfigIsNil = errors.New("exchange config is nil")
|
|
errExchangeConfigEmpty = errors.New("exchange config is empty")
|
|
errWebsocketIsNil = errors.New("websocket is nil")
|
|
errWebsocketSetupIsNil = errors.New("websocket setup is nil")
|
|
errWebsocketAlreadyInitialised = errors.New("websocket already initialised")
|
|
errWebsocketAlreadyEnabled = errors.New("websocket already enabled")
|
|
errWebsocketFeaturesIsUnset = errors.New("websocket features is unset")
|
|
errConfigFeaturesIsNil = errors.New("exchange config features is nil")
|
|
errDefaultURLIsEmpty = errors.New("default url is empty")
|
|
errRunningURLIsEmpty = errors.New("running url cannot be empty")
|
|
errInvalidWebsocketURL = errors.New("invalid websocket url")
|
|
errExchangeConfigNameEmpty = errors.New("exchange config name empty")
|
|
errInvalidTrafficTimeout = errors.New("invalid traffic timeout")
|
|
errTrafficAlertNil = errors.New("traffic alert is nil")
|
|
errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set")
|
|
errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set")
|
|
errWebsocketConnectorUnset = errors.New("websocket connector function not set")
|
|
errReadMessageErrorsNil = errors.New("read message errors is nil")
|
|
errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set")
|
|
errClosedConnection = errors.New("use of closed network connection")
|
|
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
|
|
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
|
|
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
|
|
errNoConnectFunc = errors.New("websocket connect func not set")
|
|
errAlreadyConnected = errors.New("websocket already connected")
|
|
errCannotShutdown = errors.New("websocket cannot shutdown")
|
|
errAlreadyReconnecting = errors.New("websocket in the process of reconnection")
|
|
errConnSetup = errors.New("error in connection setup")
|
|
)
|
|
|
|
var (
|
|
globalReporter Reporter
|
|
trafficCheckInterval = 100 * time.Millisecond
|
|
)
|
|
|
|
// SetupGlobalReporter sets a reporter interface to be used
|
|
// for all exchange requests
|
|
func SetupGlobalReporter(r Reporter) {
|
|
globalReporter = r
|
|
}
|
|
|
|
// NewWebsocket initialises the websocket struct
|
|
func NewWebsocket() *Websocket {
|
|
return &Websocket{
|
|
DataHandler: make(chan interface{}, jobBuffer),
|
|
ToRoutine: make(chan interface{}, jobBuffer),
|
|
ShutdownC: make(chan struct{}),
|
|
TrafficAlert: make(chan struct{}, 1),
|
|
ReadMessageErrors: make(chan error),
|
|
Match: NewMatch(),
|
|
subscriptions: subscription.NewStore(),
|
|
features: &protocol.Features{},
|
|
Orderbook: buffer.Orderbook{},
|
|
}
|
|
}
|
|
|
|
// Setup sets main variables for websocket connection
|
|
func (w *Websocket) Setup(s *WebsocketSetup) error {
|
|
if w == nil {
|
|
return errWebsocketIsNil
|
|
}
|
|
|
|
if s == nil {
|
|
return errWebsocketSetupIsNil
|
|
}
|
|
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
if w.IsInitialised() {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised)
|
|
}
|
|
|
|
if s.ExchangeConfig == nil {
|
|
return errExchangeConfigIsNil
|
|
}
|
|
|
|
if s.ExchangeConfig.Name == "" {
|
|
return errExchangeConfigNameEmpty
|
|
}
|
|
w.exchangeName = s.ExchangeConfig.Name
|
|
w.verbose = s.ExchangeConfig.Verbose
|
|
|
|
if s.Features == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketFeaturesIsUnset)
|
|
}
|
|
w.features = s.Features
|
|
|
|
if s.ExchangeConfig.Features == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil)
|
|
}
|
|
w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket)
|
|
|
|
if s.Connector == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset)
|
|
}
|
|
w.connector = s.Connector
|
|
|
|
if s.Subscriber == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriberUnset)
|
|
}
|
|
w.Subscriber = s.Subscriber
|
|
|
|
if w.features.Unsubscribe && s.Unsubscriber == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketUnsubscriberUnset)
|
|
}
|
|
w.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay
|
|
if w.connectionMonitorDelay <= 0 {
|
|
w.connectionMonitorDelay = config.DefaultConnectionMonitorDelay
|
|
}
|
|
w.Unsubscriber = s.Unsubscriber
|
|
|
|
if s.GenerateSubscriptions == nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketSubscriptionsGeneratorUnset)
|
|
}
|
|
w.GenerateSubs = s.GenerateSubscriptions
|
|
|
|
if s.DefaultURL == "" {
|
|
return fmt.Errorf("%s websocket %w", w.exchangeName, errDefaultURLIsEmpty)
|
|
}
|
|
w.defaultURL = s.DefaultURL
|
|
if s.RunningURL == "" {
|
|
return fmt.Errorf("%s websocket %w", w.exchangeName, errRunningURLIsEmpty)
|
|
}
|
|
err := w.SetWebsocketURL(s.RunningURL, false, false)
|
|
if err != nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, err)
|
|
}
|
|
|
|
if s.RunningURLAuth != "" {
|
|
err = w.SetWebsocketURL(s.RunningURLAuth, true, false)
|
|
if err != nil {
|
|
return fmt.Errorf("%s %w", w.exchangeName, err)
|
|
}
|
|
}
|
|
|
|
if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second {
|
|
return fmt.Errorf("%s %w cannot be less than %s",
|
|
w.exchangeName,
|
|
errInvalidTrafficTimeout,
|
|
time.Second)
|
|
}
|
|
w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout
|
|
|
|
w.ShutdownC = make(chan struct{})
|
|
w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport)
|
|
|
|
if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil {
|
|
return err
|
|
}
|
|
|
|
w.Trade.Setup(w.exchangeName, s.TradeFeed, w.DataHandler)
|
|
w.Fills.Setup(s.FillsFeed, w.DataHandler)
|
|
|
|
if s.MaxWebsocketSubscriptionsPerConnection < 0 {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
|
|
}
|
|
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
|
|
w.setState(disconnectedState)
|
|
|
|
w.rateLimitDefinitions = s.RateLimitDefinitions
|
|
return nil
|
|
}
|
|
|
|
// SetupNewConnection sets up an auth or unauth streaming connection
|
|
func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
|
|
if w == nil {
|
|
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil)
|
|
}
|
|
|
|
if c.ResponseCheckTimeout == 0 &&
|
|
c.ResponseMaxLimit == 0 &&
|
|
c.RateLimit == nil &&
|
|
c.URL == "" &&
|
|
c.ConnectionLevelReporter == nil &&
|
|
c.BespokeGenerateMessageID == nil {
|
|
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty)
|
|
}
|
|
|
|
if w.exchangeName == "" {
|
|
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty)
|
|
}
|
|
|
|
if w.TrafficAlert == nil {
|
|
return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil)
|
|
}
|
|
|
|
if w.ReadMessageErrors == nil {
|
|
return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil)
|
|
}
|
|
|
|
connectionURL := w.GetWebsocketURL()
|
|
if c.URL != "" {
|
|
connectionURL = c.URL
|
|
}
|
|
|
|
if c.ConnectionLevelReporter == nil {
|
|
c.ConnectionLevelReporter = w.ExchangeLevelReporter
|
|
}
|
|
|
|
if c.ConnectionLevelReporter == nil {
|
|
c.ConnectionLevelReporter = globalReporter
|
|
}
|
|
|
|
newConn := &WebsocketConnection{
|
|
ExchangeName: w.exchangeName,
|
|
URL: connectionURL,
|
|
ProxyURL: w.GetProxyAddress(),
|
|
Verbose: w.verbose,
|
|
ResponseMaxLimit: c.ResponseMaxLimit,
|
|
Traffic: w.TrafficAlert,
|
|
readMessageErrors: w.ReadMessageErrors,
|
|
ShutdownC: w.ShutdownC,
|
|
Wg: &w.Wg,
|
|
Match: w.Match,
|
|
RateLimit: c.RateLimit,
|
|
Reporter: c.ConnectionLevelReporter,
|
|
bespokeGenerateMessageID: c.BespokeGenerateMessageID,
|
|
RateLimitDefinitions: w.rateLimitDefinitions,
|
|
}
|
|
|
|
if c.Authenticated {
|
|
w.AuthConn = newConn
|
|
} else {
|
|
w.Conn = newConn
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Connect initiates a websocket connection by using a package defined connection
|
|
// function
|
|
func (w *Websocket) Connect() error {
|
|
if w.connector == nil {
|
|
return errNoConnectFunc
|
|
}
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
if !w.IsEnabled() {
|
|
return ErrWebsocketNotEnabled
|
|
}
|
|
if w.IsConnecting() {
|
|
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting)
|
|
}
|
|
if w.IsConnected() {
|
|
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected)
|
|
}
|
|
|
|
if w.subscriptions == nil {
|
|
return fmt.Errorf("%w: subscriptions", common.ErrNilPointer)
|
|
}
|
|
w.subscriptions.Clear()
|
|
|
|
w.dataMonitor()
|
|
w.trafficMonitor()
|
|
w.setState(connectingState)
|
|
|
|
err := w.connector()
|
|
if err != nil {
|
|
w.setState(disconnectedState)
|
|
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
|
|
}
|
|
w.setState(connectedState)
|
|
|
|
if !w.IsConnectionMonitorRunning() {
|
|
err = w.connectionMonitor()
|
|
if err != nil {
|
|
log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err)
|
|
}
|
|
}
|
|
|
|
subs, err := w.GenerateSubs() // regenerate state on new connection
|
|
if err != nil {
|
|
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
|
|
}
|
|
if len(subs) != 0 {
|
|
if err := w.SubscribeToChannels(subs); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disable disables the exchange websocket protocol
|
|
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
|
|
func (w *Websocket) Disable() error {
|
|
if !w.IsEnabled() {
|
|
return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled)
|
|
}
|
|
|
|
w.setEnabled(false)
|
|
return nil
|
|
}
|
|
|
|
// Enable enables the exchange websocket protocol
|
|
func (w *Websocket) Enable() error {
|
|
if w.IsConnected() || w.IsEnabled() {
|
|
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled)
|
|
}
|
|
|
|
w.setEnabled(true)
|
|
return w.Connect()
|
|
}
|
|
|
|
// dataMonitor monitors job throughput and logs if there is a back log of data
|
|
func (w *Websocket) dataMonitor() {
|
|
if w.IsDataMonitorRunning() {
|
|
return
|
|
}
|
|
w.setDataMonitorRunning(true)
|
|
w.Wg.Add(1)
|
|
|
|
go func() {
|
|
defer func() {
|
|
w.setDataMonitorRunning(false)
|
|
w.Wg.Done()
|
|
}()
|
|
dropped := 0
|
|
for {
|
|
select {
|
|
case <-w.ShutdownC:
|
|
return
|
|
case d := <-w.DataHandler:
|
|
select {
|
|
case w.ToRoutine <- d:
|
|
if dropped != 0 {
|
|
log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", w.exchangeName, dropped)
|
|
dropped = 0
|
|
}
|
|
default:
|
|
if dropped == 0 {
|
|
// If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible
|
|
log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", w.exchangeName)
|
|
}
|
|
dropped++
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// connectionMonitor ensures that the WS keeps connecting
|
|
func (w *Websocket) connectionMonitor() error {
|
|
if w.checkAndSetMonitorRunning() {
|
|
return errAlreadyRunning
|
|
}
|
|
delay := w.connectionMonitorDelay
|
|
|
|
go func() {
|
|
timer := time.NewTimer(delay)
|
|
for {
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName)
|
|
}
|
|
if !w.IsEnabled() {
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName)
|
|
}
|
|
if w.IsConnected() {
|
|
if err := w.Shutdown(); err != nil {
|
|
log.Errorln(log.WebsocketMgr, err)
|
|
}
|
|
}
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName)
|
|
}
|
|
timer.Stop()
|
|
w.setConnectionMonitorRunning(false)
|
|
return
|
|
}
|
|
select {
|
|
case err := <-w.ReadMessageErrors:
|
|
w.DataHandler <- err
|
|
if IsDisconnectionError(err) {
|
|
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
|
|
if w.IsConnected() {
|
|
if shutdownErr := w.Shutdown(); shutdownErr != nil {
|
|
log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", w.exchangeName, shutdownErr)
|
|
}
|
|
}
|
|
}
|
|
case <-timer.C:
|
|
if !w.IsConnecting() && !w.IsConnected() {
|
|
err := w.Connect()
|
|
if err != nil {
|
|
log.Errorln(log.WebsocketMgr, err)
|
|
}
|
|
}
|
|
if !timer.Stop() {
|
|
select {
|
|
case <-timer.C:
|
|
default:
|
|
}
|
|
}
|
|
timer.Reset(delay)
|
|
}
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
// Shutdown attempts to shut down a websocket connection and associated routines
|
|
// by using a package defined shutdown function
|
|
func (w *Websocket) Shutdown() error {
|
|
w.m.Lock()
|
|
defer w.m.Unlock()
|
|
|
|
if !w.IsConnected() {
|
|
return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected)
|
|
}
|
|
|
|
// TODO: Interrupt connection and or close connection when it is re-established.
|
|
if w.IsConnecting() {
|
|
return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting)
|
|
}
|
|
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName)
|
|
}
|
|
|
|
defer w.Orderbook.FlushBuffer()
|
|
|
|
if w.Conn != nil {
|
|
if err := w.Conn.Shutdown(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if w.AuthConn != nil {
|
|
if err := w.AuthConn.Shutdown(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// flush any subscriptions from last connection if needed
|
|
w.subscriptions.Clear()
|
|
|
|
w.setState(disconnectedState)
|
|
|
|
close(w.ShutdownC)
|
|
w.Wg.Wait()
|
|
w.ShutdownC = make(chan struct{})
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// FlushChannels flushes channel subscriptions when there is a pair/asset change
|
|
func (w *Websocket) FlushChannels() error {
|
|
if !w.IsEnabled() {
|
|
return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled)
|
|
}
|
|
|
|
if !w.IsConnected() {
|
|
return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected)
|
|
}
|
|
|
|
if w.features.Subscribe {
|
|
newsubs, err := w.GenerateSubs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
subs, unsubs := w.GetChannelDifference(newsubs)
|
|
if w.features.Unsubscribe {
|
|
if len(unsubs) != 0 {
|
|
err := w.UnsubscribeChannels(unsubs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(subs) < 1 {
|
|
return nil
|
|
}
|
|
return w.SubscribeToChannels(subs)
|
|
} else if w.features.FullPayloadSubscribe {
|
|
// FullPayloadSubscribe means that the endpoint requires all
|
|
// subscriptions to be sent via the websocket connection e.g. if you are
|
|
// subscribed to ticker and orderbook but require trades as well, you
|
|
// would need to send ticker, orderbook and trades channel subscription
|
|
// messages.
|
|
newsubs, err := w.GenerateSubs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(newsubs) != 0 {
|
|
// Purge subscription list as there will be conflicts
|
|
w.subscriptions.Clear()
|
|
return w.SubscribeToChannels(newsubs)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := w.Shutdown(); err != nil {
|
|
return err
|
|
}
|
|
return w.Connect()
|
|
}
|
|
|
|
// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert
|
|
// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic
|
|
// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again
|
|
func (w *Websocket) trafficMonitor() {
|
|
if w.IsTrafficMonitorRunning() {
|
|
return
|
|
}
|
|
w.setTrafficMonitorRunning(true)
|
|
w.Wg.Add(1)
|
|
|
|
go func() {
|
|
t := time.NewTimer(w.trafficTimeout)
|
|
for {
|
|
select {
|
|
case <-w.ShutdownC:
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName)
|
|
}
|
|
t.Stop()
|
|
w.setTrafficMonitorRunning(false)
|
|
w.Wg.Done()
|
|
return
|
|
case <-time.After(trafficCheckInterval):
|
|
select {
|
|
case <-w.TrafficAlert:
|
|
if !t.Stop() {
|
|
<-t.C
|
|
}
|
|
t.Reset(w.trafficTimeout)
|
|
default:
|
|
}
|
|
case <-t.C:
|
|
checkAgain := w.IsConnecting()
|
|
select {
|
|
case <-w.TrafficAlert:
|
|
checkAgain = true
|
|
default:
|
|
}
|
|
if checkAgain {
|
|
t.Reset(w.trafficTimeout)
|
|
break
|
|
}
|
|
if w.verbose {
|
|
log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout)
|
|
}
|
|
w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call
|
|
w.Wg.Done() // Without this the w.Shutdown() call below will deadlock
|
|
if w.IsConnected() {
|
|
err := w.Shutdown()
|
|
if err != nil {
|
|
log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (w *Websocket) setState(s uint32) {
|
|
w.state.Store(s)
|
|
}
|
|
|
|
// IsInitialised returns whether the websocket has been Setup() already
|
|
func (w *Websocket) IsInitialised() bool {
|
|
return w.state.Load() != uninitialisedState
|
|
}
|
|
|
|
// IsConnected returns whether the websocket is connected
|
|
func (w *Websocket) IsConnected() bool {
|
|
return w.state.Load() == connectedState
|
|
}
|
|
|
|
// IsConnecting returns whether the websocket is connecting
|
|
func (w *Websocket) IsConnecting() bool {
|
|
return w.state.Load() == connectingState
|
|
}
|
|
|
|
func (w *Websocket) setEnabled(b bool) {
|
|
w.enabled.Store(b)
|
|
}
|
|
|
|
// IsEnabled returns whether the websocket is enabled
|
|
func (w *Websocket) IsEnabled() bool {
|
|
return w.enabled.Load()
|
|
}
|
|
|
|
func (w *Websocket) setTrafficMonitorRunning(b bool) {
|
|
w.trafficMonitorRunning.Store(b)
|
|
}
|
|
|
|
// IsTrafficMonitorRunning returns status of the traffic monitor
|
|
func (w *Websocket) IsTrafficMonitorRunning() bool {
|
|
return w.trafficMonitorRunning.Load()
|
|
}
|
|
|
|
func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) {
|
|
return !w.connectionMonitorRunning.CompareAndSwap(false, true)
|
|
}
|
|
|
|
func (w *Websocket) setConnectionMonitorRunning(b bool) {
|
|
w.connectionMonitorRunning.Store(b)
|
|
}
|
|
|
|
// IsConnectionMonitorRunning returns status of connection monitor
|
|
func (w *Websocket) IsConnectionMonitorRunning() bool {
|
|
return w.connectionMonitorRunning.Load()
|
|
}
|
|
|
|
func (w *Websocket) setDataMonitorRunning(b bool) {
|
|
w.dataMonitorRunning.Store(b)
|
|
}
|
|
|
|
// IsDataMonitorRunning returns status of data monitor
|
|
func (w *Websocket) IsDataMonitorRunning() bool {
|
|
return w.dataMonitorRunning.Load()
|
|
}
|
|
|
|
// CanUseAuthenticatedWebsocketForWrapper Handles a common check to
|
|
// verify whether a wrapper can use an authenticated websocket endpoint
|
|
func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool {
|
|
if w.IsConnected() {
|
|
if w.CanUseAuthenticatedEndpoints() {
|
|
return true
|
|
}
|
|
log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// SetWebsocketURL sets websocket URL and can refresh underlying connections
|
|
func (w *Websocket) SetWebsocketURL(url string, auth, reconnect bool) error {
|
|
defaultVals := url == "" || url == config.WebsocketURLNonDefaultMessage
|
|
if auth {
|
|
if defaultVals {
|
|
url = w.defaultURLAuth
|
|
}
|
|
|
|
err := checkWebsocketURL(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
w.runningURLAuth = url
|
|
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr,
|
|
"%s websocket: setting authenticated websocket URL: %s\n",
|
|
w.exchangeName,
|
|
url)
|
|
}
|
|
|
|
if w.AuthConn != nil {
|
|
w.AuthConn.SetURL(url)
|
|
}
|
|
} else {
|
|
if defaultVals {
|
|
url = w.defaultURL
|
|
}
|
|
err := checkWebsocketURL(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
w.runningURL = url
|
|
|
|
if w.verbose {
|
|
log.Debugf(log.WebsocketMgr,
|
|
"%s websocket: setting unauthenticated websocket URL: %s\n",
|
|
w.exchangeName,
|
|
url)
|
|
}
|
|
|
|
if w.Conn != nil {
|
|
w.Conn.SetURL(url)
|
|
}
|
|
}
|
|
|
|
if w.IsConnected() && reconnect {
|
|
log.Debugf(log.WebsocketMgr,
|
|
"%s websocket: flushing websocket connection to %s\n",
|
|
w.exchangeName,
|
|
url)
|
|
return w.Shutdown()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetWebsocketURL returns the running websocket URL
|
|
func (w *Websocket) GetWebsocketURL() string {
|
|
return w.runningURL
|
|
}
|
|
|
|
// SetProxyAddress sets websocket proxy address
|
|
func (w *Websocket) SetProxyAddress(proxyAddr string) error {
|
|
w.m.Lock()
|
|
|
|
if proxyAddr != "" {
|
|
if _, err := url.ParseRequestURI(proxyAddr); err != nil {
|
|
w.m.Unlock()
|
|
return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err)
|
|
}
|
|
|
|
if w.proxyAddr == proxyAddr {
|
|
w.m.Unlock()
|
|
return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr)
|
|
}
|
|
|
|
log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr)
|
|
} else {
|
|
log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName)
|
|
}
|
|
|
|
if w.Conn != nil {
|
|
w.Conn.SetProxy(proxyAddr)
|
|
}
|
|
if w.AuthConn != nil {
|
|
w.AuthConn.SetProxy(proxyAddr)
|
|
}
|
|
|
|
w.proxyAddr = proxyAddr
|
|
|
|
if w.IsConnected() {
|
|
w.m.Unlock()
|
|
if err := w.Shutdown(); err != nil {
|
|
return err
|
|
}
|
|
return w.Connect()
|
|
}
|
|
|
|
w.m.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetProxyAddress returns the current websocket proxy
|
|
func (w *Websocket) GetProxyAddress() string {
|
|
return w.proxyAddr
|
|
}
|
|
|
|
// GetName returns exchange name
|
|
func (w *Websocket) GetName() string {
|
|
return w.exchangeName
|
|
}
|
|
|
|
// GetChannelDifference finds the difference between the subscribed channels
|
|
// and the new subscription list when pairs are disabled or enabled.
|
|
func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) {
|
|
if w.subscriptions == nil {
|
|
w.subscriptions = subscription.NewStore()
|
|
}
|
|
return w.subscriptions.Diff(newSubs)
|
|
}
|
|
|
|
// UnsubscribeChannels unsubscribes from a list of websocket channel
|
|
func (w *Websocket) UnsubscribeChannels(channels subscription.List) error {
|
|
if w.subscriptions == nil || len(channels) == 0 {
|
|
return nil // No channels to unsubscribe from is not an error
|
|
}
|
|
for _, s := range channels {
|
|
if w.subscriptions.Get(s) == nil {
|
|
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
|
|
}
|
|
}
|
|
return w.Unsubscriber(channels)
|
|
}
|
|
|
|
// ResubscribeToChannel resubscribes to channel
|
|
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
|
|
// Errors if subscription is already subscribing
|
|
func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {
|
|
l := subscription.List{s}
|
|
if err := s.SetState(subscription.ResubscribingState); err != nil {
|
|
return fmt.Errorf("%w: %s", err, s)
|
|
}
|
|
if err := w.UnsubscribeChannels(l); err != nil {
|
|
return err
|
|
}
|
|
return w.SubscribeToChannels(l)
|
|
}
|
|
|
|
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
|
|
// Errors are returned for duplicates or exceeding max Subscriptions
|
|
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
|
|
if slices.Contains(subs, nil) {
|
|
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
|
|
}
|
|
if err := w.checkSubscriptions(subs); err != nil {
|
|
return err
|
|
}
|
|
if err := w.Subscriber(subs); err != nil {
|
|
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AddSubscriptions adds subscriptions to the subscription store
|
|
// Sets state to Subscribing unless the state is already set
|
|
func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error {
|
|
if w == nil {
|
|
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
|
|
}
|
|
if w.subscriptions == nil {
|
|
w.subscriptions = subscription.NewStore()
|
|
}
|
|
var errs error
|
|
for _, s := range subs {
|
|
if s.State() == subscription.InactiveState {
|
|
if err := s.SetState(subscription.SubscribingState); err != nil {
|
|
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
|
}
|
|
}
|
|
if err := w.subscriptions.Add(s); err != nil {
|
|
errs = common.AppendError(errs, err)
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
|
|
func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error {
|
|
if w == nil {
|
|
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
|
|
}
|
|
if w.subscriptions == nil {
|
|
w.subscriptions = subscription.NewStore()
|
|
}
|
|
var errs error
|
|
for _, s := range subs {
|
|
if err := s.SetState(subscription.SubscribedState); err != nil {
|
|
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
|
}
|
|
if err := w.subscriptions.Add(s); err != nil {
|
|
errs = common.AppendError(errs, err)
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
|
|
func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error {
|
|
if w == nil {
|
|
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
|
|
}
|
|
if w.subscriptions == nil {
|
|
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
|
|
}
|
|
var errs error
|
|
for _, s := range subs {
|
|
if err := s.SetState(subscription.UnsubscribedState); err != nil {
|
|
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
|
}
|
|
if err := w.subscriptions.Remove(s); err != nil {
|
|
errs = common.AppendError(errs, err)
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
// GetSubscription returns a subscription at the key provided
|
|
// returns nil if no subscription is at that key or the key is nil
|
|
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
|
|
func (w *Websocket) GetSubscription(key any) *subscription.Subscription {
|
|
if w == nil || w.subscriptions == nil || key == nil {
|
|
return nil
|
|
}
|
|
return w.subscriptions.Get(key)
|
|
}
|
|
|
|
// GetSubscriptions returns a new slice of the subscriptions
|
|
func (w *Websocket) GetSubscriptions() subscription.List {
|
|
if w == nil || w.subscriptions == nil {
|
|
return nil
|
|
}
|
|
return w.subscriptions.List()
|
|
}
|
|
|
|
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
|
|
func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) {
|
|
w.canUseAuthenticatedEndpoints.Store(b)
|
|
}
|
|
|
|
// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner
|
|
func (w *Websocket) CanUseAuthenticatedEndpoints() bool {
|
|
return w.canUseAuthenticatedEndpoints.Load()
|
|
}
|
|
|
|
// IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error
|
|
func IsDisconnectionError(err error) bool {
|
|
if websocket.IsUnexpectedCloseError(err) {
|
|
return true
|
|
}
|
|
if _, ok := err.(*net.OpError); ok {
|
|
return !errors.Is(err, errClosedConnection)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// checkWebsocketURL checks for a valid websocket url
|
|
func checkWebsocketURL(s string) error {
|
|
u, err := url.Parse(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if u.Scheme != "ws" && u.Scheme != "wss" {
|
|
return fmt.Errorf("cannot set %w %s", errInvalidWebsocketURL, s)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
|
|
// The subscription state is not considered when counting existing subscriptions
|
|
func (w *Websocket) checkSubscriptions(subs subscription.List) error {
|
|
if w.subscriptions == nil {
|
|
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
|
|
}
|
|
|
|
existing := w.subscriptions.Len()
|
|
if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection {
|
|
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
|
|
errSubscriptionsExceedsLimit,
|
|
existing,
|
|
len(subs),
|
|
w.MaxSubscriptionsPerConnection)
|
|
}
|
|
|
|
for _, s := range subs {
|
|
if found := w.subscriptions.Get(s); found != nil {
|
|
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|