Files
gocryptotrader/exchange/websocket/manager.go
cranktakular fd9aaf00a2 Coinbase: Update exchange implementation (#1480)
* Slight enhance of Coinbase tests

Continual enhance of Coinbase tests

The revamp continues

Oh jeez the Orderbook part's unfinished don't look

Coinbase revamp, Orderbook still unfinished

* Coinbase revamp; CreateReport is still WIP

* More coinbase improvements; onto sandbox testing

* Coinbase revamp continues

* Coinbase revamp continues

* Coinbasepro revamp is ceaseless

* Coinbase revamp, starting on advanced trade API

* Coinbase Advanced Trade Starts in Ernest

V3 done, onto V2

Coinbase revamp nears completion

Coinbase revamp nears completion

Test commit should fail

Coinbase revamp nears completion

* Coinbase revamp stage wrapper

* Coinbase wrapper coherence continues

* Coinbase wrapper continues writhing

* Coinbase wrapper & codebase cleanup

* Coinbase updates & wrap progress

* More Coinbase wrapper progress

* Wrapper is wrapped, kinda

* Test & type checking

* Coinbase REST revamp finished

* Post-merge fix

* WS revamp begins

* WS Main Revamp Done?

* CB websocket tidying up

* Coinbase WS wrapperupperer

* Coinbase revamp done??

* Linter progress

* Continued lint cleanup

* Further lint cleanup

* Increased lint coverage

* Does this fix all sloppy reassigns & shadowing?

* Undoing retry policy change

* Documentation regeneration

* Coinbase code improvements

* Providing warning about known issue

* Updating an error to new format

* Making gocritic happy

* Review adherence

* Endpoints moved to V3 & nil pointer fixes

* Removing seemingly superfluous constant

* Glorious improvements

* Removing unused error

* Partial public endpoint addition

* Slight improvements

* Wrapper improvements; still a few errors left in other packages

* A lil Coinbase progress

* Json cleaning

* Lint appeasement

* Config repair

* Config fix (real)

* Little fix

* New public endpoint incorporation

* Additional fixes

* Improvements & Appeasements

* LineSaver

* Additional fixes

* Another fix

* Fixing picked nits

* Quick fixies

* Lil fixes

* Subscriptions: Add List.Enabled

* CoinbasePro: Add subscription templating

* fixup! CoinbasePro: Add subscription templating

* fixup! CoinbasePro: Add subscription templating

* Comment fix

* Subsequent fixes

* Issues hopefully fixed

* Lint fix

* Glorious fixes

* Json formatting

* ShazNits

* (L/N)i(n/)t

* Adding a test

* Tiny test improvement

* Template patch testing

* Fixes

* Further shaznits

* Lint nit

* JWT move and other fixes

* Small nits

* Shaznit, singular

* Post-merge fix

* Post-merge fixes

* Typo fix

* Some glorious nits

* Required changes

* Stop going

* Alias attempt

* Alias fix & test cleanup

* Test fix

* GetDepositAddress logic improvement

* Status update: Fixed

* Lint fix

* Happy birthday to PR 1480

* Cleanups

* Necessary nit corrections

* Fixing sillybug

* As per request

* Programming progress

* Order fixes

* Further fixies

* Test fix

* Pre-merge fixes

* More shaznits

* Context

* Sonic error handling

* Import fix

* Better Sonic error handling

* Perfect Sonic error handling?

* F purge

* Coinbase improvements

* API Update Conformity

* Coinbase continuation

* Coinbase order improvements

* Coinbase order improvements

* CreateOrderConfig improvements

* Managing API updates

* Coinbase API update progression

* jwt rename

* Comment link fix

* Coinbase v2 cleanup

* Post-merge fixes

* Review fixes

* GK's suggestions

* Linter fix

* Minor gbjk fixes

* Nit fixes

* Merge fix

* Lint fixes

* Coinbase rename stage 1

* Coinbase rename stage 2

* Coinbase rename stage 3

* Coinbase rename stage 4

* Coinbase rename final fix

* Coinbase: PoC on converting to request structs

* Applying requested changes

* Many review fixes, handled

* Thrashed by nits

* More minor modifications

* The last nit!?

---------

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
2025-09-16 13:37:00 +10:00

1072 lines
36 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package websocket
import (
"context"
"errors"
"fmt"
"net/url"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchange/websocket/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/fill"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
"github.com/thrasher-corp/gocryptotrader/log"
)
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrWebsocketAlreadyEnabled = errors.New("websocket already enabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
ErrRequestRouteNotFound = errors.New("request route not found")
ErrSignatureNotSet = errors.New("signature not set")
)
// Private websocket errors
var (
errWebsocketAlreadyInitialised = errors.New("websocket already initialised")
errDefaultURLIsEmpty = errors.New("default url is empty")
errRunningURLIsEmpty = errors.New("running url cannot be empty")
errInvalidWebsocketURL = errors.New("invalid websocket url")
errExchangeConfigNameEmpty = errors.New("exchange config name empty")
errInvalidTrafficTimeout = errors.New("invalid traffic timeout")
errTrafficAlertNil = errors.New("traffic alert is nil")
errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set")
errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set")
errWebsocketConnectorUnset = errors.New("websocket connector function not set")
errWebsocketDataHandlerUnset = errors.New("websocket data handler not set")
errReadMessageErrorsNil = errors.New("read message errors is nil")
errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set")
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("websocket connect func not set")
errAlreadyConnected = errors.New("websocket already connected")
errCannotShutdown = errors.New("websocket cannot shutdown")
errAlreadyReconnecting = errors.New("websocket in the process of reconnection")
errConnSetup = errors.New("error in connection setup")
errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first")
errDuplicateConnectionSetup = errors.New("duplicate connection setup")
errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management")
errExchangeConfigEmpty = errors.New("exchange config is empty")
errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection")
errMessageFilterNotComparable = errors.New("message filter is not comparable")
)
// Websocket functionality list and state consts
const (
UnhandledMessage = " - Unhandled websocket message: "
jobBuffer = 5000
)
const (
uninitialisedState uint32 = iota
disconnectedState
connectingState
connectedState
)
// Manager provides connection and subscription management and routing
type Manager struct {
enabled atomic.Bool
state atomic.Uint32
verbose bool
canUseAuthenticatedEndpoints atomic.Bool
connectionMonitorRunning atomic.Bool
trafficTimeout time.Duration
connectionMonitorDelay time.Duration
proxyAddr string
defaultURL string
defaultURLAuth string
runningURL string
runningURLAuth string
exchangeName string
features *protocol.Features
m sync.Mutex
connections map[Connection]*connectionWrapper
subscriptions *subscription.Store
connector func() error
rateLimitDefinitions request.RateLimitDefinitions // rate limiters shared between Websocket and REST connections
Subscriber func(subscription.List) error
Unsubscriber func(subscription.List) error
GenerateSubs func() (subscription.List, error)
useMultiConnectionManagement bool
DataHandler chan any
ToRoutine chan any
Match *Match
ShutdownC chan struct{}
Wg sync.WaitGroup
Orderbook buffer.Orderbook
Trade trade.Trade // Trade is a notifier for trades
Fills fill.Fills // Fills is a notifier for fills
TrafficAlert chan struct{}
ReadMessageErrors chan error
Conn Connection // Public connection
AuthConn Connection // Authenticated Private connection
ExchangeLevelReporter Reporter // Latency reporter
MaxSubscriptionsPerConnection int
// connectionManager stores all *potential* connections for the exchange, organised within connectionWrapper structs.
// Each connectionWrapper one connection (will be expanded soon) tailored for specific exchange functionalities or asset types. // TODO: Expand this to support multiple connections per connectionWrapper
// For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful
// for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes.
// If an exchange does not require such differentiation, all connections may be managed under a single connectionWrapper.
connectionManager []*connectionWrapper
}
// ManagerSetup defines variables for setting up a websocket manager
type ManagerSetup struct {
ExchangeConfig *config.Exchange
DefaultURL string
RunningURL string
RunningURLAuth string
Connector func() error
Subscriber func(subscription.List) error
Unsubscriber func(subscription.List) error
GenerateSubscriptions func() (subscription.List, error)
Features *protocol.Features
OrderbookBufferConfig buffer.Config
// UseMultiConnectionManagement allows the connections to be managed by the
// connection manager. If false, this will default to the global fields
// provided in this struct.
UseMultiConnectionManagement bool
TradeFeed bool
FillsFeed bool
MaxWebsocketSubscriptionsPerConnection int
// RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all endpoints.
// These rate limits take precedence over any rate limits specified in individual connection configurations.
// If no connection-specific rate limit is provided and the endpoint does not match any of these definitions,
// an error will be returned. However, if a connection configuration includes its own rate limit,
// it will fall back to that configurations rate limit without raising an error.
RateLimitDefinitions request.RateLimitDefinitions
}
// connectionWrapper contains the connection setup details to be used when
// attempting a new connection. It also contains the subscriptions that are
// associated with the specific connection.
type connectionWrapper struct {
setup *ConnectionSetup
subscriptions *subscription.Store
connection Connection
}
var globalReporter Reporter
// SetupGlobalReporter sets a reporter interface to be used
// for all exchange requests
func SetupGlobalReporter(r Reporter) {
globalReporter = r
}
// NewManager initialises the websocket struct
func NewManager() *Manager {
return &Manager{
DataHandler: make(chan any, jobBuffer),
ToRoutine: make(chan any, jobBuffer),
ShutdownC: make(chan struct{}),
TrafficAlert: make(chan struct{}, 1),
// ReadMessageErrors is buffered for an edge case when `Connect` fails
// after subscriptions are made but before the connectionMonitor has
// started. This allows the error to be read and handled in the
// connectionMonitor and start a connection cycle again.
ReadMessageErrors: make(chan error, 1),
Match: NewMatch(),
subscriptions: subscription.NewStore(),
features: &protocol.Features{},
Orderbook: buffer.Orderbook{},
connections: make(map[Connection]*connectionWrapper),
}
}
// Setup sets main variables for websocket connection
func (m *Manager) Setup(s *ManagerSetup) error {
if err := common.NilGuard(m, s); err != nil {
return err
}
if s.ExchangeConfig == nil {
return fmt.Errorf("%w: ManagerSetup.ExchangeConfig", common.ErrNilPointer)
}
if s.ExchangeConfig.Features == nil {
return fmt.Errorf("%w: ManagerSetup.ExchangeConfig.Features", common.ErrNilPointer)
}
if s.Features == nil {
return fmt.Errorf("%w: ManagerSetup.Features", common.ErrNilPointer)
}
m.m.Lock()
defer m.m.Unlock()
if m.IsInitialised() {
return fmt.Errorf("%s %w", m.exchangeName, errWebsocketAlreadyInitialised)
}
if s.ExchangeConfig.Name == "" {
return errExchangeConfigNameEmpty
}
m.exchangeName = s.ExchangeConfig.Name
m.verbose = s.ExchangeConfig.Verbose
m.features = s.Features
m.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket)
m.useMultiConnectionManagement = s.UseMultiConnectionManagement
if !m.useMultiConnectionManagement {
// TODO: Remove this block when all exchanges are updated and backwards
// compatibility is no longer required.
if s.Connector == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset)
}
if s.Subscriber == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset)
}
if s.Unsubscriber == nil && m.features.Unsubscribe {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset)
}
if s.GenerateSubscriptions == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset)
}
if s.DefaultURL == "" {
return fmt.Errorf("%s websocket %w", m.exchangeName, errDefaultURLIsEmpty)
}
m.defaultURL = s.DefaultURL
if s.RunningURL == "" {
return fmt.Errorf("%s websocket %w", m.exchangeName, errRunningURLIsEmpty)
}
m.connector = s.Connector
m.Subscriber = s.Subscriber
m.Unsubscriber = s.Unsubscriber
m.GenerateSubs = s.GenerateSubscriptions
err := m.SetWebsocketURL(s.RunningURL, false, false)
if err != nil {
return fmt.Errorf("%s %w", m.exchangeName, err)
}
if s.RunningURLAuth != "" {
err = m.SetWebsocketURL(s.RunningURLAuth, true, false)
if err != nil {
return fmt.Errorf("%s %w", m.exchangeName, err)
}
}
}
m.connectionMonitorDelay = s.ExchangeConfig.ConnectionMonitorDelay
if m.connectionMonitorDelay <= 0 {
m.connectionMonitorDelay = config.DefaultConnectionMonitorDelay
}
if s.ExchangeConfig.WebsocketTrafficTimeout < time.Second {
return fmt.Errorf("%s %w cannot be less than %s",
m.exchangeName,
errInvalidTrafficTimeout,
time.Second)
}
m.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout
m.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport)
if err := m.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, m.DataHandler); err != nil {
return err
}
m.Trade.Setup(s.TradeFeed, m.DataHandler)
m.Fills.Setup(s.FillsFeed, m.DataHandler)
if s.MaxWebsocketSubscriptionsPerConnection < 0 {
return fmt.Errorf("%s %w", m.exchangeName, errInvalidMaxSubscriptions)
}
m.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
m.setState(disconnectedState)
m.rateLimitDefinitions = s.RateLimitDefinitions
return nil
}
// SetupNewConnection sets up an auth or unauth streaming connection
func (m *Manager) SetupNewConnection(c *ConnectionSetup) error {
if err := common.NilGuard(m, c); err != nil {
return err
}
if c.ResponseCheckTimeout == 0 && c.ResponseMaxLimit == 0 && c.RateLimit == nil && c.URL == "" && c.ConnectionLevelReporter == nil && c.RequestIDGenerator == nil {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty)
}
if m.exchangeName == "" {
return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty)
}
if m.TrafficAlert == nil {
return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil)
}
if m.ReadMessageErrors == nil {
return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil)
}
if c.ConnectionLevelReporter == nil {
c.ConnectionLevelReporter = m.ExchangeLevelReporter
}
if c.ConnectionLevelReporter == nil {
c.ConnectionLevelReporter = globalReporter
}
if m.useMultiConnectionManagement {
// The connection and supporting functions are defined per connection
// and the connection wrapper is stored in the connection manager.
if c.URL == "" {
return fmt.Errorf("%w: %w", errConnSetup, errDefaultURLIsEmpty)
}
if c.Connector == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketConnectorUnset)
}
if c.GenerateSubscriptions == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriptionsGeneratorUnset)
}
if c.Subscriber == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketSubscriberUnset)
}
if c.Unsubscriber == nil && m.features.Unsubscribe {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketUnsubscriberUnset)
}
if c.Handler == nil {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset)
}
if c.MessageFilter != nil && !reflect.TypeOf(c.MessageFilter).Comparable() {
return errMessageFilterNotComparable
}
for x := range m.connectionManager {
// Below allows for multiple connections to the same URL with different outbound request signatures. This
// allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on
// a spot connection.
if m.connectionManager[x].setup.URL == c.URL && c.MessageFilter == m.connectionManager[x].setup.MessageFilter {
return fmt.Errorf("%w: %w", errConnSetup, errDuplicateConnectionSetup)
}
}
m.connectionManager = append(m.connectionManager, &connectionWrapper{
setup: c,
subscriptions: subscription.NewStore(),
})
return nil
}
if c.Authenticated {
m.AuthConn = m.getConnectionFromSetup(c)
} else {
m.Conn = m.getConnectionFromSetup(c)
}
return nil
}
// getConnectionFromSetup returns a websocket connection from a setup
// configuration. This is used for setting up new connections on the fly.
func (m *Manager) getConnectionFromSetup(c *ConnectionSetup) *connection {
connectionURL := m.GetWebsocketURL()
if c.URL != "" {
connectionURL = c.URL
}
match := m.Match
if m.useMultiConnectionManagement {
// If we are using multi connection management, we can decouple
// the match from the global match and have a match per connection.
match = NewMatch()
}
return &connection{
ExchangeName: m.exchangeName,
URL: connectionURL,
ProxyURL: m.GetProxyAddress(),
Verbose: m.verbose,
ResponseMaxLimit: c.ResponseMaxLimit,
Traffic: m.TrafficAlert,
readMessageErrors: m.ReadMessageErrors,
shutdown: m.ShutdownC,
Wg: &m.Wg,
Match: match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
requestIDGenerator: c.RequestIDGenerator,
RateLimitDefinitions: m.rateLimitDefinitions,
}
}
// Connect initiates a websocket connection by using a package defined connection
// function
func (m *Manager) Connect() error {
m.m.Lock()
defer m.m.Unlock()
return m.connect()
}
func (m *Manager) connect() error {
if !m.IsEnabled() {
return ErrWebsocketNotEnabled
}
if m.IsConnecting() {
return fmt.Errorf("%v %w", m.exchangeName, errAlreadyReconnecting)
}
if m.IsConnected() {
return fmt.Errorf("%v %w", m.exchangeName, errAlreadyConnected)
}
if m.subscriptions == nil {
return fmt.Errorf("%w: subscriptions", common.ErrNilPointer)
}
m.subscriptions.Clear()
m.setState(connectingState)
m.Wg.Add(2)
go m.monitorFrame(&m.Wg, m.monitorData)
go m.monitorFrame(&m.Wg, m.monitorTraffic)
if !m.useMultiConnectionManagement {
if m.connector == nil {
return fmt.Errorf("%v %w", m.exchangeName, errNoConnectFunc)
}
err := m.connector()
if err != nil {
m.setState(disconnectedState)
return fmt.Errorf("%v Error connecting %w", m.exchangeName, err)
}
m.setState(connectedState)
if m.connectionMonitorRunning.CompareAndSwap(false, true) {
// This oversees all connections and does not need to be part of wait group management.
go m.monitorFrame(nil, m.monitorConnection)
}
subs, err := m.GenerateSubs() // regenerate state on new connection
if err != nil {
return fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) != 0 {
if err := m.SubscribeToChannels(nil, subs); err != nil {
return err
}
if missing := m.subscriptions.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w %q", m.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}
if len(m.connectionManager) == 0 {
m.setState(disconnectedState)
return fmt.Errorf("cannot connect: %w", errNoPendingConnections)
}
// multiConnectFatalError is a fatal error that will cause all connections to
// be shutdown and the websocket to be disconnected.
var multiConnectFatalError error
// subscriptionError is a non-fatal error that does not shutdown connections
var subscriptionError error
// TODO: Implement concurrency below.
for i := range m.connectionManager {
if m.connectionManager[i].setup.GenerateSubscriptions == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriptionsGeneratorUnset)
break
}
subs, err := m.connectionManager[i].setup.GenerateSubscriptions() // regenerate state on new connection
if err != nil {
multiConnectFatalError = fmt.Errorf("%s websocket: %w", m.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
break
}
if len(subs) == 0 {
// If no subscriptions are generated, we skip the connection
if m.verbose {
log.Warnf(log.WebsocketMgr, "%s websocket: no subscriptions generated", m.exchangeName)
}
continue
}
if m.connectionManager[i].setup.Connector == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errNoConnectFunc)
break
}
if m.connectionManager[i].setup.Handler == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketDataHandlerUnset)
break
}
if m.connectionManager[i].setup.Subscriber == nil {
multiConnectFatalError = fmt.Errorf("cannot connect to [conn:%d] [URL:%s]: %w ", i+1, m.connectionManager[i].setup.URL, errWebsocketSubscriberUnset)
break
}
// TODO: Add window for max subscriptions per connection, to spawn new connections if needed.
conn := m.getConnectionFromSetup(m.connectionManager[i].setup)
err = m.connectionManager[i].setup.Connector(context.TODO(), conn)
if err != nil {
multiConnectFatalError = fmt.Errorf("%v Error connecting %w", m.exchangeName, err)
break
}
if !conn.IsConnected() {
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to connect", m.exchangeName, i+1, conn.URL)
break
}
m.connections[conn] = m.connectionManager[i]
m.connectionManager[i].connection = conn
m.Wg.Add(1)
go m.Reader(context.TODO(), conn, m.connectionManager[i].setup.Handler)
if m.connectionManager[i].setup.Authenticate != nil && m.CanUseAuthenticatedEndpoints() {
err = m.connectionManager[i].setup.Authenticate(context.TODO(), conn)
if err != nil {
multiConnectFatalError = fmt.Errorf("%s websocket: [conn:%d] [URL:%s] failed to authenticate %w", m.exchangeName, i+1, conn.URL, err)
break
}
}
err = m.connectionManager[i].setup.Subscriber(context.TODO(), conn, subs)
if err != nil {
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v Error subscribing %w", m.exchangeName, err))
continue
}
if missing := m.connectionManager[i].subscriptions.Missing(subs); len(missing) > 0 {
subscriptionError = common.AppendError(subscriptionError, fmt.Errorf("%v %w %q", m.exchangeName, ErrSubscriptionsNotAdded, missing))
continue
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] connected. [Subscribed: %d]",
m.exchangeName,
i+1,
conn.URL,
len(subs))
}
}
if multiConnectFatalError != nil {
// Roll back any successful connections and flush subscriptions
for x := range m.connectionManager {
if m.connectionManager[x].connection != nil {
if err := m.connectionManager[x].connection.Shutdown(); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
m.connectionManager[x].connection = nil
}
m.connectionManager[x].subscriptions.Clear()
}
clear(m.connections)
m.setState(disconnectedState) // Flip from connecting to disconnected.
// Drain residual error in the single buffered channel, this mitigates
// the cycle when `Connect` is called again and the connectionMonitor
// starts but there is an old error in the channel.
drain(m.ReadMessageErrors)
return multiConnectFatalError
}
// Assume connected state here. All connections have been established.
// All subscriptions have been sent and stored. All data received is being
// handled by the appropriate data handler.
m.setState(connectedState)
if m.connectionMonitorRunning.CompareAndSwap(false, true) {
// This oversees all connections and does not need to be part of wait group management.
go m.monitorFrame(nil, m.monitorConnection)
}
return subscriptionError
}
// Disable disables the exchange websocket protocol
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
func (m *Manager) Disable() error {
if !m.IsEnabled() {
return fmt.Errorf("%s %w", m.exchangeName, ErrAlreadyDisabled)
}
m.setEnabled(false)
return nil
}
// Enable enables the exchange websocket protocol
func (m *Manager) Enable() error {
if m.IsConnected() || m.IsEnabled() {
return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketAlreadyEnabled)
}
m.setEnabled(true)
return m.Connect()
}
// Shutdown attempts to shut down a websocket connection and associated routines
// by using a package defined shutdown function
func (m *Manager) Shutdown() error {
m.m.Lock()
defer m.m.Unlock()
return m.shutdown()
}
func (m *Manager) shutdown() error {
if !m.IsConnected() {
return fmt.Errorf("%v %w: %w", m.exchangeName, errCannotShutdown, ErrNotConnected)
}
// TODO: Interrupt connection and or close connection when it is re-established.
if m.IsConnecting() {
return fmt.Errorf("%v %w: %w ", m.exchangeName, errCannotShutdown, errAlreadyReconnecting)
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", m.exchangeName)
}
defer m.Orderbook.FlushBuffer()
// During the shutdown process, all errors are treated as non-fatal to avoid issues when the connection has already
// been closed. In such cases, attempting to close the connection may result in a
// "failed to send closeNotify alert (but connection was closed anyway)" error. Treating these errors as non-fatal
// prevents the shutdown process from being interrupted, which could otherwise trigger a continuous traffic monitor
// cycle and potentially block the initiation of a new connection.
var nonFatalCloseConnectionErrors error
// Shutdown managed connections
for x := range m.connectionManager {
if m.connectionManager[x].connection != nil {
if err := m.connectionManager[x].connection.Shutdown(); err != nil {
nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err)
}
m.connectionManager[x].connection = nil
// Flush any subscriptions from last connection across any managed connections
m.connectionManager[x].subscriptions.Clear()
}
}
// Clean map of old connections
clear(m.connections)
if m.Conn != nil {
if err := m.Conn.Shutdown(); err != nil {
nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err)
}
}
if m.AuthConn != nil {
if err := m.AuthConn.Shutdown(); err != nil {
nonFatalCloseConnectionErrors = common.AppendError(nonFatalCloseConnectionErrors, err)
}
}
// flush any subscriptions from last connection if needed
m.subscriptions.Clear()
m.setState(disconnectedState)
close(m.ShutdownC)
m.Wg.Wait()
m.ShutdownC = make(chan struct{})
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", m.exchangeName)
}
// Drain residual error in the single buffered channel, this mitigates
// the cycle when `Connect` is called again and the connectionMonitor
// starts but there is an old error in the channel.
drain(m.ReadMessageErrors)
if nonFatalCloseConnectionErrors != nil {
log.Warnf(log.WebsocketMgr, "%v websocket: shutdown error: %v", m.exchangeName, nonFatalCloseConnectionErrors)
}
return nil
}
func (m *Manager) setState(s uint32) {
m.state.Store(s)
}
// IsInitialised returns whether the websocket has been Setup() already
func (m *Manager) IsInitialised() bool {
return m.state.Load() != uninitialisedState
}
// IsConnected returns whether the websocket is connected
func (m *Manager) IsConnected() bool {
return m.state.Load() == connectedState
}
// IsConnecting returns whether the websocket is connecting
func (m *Manager) IsConnecting() bool {
return m.state.Load() == connectingState
}
func (m *Manager) setEnabled(b bool) {
m.enabled.Store(b)
}
// IsEnabled returns whether the websocket is enabled
func (m *Manager) IsEnabled() bool {
return m.enabled.Load()
}
// CanUseAuthenticatedWebsocketForWrapper Handles a common check to
// verify whether a wrapper can use an authenticated websocket endpoint
func (m *Manager) CanUseAuthenticatedWebsocketForWrapper() bool {
if m.IsConnected() {
if m.CanUseAuthenticatedEndpoints() {
return true
}
log.Infof(log.WebsocketMgr, "%v - Websocket not authenticated, using REST\n", m.exchangeName)
}
return false
}
// SetWebsocketURL sets websocket URL and can refresh underlying connections
func (m *Manager) SetWebsocketURL(u string, auth, reconnect bool) error {
if m.useMultiConnectionManagement {
// TODO: Add functionality for multi-connection management to change URL
return fmt.Errorf("%s: %w", m.exchangeName, errCannotChangeConnectionURL)
}
defaultVals := u == "" || u == config.WebsocketURLNonDefaultMessage
if auth {
if defaultVals {
u = m.defaultURLAuth
}
err := checkWebsocketURL(u)
if err != nil {
return err
}
m.runningURLAuth = u
if m.verbose {
log.Debugf(log.WebsocketMgr, "%s websocket: setting authenticated websocket URL: %s\n", m.exchangeName, u)
}
if m.AuthConn != nil {
m.AuthConn.SetURL(u)
}
} else {
if defaultVals {
u = m.defaultURL
}
err := checkWebsocketURL(u)
if err != nil {
return err
}
m.runningURL = u
if m.verbose {
log.Debugf(log.WebsocketMgr, "%s websocket: setting unauthenticated websocket URL: %s\n", m.exchangeName, u)
}
if m.Conn != nil {
m.Conn.SetURL(u)
}
}
if m.IsConnected() && reconnect {
log.Debugf(log.WebsocketMgr, "%s websocket: flushing websocket connection to %s\n", m.exchangeName, u)
return m.Shutdown()
}
return nil
}
// GetWebsocketURL returns the running websocket URL
func (m *Manager) GetWebsocketURL() string {
return m.runningURL
}
// SetProxyAddress sets websocket proxy address
func (m *Manager) SetProxyAddress(proxyAddr string) error {
m.m.Lock()
defer m.m.Unlock()
if proxyAddr != "" {
if _, err := url.ParseRequestURI(proxyAddr); err != nil {
return fmt.Errorf("%v websocket: cannot set proxy address: %w", m.exchangeName, err)
}
if m.proxyAddr == proxyAddr {
return fmt.Errorf("%v websocket: %w '%v'", m.exchangeName, errSameProxyAddress, m.proxyAddr)
}
log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", m.exchangeName, proxyAddr)
} else {
log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", m.exchangeName)
}
for _, wrapper := range m.connectionManager {
if wrapper.connection != nil {
wrapper.connection.SetProxy(proxyAddr)
}
}
if m.Conn != nil {
m.Conn.SetProxy(proxyAddr)
}
if m.AuthConn != nil {
m.AuthConn.SetProxy(proxyAddr)
}
m.proxyAddr = proxyAddr
if !m.IsConnected() {
return nil
}
if err := m.shutdown(); err != nil {
return err
}
return m.connect()
}
// GetProxyAddress returns the current websocket proxy
func (m *Manager) GetProxyAddress() string {
return m.proxyAddr
}
// GetName returns exchange name
func (m *Manager) GetName() string {
return m.exchangeName
}
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
func (m *Manager) SetCanUseAuthenticatedEndpoints(b bool) {
m.canUseAuthenticatedEndpoints.Store(b)
}
// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner
func (m *Manager) CanUseAuthenticatedEndpoints() bool {
return m.canUseAuthenticatedEndpoints.Load()
}
// checkWebsocketURL checks for a valid websocket url
func checkWebsocketURL(s string) error {
u, err := url.Parse(s)
if err != nil {
return err
}
if u.Scheme != "ws" && u.Scheme != "wss" {
return fmt.Errorf("cannot set %w %s", errInvalidWebsocketURL, s)
}
return nil
}
// Reader reads and handles data from a specific connection
func (m *Manager) Reader(ctx context.Context, conn Connection, handler func(ctx context.Context, conn Connection, message []byte) error) {
defer m.Wg.Done()
for {
resp := conn.ReadMessage()
if resp.Raw == nil {
return // Connection has been closed
}
if err := handler(ctx, conn, resp.Raw); err != nil {
m.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err)
}
}
}
func drain(ch <-chan error) {
for {
select {
case <-ch:
default:
return
}
}
}
// ClosureFrame is a closure function that wraps monitoring variables with observer, if the return is true the frame will exit
type ClosureFrame func() func() bool
// monitorFrame monitors a specific websocket component or critical system. It will exit if the observer returns true
// This is used for monitoring data throughput, connection status and other critical websocket components. The waitgroup
// is optional and is used to signal when the monitor has finished.
func (m *Manager) monitorFrame(wg *sync.WaitGroup, fn ClosureFrame) {
if wg != nil {
defer m.Wg.Done()
}
observe := fn()
for {
if observe() {
return
}
}
}
// monitorData monitors data throughput and logs if there is a back log of data
func (m *Manager) monitorData() func() bool {
dropped := 0
return func() bool { return m.observeData(&dropped) }
}
// observeData observes data throughput and logs if there is a back log of data
func (m *Manager) observeData(dropped *int) (exit bool) {
select {
case <-m.ShutdownC:
return true
case d := <-m.DataHandler:
select {
case m.ToRoutine <- d:
if *dropped != 0 {
log.Infof(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer recovered; %d messages were dropped", m.exchangeName, dropped)
*dropped = 0
}
default:
if *dropped == 0 {
// If this becomes prone to flapping we could drain the buffer, but that's extreme and we'd like to avoid it if possible
log.Warnf(log.WebsocketMgr, "%s exchange websocket ToRoutine channel buffer full; dropping messages", m.exchangeName)
}
*dropped++
}
return false
}
}
// monitorConnection monitors the connection and attempts to reconnect if the connection is lost
func (m *Manager) monitorConnection() func() bool {
timer := time.NewTimer(m.connectionMonitorDelay)
return func() bool { return m.observeConnection(timer) }
}
// observeConnection observes the connection and attempts to reconnect if the connection is lost
func (m *Manager) observeConnection(t *time.Timer) (exit bool) {
select {
case err := <-m.ReadMessageErrors:
if errors.Is(err, errConnectionFault) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", m.exchangeName, err)
if m.IsConnected() {
if shutdownErr := m.Shutdown(); shutdownErr != nil {
log.Errorf(log.WebsocketMgr, "%v websocket: connectionMonitor shutdown err: %s", m.exchangeName, shutdownErr)
}
}
}
// Speedier reconnection, instead of waiting for the next cycle.
if m.IsEnabled() && (!m.IsConnected() && !m.IsConnecting()) {
if connectErr := m.Connect(); connectErr != nil {
log.Errorln(log.WebsocketMgr, connectErr)
}
}
m.DataHandler <- err // hand over the error to the data handler (shutdown and reconnection is priority)
case <-t.C:
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", m.exchangeName)
}
if !m.IsEnabled() {
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", m.exchangeName)
}
if m.IsConnected() {
if err := m.Shutdown(); err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", m.exchangeName)
}
t.Stop()
m.connectionMonitorRunning.Store(false)
return true
}
if !m.IsConnecting() && !m.IsConnected() {
err := m.Connect()
if err != nil {
log.Errorln(log.WebsocketMgr, err)
}
}
t.Reset(m.connectionMonitorDelay)
}
return false
}
// monitorTraffic monitors to see if there has been traffic within the trafficTimeout time window. If there is no traffic
// the connection is shutdown and will be reconnected by the connectionMonitor routine.
func (m *Manager) monitorTraffic() func() bool {
timer := time.NewTimer(m.trafficTimeout)
return func() bool { return m.observeTraffic(timer) }
}
func (m *Manager) observeTraffic(t *time.Timer) bool {
select {
case <-m.ShutdownC:
if m.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", m.exchangeName)
}
case <-t.C:
if m.IsConnecting() || signalReceived(m.TrafficAlert) {
t.Reset(m.trafficTimeout)
return false
}
if m.verbose {
log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", m.exchangeName, m.trafficTimeout)
}
if m.IsConnected() {
go func() { // Without this the m.Shutdown() call below will deadlock
if err := m.Shutdown(); err != nil {
log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", m.exchangeName, err)
}
}()
}
}
t.Stop()
return true
}
// signalReceived checks if a signal has been received, this also clears the signal.
func signalReceived(ch chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}
// GetConnection returns a connection by message filter (defined in exchange package _wrapper.go websocket connection)
// for request and response handling in a multi connection context.
func (m *Manager) GetConnection(messageFilter any) (Connection, error) {
if err := common.NilGuard(m); err != nil {
return nil, err
}
if messageFilter == nil {
return nil, fmt.Errorf("%w: messageFilter", common.ErrNilPointer)
}
m.m.Lock()
defer m.m.Unlock()
if !m.useMultiConnectionManagement {
return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", m.exchangeName, errCannotObtainOutboundConnection)
}
if !m.IsConnected() {
return nil, ErrNotConnected
}
for _, wrapper := range m.connectionManager {
if wrapper.setup.MessageFilter == messageFilter {
if wrapper.connection == nil {
return nil, fmt.Errorf("%s: %s %w associated with message filter: '%v'", m.exchangeName, wrapper.setup.URL, ErrNotConnected, messageFilter)
}
return wrapper.connection, nil
}
}
return nil, fmt.Errorf("%s: %w associated with message filter: '%v'", m.exchangeName, ErrRequestRouteNotFound, messageFilter)
}