mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-07 15:11:03 +00:00
Websocket: Restructure files and types (#1859)
* Websocket: Rename stream package * Websocket: Rename Websocket to Manager * Websocket: Replace explicit errs with common.NilGuard * Websocket: Move websocket_types.go to types.go * Websocket: Minor field comment and alignment in types * Webosocket: Rename WebsocketConnection to Connection * Alphapoint: Make gorilla ws import explicit Just to avoid confusion with our own packages. * Websocket: Move stream_match to match * Websocket: Move websocket_connection to connection * Websocket: Move websocket.go to manager.go * Websocket: Break out all subscription methods into subscriptions.go * Websocket: Move connection type into its file * Websocket: Remove PositionUpdated Type is not used anywhere * Kraken: Use local constant for pong Was the only use of websocket.Pong and doesn't really feel right to represent kraken's api resp in one of our packages * Websocket: Move connection sub-types to connection package * Websocket: Move manager types into manager * Websocket: Move ConnectionWrapper into manager * Websocket: Move websocket_test to manager_test * Websocket: Privatise connectionWrapper * Websocket: Remaining types into types.go These really belong somewhere else mostly, but this will do for now * Websocket: Tidy up connection method vars * Gofumpt: Moving package imports around * Websocket: Rename errDuplicateConnectionSetup * Websocket: Fix duplicate import of gws * Websocket: Fix gofumpt -extra * Websocket: Standardise import of gws across other pkgs * Kraken: Remove unused sub conf consts These were replaced by the generic Levels and Depth fields on all subs * Websocket: Privitise ConnectioWrapper fields * Websocket: inline single use var WebsocketNotAuthenticatedUsingRest * Websocket: Move documentation to template * Bithumb: Assertify TestWsHandleData
This commit is contained in:
174
internal/exchange/websocket/README.md
Normal file
174
internal/exchange/websocket/README.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# GoCryptoTrader package Websocket
|
||||
|
||||
<img src="/common/gctlogo.png?raw=true" width="350px" height="350px" hspace="70">
|
||||
|
||||
|
||||
[](https://github.com/thrasher-corp/gocryptotrader/actions/workflows/tests.yml)
|
||||
[](https://github.com/thrasher-corp/gocryptotrader/blob/master/LICENSE)
|
||||
[](https://godoc.org/github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket)
|
||||
[](https://codecov.io/gh/thrasher-corp/gocryptotrader)
|
||||
[](https://goreportcard.com/report/github.com/thrasher-corp/gocryptotrader)
|
||||
|
||||
|
||||
This websocket package is part of the GoCryptoTrader codebase.
|
||||
|
||||
## This is still in active development
|
||||
|
||||
You can track ideas, planned features and what's in progress on our [GoCryptoTrader Kanban board](https://github.com/orgs/thrasher-corp/projects/3).
|
||||
|
||||
Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk)
|
||||
|
||||
## Overview
|
||||
|
||||
The `websocket` package provides methods to manage connections and subscriptions for exchange websockets.
|
||||
|
||||
## Features
|
||||
|
||||
- Handle real-time market data streams
|
||||
- Unified interface for managing data streams
|
||||
- Multi-connection management - a system that can be used to manage multiple connections to the same exchange
|
||||
- Connection monitoring - a system that can be used to monitor the health of the websocket connections. This can be used to check if the connection is still alive and if it is not, it will attempt to reconnect
|
||||
- Traffic monitoring - will reconnect if no message is sent for a period of time defined in your config
|
||||
- Subscription management - a system that can be used to manage subscriptions to various data streams
|
||||
- Rate limiting - a system that can be used to rate limit the number of requests sent to the exchange
|
||||
- Message ID generation - a system that can be used to generate message IDs for websocket requests
|
||||
- Websocket message response matching - can be used to match websocket responses to the requests that were sent
|
||||
|
||||
## Usage
|
||||
|
||||
### Default single websocket connection
|
||||
|
||||
Example setup for the `websocket` package connection:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket"
|
||||
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
|
||||
)
|
||||
|
||||
type Exchange struct {
|
||||
exchange.Base
|
||||
}
|
||||
|
||||
// In the exchange wrapper this will set up the initial pointer field provided by exchange.Base
|
||||
func (e *Exchange) SetDefault() {
|
||||
e.Websocket = websocket.NewManager()
|
||||
e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
|
||||
e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
|
||||
e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
|
||||
}
|
||||
|
||||
// In the exchange wrapper this is the original setup pattern for the websocket services
|
||||
func (e *Exchange) Setup(exch *config.Exchange) error {
|
||||
// This sets up global connection, sub, unsub and generate subscriptions for each connection defined below.
|
||||
if err := e.Websocket.Setup(&websocket.ManagerSetup{
|
||||
ExchangeConfig: exch,
|
||||
DefaultURL: connectionURLString,
|
||||
RunningURL: connectionURLString,
|
||||
Connector: e.WsConnect,
|
||||
Subscriber: e.Subscribe,
|
||||
Unsubscriber: e.Unsubscribe,
|
||||
GenerateSubscriptions: e.GenerateDefaultSubscriptions,
|
||||
Features: &e.Features.Supports.WebsocketCapabilities,
|
||||
MaxWebsocketSubscriptionsPerConnection: 240,
|
||||
OrderbookBufferConfig: buffer.Config{ Checksum: e.CalculateUpdateOrderbookChecksum },
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This is a public websocket connection
|
||||
if err := ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
|
||||
URL: connectionURLString,
|
||||
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
|
||||
ResponseMaxLimit: exchangeWebsocketResponseMaxLimit,
|
||||
RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This is a private websocket connection
|
||||
return ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
|
||||
URL: privateConnectionURLString,
|
||||
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
|
||||
ResponseMaxLimit: exchangeWebsocketResponseMaxLimit,
|
||||
Authenticated: true,
|
||||
RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1),
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple websocket connections
|
||||
The example below provides the now optional multi connection management system which allows for more connections
|
||||
to be maintained and established based off URL, connections types, asset types etc.
|
||||
```go
|
||||
func (e *Exchange) Setup(exch *config.Exchange) error {
|
||||
// This sets up global connection, sub, unsub and generate subscriptions for each connection defined below.
|
||||
if err := e.Websocket.Setup(&websocket.ManagerSetup{
|
||||
ExchangeConfig: exch,
|
||||
Features: &e.Features.Supports.WebsocketCapabilities,
|
||||
FillsFeed: e.Features.Enabled.FillsFeed,
|
||||
TradeFeed: e.Features.Enabled.TradeFeed,
|
||||
UseMultiConnectionManagement: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Spot connection
|
||||
err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
|
||||
URL: connectionURLStringForSpot,
|
||||
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
|
||||
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
|
||||
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
|
||||
// Custom handlers for the specific connection:
|
||||
Handler: e.WsHandleSpotData,
|
||||
Subscriber: e.SpotSubscribe,
|
||||
Unsubscriber: e.SpotUnsubscribe,
|
||||
GenerateSubscriptions: e.GenerateDefaultSubscriptionsSpot,
|
||||
Connector: e.WsConnectSpot,
|
||||
BespokeGenerateMessageID: e.GenerateWebsocketMessageID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Futures connection - USDT margined
|
||||
err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
|
||||
URL: connectionURLStringForSpotForFutures,
|
||||
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
|
||||
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
|
||||
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
|
||||
// Custom handlers for the specific connection:
|
||||
Handler: func(ctx context.Context, incoming []byte) error { return e.WsHandleFuturesData(ctx, incoming, asset.Futures) },
|
||||
Subscriber: e.FuturesSubscribe,
|
||||
Unsubscriber: e.FuturesUnsubscribe,
|
||||
GenerateSubscriptions: func() (subscription.List, error) { return e.GenerateFuturesDefaultSubscriptions(currency.USDT) },
|
||||
Connector: e.WsFuturesConnect,
|
||||
BespokeGenerateMessageID: e.GenerateWebsocketMessageID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Contribution
|
||||
|
||||
Please feel free to submit any pull requests or suggest any desired features to be added.
|
||||
|
||||
When submitting a PR, please abide by our coding guidelines:
|
||||
|
||||
+ Code must adhere to the official Go [formatting](https://golang.org/doc/effective_go.html#formatting) guidelines (i.e. uses [gofmt](https://golang.org/cmd/gofmt/)).
|
||||
+ Code must be documented adhering to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines.
|
||||
+ Code must adhere to our [coding style](https://github.com/thrasher-corp/gocryptotrader/blob/master/doc/coding_style.md).
|
||||
+ Pull requests need to be based on and opened against the `master` branch.
|
||||
|
||||
## Donations
|
||||
|
||||
<img src="https://github.com/thrasher-corp/gocryptotrader/blob/master/web/src/assets/donate.png?raw=true" hspace="70">
|
||||
|
||||
If this framework helped you in any way, or you would like to support the developers working on it, please donate Bitcoin to:
|
||||
|
||||
***bc1qk0jareu4jytc0cfrhr5wgshsq8282awpavfahc***
|
||||
344
internal/exchange/websocket/buffer/buffer.go
Normal file
344
internal/exchange/websocket/buffer/buffer.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/config"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
const packageError = "websocket orderbook buffer error: %w"
|
||||
|
||||
var (
|
||||
errExchangeConfigNil = errors.New("exchange config is nil")
|
||||
errBufferConfigNil = errors.New("buffer config is nil")
|
||||
errUnsetDataHandler = errors.New("datahandler unset")
|
||||
errIssueBufferEnabledButNoLimit = errors.New("buffer enabled but no limit set")
|
||||
errUpdateIsNil = errors.New("update is nil")
|
||||
errUpdateNoTargets = errors.New("update bid/ask targets cannot be nil")
|
||||
errDepthNotFound = errors.New("orderbook depth not found")
|
||||
errRESTOverwrite = errors.New("orderbook has been overwritten by REST protocol")
|
||||
errInvalidAction = errors.New("invalid action")
|
||||
errAmendFailure = errors.New("orderbook amend update failure")
|
||||
errDeleteFailure = errors.New("orderbook delete update failure")
|
||||
errInsertFailure = errors.New("orderbook insert update failure")
|
||||
errUpdateInsertFailure = errors.New("orderbook update/insert update failure")
|
||||
errRESTTimerLapse = errors.New("rest sync timer lapse with active websocket connection")
|
||||
errOrderbookFlushed = errors.New("orderbook flushed")
|
||||
)
|
||||
|
||||
// Setup sets private variables
|
||||
func (w *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandler chan<- any) error {
|
||||
if exchangeConfig == nil { // exchange config fields are checked in websocket package prior to calling this, so further checks are not needed
|
||||
return fmt.Errorf(packageError, errExchangeConfigNil)
|
||||
}
|
||||
if c == nil {
|
||||
return fmt.Errorf(packageError, errBufferConfigNil)
|
||||
}
|
||||
if dataHandler == nil {
|
||||
return fmt.Errorf(packageError, errUnsetDataHandler)
|
||||
}
|
||||
if exchangeConfig.Orderbook.WebsocketBufferEnabled &&
|
||||
exchangeConfig.Orderbook.WebsocketBufferLimit < 1 {
|
||||
return fmt.Errorf(packageError, errIssueBufferEnabledButNoLimit)
|
||||
}
|
||||
|
||||
// NOTE: These variables are set by config.json under "orderbook" for each individual exchange
|
||||
w.bufferEnabled = exchangeConfig.Orderbook.WebsocketBufferEnabled
|
||||
w.obBufferLimit = exchangeConfig.Orderbook.WebsocketBufferLimit
|
||||
|
||||
w.sortBuffer = c.SortBuffer
|
||||
w.sortBufferByUpdateIDs = c.SortBufferByUpdateIDs
|
||||
w.updateEntriesByID = c.UpdateEntriesByID
|
||||
w.exchangeName = exchangeConfig.Name
|
||||
w.dataHandler = dataHandler
|
||||
w.ob = make(map[key.PairAsset]*orderbookHolder)
|
||||
w.verbose = exchangeConfig.Verbose
|
||||
w.updateIDProgression = c.UpdateIDProgression
|
||||
w.checksum = c.Checksum
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate validates update against setup values
|
||||
func (w *Orderbook) validate(u *orderbook.Update) error {
|
||||
if u == nil {
|
||||
return fmt.Errorf(packageError, errUpdateIsNil)
|
||||
}
|
||||
if len(u.Bids) == 0 && len(u.Asks) == 0 {
|
||||
return fmt.Errorf(packageError, errUpdateNoTargets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates a stored pointer to an orderbook.Depth struct containing a
|
||||
// bid and ask Tranches, this switches between the usage of a buffered update
|
||||
func (w *Orderbook) Update(u *orderbook.Update) error {
|
||||
if err := w.validate(u); err != nil {
|
||||
return err
|
||||
}
|
||||
w.mtx.Lock()
|
||||
defer w.mtx.Unlock()
|
||||
book, ok := w.ob[key.PairAsset{Base: u.Pair.Base.Item, Quote: u.Pair.Quote.Item, Asset: u.Asset}]
|
||||
if !ok {
|
||||
return fmt.Errorf("%w for Exchange %s CurrencyPair: %s AssetType: %s",
|
||||
errDepthNotFound,
|
||||
w.exchangeName,
|
||||
u.Pair,
|
||||
u.Asset)
|
||||
}
|
||||
|
||||
// out of order update ID can be skipped
|
||||
if w.updateIDProgression && u.UpdateID <= book.updateID {
|
||||
if w.verbose {
|
||||
log.Warnf(log.WebsocketMgr,
|
||||
"Exchange %s CurrencyPair: %s AssetType: %s out of order websocket update received",
|
||||
w.exchangeName,
|
||||
u.Pair,
|
||||
u.Asset)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Checks for when the rest protocol overwrites a streaming dominated book
|
||||
// will stop updating book via incremental updates. This occurs because our
|
||||
// sync manager (engine/sync.go) timer has elapsed for streaming. Usually
|
||||
// because the book is highly illiquid.
|
||||
isREST, err := book.ob.IsRESTSnapshot()
|
||||
if err != nil {
|
||||
if !errors.Is(err, orderbook.ErrOrderbookInvalid) {
|
||||
return err
|
||||
}
|
||||
// In the event a checksum or processing error invalidates the book, all
|
||||
// updates that could be stored in the websocket buffer, skip applying
|
||||
// until a new snapshot comes through.
|
||||
if w.verbose {
|
||||
log.Warnf(log.WebsocketMgr,
|
||||
"Exchange %s CurrencyPair: %s AssetType: %s underlying book is invalid, cannot apply update.",
|
||||
w.exchangeName,
|
||||
u.Pair,
|
||||
u.Asset)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if isREST {
|
||||
if w.verbose {
|
||||
log.Warnf(log.WebsocketMgr,
|
||||
"%s for Exchange %s CurrencyPair: %s AssetType: %s consider extending synctimeoutwebsocket",
|
||||
errRESTOverwrite,
|
||||
w.exchangeName,
|
||||
u.Pair,
|
||||
u.Asset)
|
||||
}
|
||||
// Instance of illiquidity, this signal notifies that there is websocket
|
||||
// activity. We can invalidate the book and request a new snapshot. All
|
||||
// further updates through the websocket should be caught above in the
|
||||
// IsRestSnapshot() call.
|
||||
return book.ob.Invalidate(errRESTTimerLapse)
|
||||
}
|
||||
|
||||
if w.bufferEnabled {
|
||||
var processed bool
|
||||
processed, err = w.processBufferUpdate(book, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !processed {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
err = w.processObUpdate(book, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Publish all state changes, disregarding verbosity or sync requirements.
|
||||
book.ob.Publish()
|
||||
w.dataHandler <- book.ob
|
||||
return nil
|
||||
}
|
||||
|
||||
// processBufferUpdate stores update into buffer, when buffer at capacity as
|
||||
// defined by w.obBufferLimit it well then sort and apply updates.
|
||||
func (w *Orderbook) processBufferUpdate(o *orderbookHolder, u *orderbook.Update) (bool, error) {
|
||||
*o.buffer = append(*o.buffer, *u)
|
||||
if len(*o.buffer) < w.obBufferLimit {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if w.sortBuffer {
|
||||
// sort by last updated to ensure each update is in order
|
||||
if w.sortBufferByUpdateIDs {
|
||||
sort.Slice(*o.buffer, func(i, j int) bool {
|
||||
return (*o.buffer)[i].UpdateID < (*o.buffer)[j].UpdateID
|
||||
})
|
||||
} else {
|
||||
sort.Slice(*o.buffer, func(i, j int) bool {
|
||||
return (*o.buffer)[i].UpdateTime.Before((*o.buffer)[j].UpdateTime)
|
||||
})
|
||||
}
|
||||
}
|
||||
for i := range *o.buffer {
|
||||
err := w.processObUpdate(o, &(*o.buffer)[i])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
// clear buffer of old updates
|
||||
*o.buffer = nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// processObUpdate processes updates either by its corresponding id or by price level
|
||||
func (w *Orderbook) processObUpdate(o *orderbookHolder, u *orderbook.Update) error {
|
||||
// Both update methods require post processing to ensure the orderbook is in a valid state.
|
||||
if w.updateEntriesByID {
|
||||
if err := o.updateByIDAndAction(u); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := o.updateByPrice(u); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if w.checksum != nil {
|
||||
compare, err := o.ob.Retrieve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.checksum(compare, u.Checksum)
|
||||
if err != nil {
|
||||
return o.ob.Invalidate(err)
|
||||
}
|
||||
o.updateID = u.UpdateID
|
||||
} else if o.ob.VerifyOrderbook() {
|
||||
compare, err := o.ob.Retrieve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = compare.Verify()
|
||||
if err != nil {
|
||||
return o.ob.Invalidate(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateByPrice amends amount if match occurs by price, deletes if amount is
|
||||
// zero or less and inserts if not found.
|
||||
func (o *orderbookHolder) updateByPrice(updts *orderbook.Update) error {
|
||||
return o.ob.UpdateBidAskByPrice(updts)
|
||||
}
|
||||
|
||||
// updateByIDAndAction will receive an action to execute against the orderbook
|
||||
// it will then match by IDs instead of price to perform the action
|
||||
func (o *orderbookHolder) updateByIDAndAction(updts *orderbook.Update) error {
|
||||
switch updts.Action {
|
||||
case orderbook.Amend:
|
||||
err := o.ob.UpdateBidAskByID(updts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %w", errAmendFailure, err)
|
||||
}
|
||||
case orderbook.Delete:
|
||||
// edge case for Bitfinex as their streaming endpoint duplicates deletes
|
||||
bypassErr := o.ob.GetName() == "Bitfinex" && o.ob.IsFundingRate()
|
||||
err := o.ob.DeleteBidAskByID(updts, bypassErr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %w", errDeleteFailure, err)
|
||||
}
|
||||
case orderbook.Insert:
|
||||
err := o.ob.InsertBidAskByID(updts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %w", errInsertFailure, err)
|
||||
}
|
||||
case orderbook.UpdateInsert:
|
||||
err := o.ob.UpdateInsertByID(updts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %w", errUpdateInsertFailure, err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("%w [%d]", errInvalidAction, updts.Action)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSnapshot loads initial snapshot of orderbook data from websocket
|
||||
func (w *Orderbook) LoadSnapshot(book *orderbook.Base) error {
|
||||
// Checks if book can deploy to depth
|
||||
err := book.Verify()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.mtx.Lock()
|
||||
defer w.mtx.Unlock()
|
||||
holder, ok := w.ob[key.PairAsset{Base: book.Pair.Base.Item, Quote: book.Pair.Quote.Item, Asset: book.Asset}]
|
||||
if !ok {
|
||||
// Associate orderbook pointer with local exchange depth map
|
||||
var depth *orderbook.Depth
|
||||
depth, err = orderbook.DeployDepth(book.Exchange, book.Pair, book.Asset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
depth.AssignOptions(book)
|
||||
buffer := make([]orderbook.Update, w.obBufferLimit)
|
||||
|
||||
holder = &orderbookHolder{ob: depth, buffer: &buffer}
|
||||
w.ob[key.PairAsset{Base: book.Pair.Base.Item, Quote: book.Pair.Quote.Item, Asset: book.Asset}] = holder
|
||||
}
|
||||
|
||||
holder.updateID = book.LastUpdateID
|
||||
|
||||
err = holder.ob.LoadSnapshot(book.Bids, book.Asks, book.LastUpdateID, book.LastUpdated, book.UpdatePushedAt, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
holder.ob.Publish()
|
||||
w.dataHandler <- holder.ob
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrderbook returns an orderbook copy as orderbook.Base
|
||||
func (w *Orderbook) GetOrderbook(p currency.Pair, a asset.Item) (*orderbook.Base, error) {
|
||||
w.mtx.Lock()
|
||||
defer w.mtx.Unlock()
|
||||
book, ok := w.ob[key.PairAsset{Base: p.Base.Item, Quote: p.Quote.Item, Asset: a}]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %s %s %w", w.exchangeName, p, a, errDepthNotFound)
|
||||
}
|
||||
return book.ob.Retrieve()
|
||||
}
|
||||
|
||||
// FlushBuffer flushes w.ob data to be garbage collected and refreshed when a
|
||||
// connection is lost and reconnected
|
||||
func (w *Orderbook) FlushBuffer() {
|
||||
w.mtx.Lock()
|
||||
w.ob = make(map[key.PairAsset]*orderbookHolder)
|
||||
w.mtx.Unlock()
|
||||
}
|
||||
|
||||
// FlushOrderbook flushes independent orderbook
|
||||
func (w *Orderbook) FlushOrderbook(p currency.Pair, a asset.Item) error {
|
||||
w.mtx.Lock()
|
||||
defer w.mtx.Unlock()
|
||||
book, ok := w.ob[key.PairAsset{Base: p.Base.Item, Quote: p.Quote.Item, Asset: a}]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot flush orderbook %s %s %s %w",
|
||||
w.exchangeName,
|
||||
p,
|
||||
a,
|
||||
errDepthNotFound)
|
||||
}
|
||||
// error not needed in this return
|
||||
_ = book.ob.Invalidate(errOrderbookFlushed)
|
||||
return nil
|
||||
}
|
||||
992
internal/exchange/websocket/buffer/buffer_test.go
Normal file
992
internal/exchange/websocket/buffer/buffer_test.go
Normal file
@@ -0,0 +1,992 @@
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/config"
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
|
||||
)
|
||||
|
||||
var (
|
||||
itemArray = [][]orderbook.Tranche{
|
||||
{{Price: 1000, Amount: 1, ID: 1000}},
|
||||
{{Price: 2000, Amount: 1, ID: 2000}},
|
||||
{{Price: 3000, Amount: 1, ID: 3000}},
|
||||
{{Price: 3000, Amount: 2, ID: 4000}},
|
||||
{{Price: 4000, Amount: 0, ID: 6000}},
|
||||
{{Price: 5000, Amount: 1, ID: 5000}},
|
||||
}
|
||||
offset = common.Counter{}
|
||||
)
|
||||
|
||||
const exchangeName = "exchangeTest"
|
||||
|
||||
// getExclusivePair returns a currency pair with a unique ID for testing as books are centralised and changes will affect other tests
|
||||
func getExclusivePair() (currency.Pair, error) {
|
||||
return currency.NewPairFromStrings(currency.BTC.String(), currency.USDT.String()+strconv.FormatInt(offset.IncrementAndGet(), 10))
|
||||
}
|
||||
|
||||
func createSnapshot(pair currency.Pair, bookVerifiy ...bool) (holder *Orderbook, asks, bids orderbook.Tranches, err error) {
|
||||
asks = orderbook.Tranches{{Price: 4000, Amount: 1, ID: 6}}
|
||||
bids = orderbook.Tranches{{Price: 4000, Amount: 1, ID: 6}}
|
||||
|
||||
book := &orderbook.Base{
|
||||
Exchange: exchangeName,
|
||||
Asks: asks,
|
||||
Bids: bids,
|
||||
Asset: asset.Spot,
|
||||
Pair: pair,
|
||||
PriceDuplication: true,
|
||||
LastUpdated: time.Now(),
|
||||
VerifyOrderbook: len(bookVerifiy) > 0 && bookVerifiy[0],
|
||||
}
|
||||
|
||||
newBook := make(map[key.PairAsset]*orderbookHolder)
|
||||
|
||||
ch := make(chan any)
|
||||
go func(<-chan any) { // reader
|
||||
for range ch {
|
||||
continue
|
||||
}
|
||||
}(ch)
|
||||
holder = &Orderbook{
|
||||
exchangeName: exchangeName,
|
||||
dataHandler: ch,
|
||||
ob: newBook,
|
||||
}
|
||||
err = holder.LoadSnapshot(book)
|
||||
return holder, asks, bids, err
|
||||
}
|
||||
|
||||
func bidAskGenerator() []orderbook.Tranche {
|
||||
response := make([]orderbook.Tranche, 100)
|
||||
for i := range 100 {
|
||||
price := float64(rand.Intn(1000)) //nolint:gosec // no need to import crypo/rand for testing
|
||||
if price == 0 {
|
||||
price = 1
|
||||
}
|
||||
response[i] = orderbook.Tranche{
|
||||
Amount: float64(rand.Intn(10)), //nolint:gosec // no need to import crypo/rand for testing
|
||||
Price: price,
|
||||
ID: int64(i),
|
||||
}
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func BenchmarkUpdateBidsByPrice(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
ob, _, _, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
for b.Loop() {
|
||||
bidAsks := bidAskGenerator()
|
||||
update := &orderbook.Update{
|
||||
Bids: bidAsks,
|
||||
Asks: bidAsks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
holder := ob.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
require.NoError(b, holder.updateByPrice(update))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdateAsksByPrice(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
ob, _, _, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
for b.Loop() {
|
||||
bidAsks := bidAskGenerator()
|
||||
update := &orderbook.Update{
|
||||
Bids: bidAsks,
|
||||
Asks: bidAsks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
holder := ob.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
require.NoError(b, holder.updateByPrice(update))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBufferPerformance demonstrates buffer more performant than multi
|
||||
// process calls
|
||||
// 890016 1688 ns/op 416 B/op 3 allocs/op
|
||||
func BenchmarkBufferPerformance(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
holder, asks, bids, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
update := &orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
for b.Loop() {
|
||||
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
|
||||
update.Asks = itemArray[randomIndex]
|
||||
update.Bids = itemArray[randomIndex]
|
||||
require.NoError(b, holder.Update(update))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBufferSortingPerformance benchmark
|
||||
//
|
||||
// 613964 2093 ns/op 440 B/op 4 allocs/op
|
||||
func BenchmarkBufferSortingPerformance(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
holder, asks, bids, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.sortBuffer = true
|
||||
update := &orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
for b.Loop() {
|
||||
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
|
||||
update.Asks = itemArray[randomIndex]
|
||||
update.Bids = itemArray[randomIndex]
|
||||
require.NoError(b, holder.Update(update))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBufferSortingPerformance benchmark
|
||||
// 914500 1599 ns/op 440 B/op 4 allocs/op
|
||||
func BenchmarkBufferSortingByIDPerformance(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
holder, asks, bids, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.sortBuffer = true
|
||||
holder.sortBufferByUpdateIDs = true
|
||||
update := &orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
|
||||
update.Asks = itemArray[randomIndex]
|
||||
update.Bids = itemArray[randomIndex]
|
||||
require.NoError(b, holder.Update(update))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNoBufferPerformance demonstrates orderbook process more performant
|
||||
// than buffer
|
||||
// 122659 12792 ns/op 972 B/op 7 allocs/op PRIOR
|
||||
// 1225924 1028 ns/op 240 B/op 2 allocs/op CURRENT
|
||||
|
||||
func BenchmarkNoBufferPerformance(b *testing.B) {
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(b, err)
|
||||
|
||||
obl, asks, bids, err := createSnapshot(cp)
|
||||
require.NoError(b, err)
|
||||
|
||||
update := &orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
|
||||
update.Asks = itemArray[randomIndex]
|
||||
update.Bids = itemArray[randomIndex]
|
||||
require.NoError(b, obl.Update(update))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdates(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
err = book.updateByPrice(&orderbook.Update{
|
||||
Bids: itemArray[5],
|
||||
Asks: itemArray[5],
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = book.updateByPrice(&orderbook.Update{
|
||||
Bids: itemArray[0],
|
||||
Asks: itemArray[0],
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
askLen, err := book.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, askLen)
|
||||
}
|
||||
|
||||
// TestHittingTheBuffer logic test
|
||||
func TestHittingTheBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.obBufferLimit = 5
|
||||
for i := range itemArray {
|
||||
asks := itemArray[i]
|
||||
bids := itemArray[i]
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
askLen, err := book.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, askLen)
|
||||
|
||||
bidLen, err := book.ob.GetBidLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, bidLen)
|
||||
}
|
||||
|
||||
// TestInsertWithIDs logic test
|
||||
func TestInsertWithIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.updateEntriesByID = true
|
||||
holder.obBufferLimit = 5
|
||||
for i := range itemArray {
|
||||
asks := itemArray[i]
|
||||
if asks[0].Amount <= 0 {
|
||||
continue
|
||||
}
|
||||
bids := itemArray[i]
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
Action: orderbook.UpdateInsert,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
askLen, err := book.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, askLen)
|
||||
|
||||
bidLen, err := book.ob.GetBidLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, bidLen)
|
||||
|
||||
cp, err = getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err = createSnapshot(cp, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
holder.checksum = nil
|
||||
holder.updateIDProgression = false
|
||||
err = holder.Update(&orderbook.Update{
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
Asks: []orderbook.Tranche{{Price: 999999}},
|
||||
Pair: cp,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestSortIDs logic test
|
||||
func TestSortIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.sortBufferByUpdateIDs = true
|
||||
holder.sortBuffer = true
|
||||
holder.obBufferLimit = 5
|
||||
for i := range itemArray {
|
||||
asks := itemArray[i]
|
||||
bids := itemArray[i]
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateID: int64(i),
|
||||
Asset: asset.Spot,
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
askLen, err := book.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, askLen)
|
||||
|
||||
bidLen, err := book.ob.GetBidLength()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, bidLen)
|
||||
}
|
||||
|
||||
// TestOutOfOrderIDs logic test
|
||||
func TestOutOfOrderIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
outOFOrderIDs := []int64{2, 1, 5, 3, 4, 6, 7}
|
||||
assert.Equal(t, 1000., itemArray[0][0].Price)
|
||||
|
||||
holder.bufferEnabled = true
|
||||
holder.sortBuffer = true
|
||||
holder.obBufferLimit = 5
|
||||
for i := range itemArray {
|
||||
asks := itemArray[i]
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateID: outOFOrderIDs[i],
|
||||
Asset: asset.Spot,
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
cpy, err := book.ob.Retrieve()
|
||||
require.NoError(t, err)
|
||||
// Index 1 since index 0 is price 7000
|
||||
assert.Equal(t, 2000., cpy.Asks[1].Price)
|
||||
}
|
||||
|
||||
func TestOrderbookLastUpdateID(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1000., itemArray[0][0].Price)
|
||||
|
||||
holder.checksum = func(*orderbook.Base, uint32) error { return errors.New("testerino") }
|
||||
|
||||
// this update invalidates the book
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Asks: []orderbook.Tranche{{Price: 999999}},
|
||||
Pair: cp,
|
||||
UpdateID: -1,
|
||||
Asset: asset.Spot,
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.ErrorIs(t, err, orderbook.ErrOrderbookInvalid)
|
||||
|
||||
cp, err = getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err = createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
holder.checksum = func(*orderbook.Base, uint32) error { return nil }
|
||||
holder.updateIDProgression = true
|
||||
|
||||
for i := range itemArray {
|
||||
asks := itemArray[i]
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateID: int64(i) + 1,
|
||||
Asset: asset.Spot,
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// out of order
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Asks: []orderbook.Tranche{{Price: 999999}},
|
||||
Pair: cp,
|
||||
UpdateID: 1,
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ob, err := holder.GetOrderbook(cp, asset.Spot)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(len(itemArray)), ob.LastUpdateID)
|
||||
}
|
||||
|
||||
// TestRunUpdateWithoutSnapshot logic test
|
||||
func TestRunUpdateWithoutSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
var holder Orderbook
|
||||
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
|
||||
bids := []orderbook.Tranche{{Price: 5999, Amount: 1, ID: 8}, {Price: 4000, Amount: 1, ID: 9}}
|
||||
holder.exchangeName = exchangeName
|
||||
err = holder.Update(&orderbook.Update{
|
||||
Bids: bids,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
require.ErrorIs(t, err, errDepthNotFound)
|
||||
}
|
||||
|
||||
// TestRunUpdateWithoutAnyUpdates logic test
|
||||
func TestRunUpdateWithoutAnyUpdates(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
var obl Orderbook
|
||||
obl.exchangeName = exchangeName
|
||||
err = obl.Update(&orderbook.Update{
|
||||
Bids: []orderbook.Tranche{},
|
||||
Asks: []orderbook.Tranche{},
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
require.ErrorIs(t, err, errUpdateNoTargets)
|
||||
}
|
||||
|
||||
// TestRunSnapshotWithNoData logic test
|
||||
func TestRunSnapshotWithNoData(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
var obl Orderbook
|
||||
obl.ob = make(map[key.PairAsset]*orderbookHolder)
|
||||
obl.dataHandler = make(chan any, 1)
|
||||
var snapShot1 orderbook.Base
|
||||
snapShot1.Asset = asset.Spot
|
||||
snapShot1.Pair = cp
|
||||
snapShot1.Exchange = "test"
|
||||
obl.exchangeName = "test"
|
||||
snapShot1.LastUpdated = time.Now()
|
||||
require.NoError(t, obl.LoadSnapshot(&snapShot1))
|
||||
}
|
||||
|
||||
// TestLoadSnapshot logic test
|
||||
func TestLoadSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
var obl Orderbook
|
||||
obl.dataHandler = make(chan any, 100)
|
||||
obl.ob = make(map[key.PairAsset]*orderbookHolder)
|
||||
var snapShot1 orderbook.Base
|
||||
snapShot1.Exchange = "SnapshotWithOverride"
|
||||
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
|
||||
bids := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 9}}
|
||||
snapShot1.Asks = asks
|
||||
snapShot1.Bids = bids
|
||||
snapShot1.Asset = asset.Spot
|
||||
snapShot1.Pair = cp
|
||||
snapShot1.LastUpdated = time.Now()
|
||||
require.NoError(t, obl.LoadSnapshot(&snapShot1))
|
||||
}
|
||||
|
||||
// TestFlushBuffer logic test
|
||||
func TestFlushBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
obl, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, obl.ob)
|
||||
obl.FlushBuffer()
|
||||
assert.Empty(t, obl.ob)
|
||||
}
|
||||
|
||||
// TestInsertingSnapShots logic test
|
||||
func TestInsertingSnapShots(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
var holder Orderbook
|
||||
holder.dataHandler = make(chan any, 100)
|
||||
holder.ob = make(map[key.PairAsset]*orderbookHolder)
|
||||
var snapShot1 orderbook.Base
|
||||
snapShot1.Exchange = "WSORDERBOOKTEST1"
|
||||
asks := []orderbook.Tranche{
|
||||
{Price: 6000, Amount: 1, ID: 1},
|
||||
{Price: 6001, Amount: 0.5, ID: 2},
|
||||
{Price: 6002, Amount: 2, ID: 3},
|
||||
{Price: 6003, Amount: 3, ID: 4},
|
||||
{Price: 6004, Amount: 5, ID: 5},
|
||||
{Price: 6005, Amount: 2, ID: 6},
|
||||
{Price: 6006, Amount: 1.5, ID: 7},
|
||||
{Price: 6007, Amount: 0.5, ID: 8},
|
||||
{Price: 6008, Amount: 23, ID: 9},
|
||||
{Price: 6009, Amount: 9, ID: 10},
|
||||
{Price: 6010, Amount: 7, ID: 11},
|
||||
}
|
||||
|
||||
bids := []orderbook.Tranche{
|
||||
{Price: 5999, Amount: 1, ID: 12},
|
||||
{Price: 5998, Amount: 0.5, ID: 13},
|
||||
{Price: 5997, Amount: 2, ID: 14},
|
||||
{Price: 5996, Amount: 3, ID: 15},
|
||||
{Price: 5995, Amount: 5, ID: 16},
|
||||
{Price: 5994, Amount: 2, ID: 17},
|
||||
{Price: 5993, Amount: 1.5, ID: 18},
|
||||
{Price: 5992, Amount: 0.5, ID: 19},
|
||||
{Price: 5991, Amount: 23, ID: 20},
|
||||
{Price: 5990, Amount: 9, ID: 21},
|
||||
{Price: 5989, Amount: 7, ID: 22},
|
||||
}
|
||||
|
||||
snapShot1.Asks = asks
|
||||
snapShot1.Bids = bids
|
||||
snapShot1.Asset = asset.Spot
|
||||
snapShot1.Pair = cp
|
||||
snapShot1.LastUpdated = time.Now()
|
||||
require.NoError(t, holder.LoadSnapshot(&snapShot1))
|
||||
|
||||
var snapShot2 orderbook.Base
|
||||
snapShot2.Exchange = "WSORDERBOOKTEST2"
|
||||
asks = []orderbook.Tranche{
|
||||
{Price: 51, Amount: 1, ID: 1},
|
||||
{Price: 52, Amount: 0.5, ID: 2},
|
||||
{Price: 53, Amount: 2, ID: 3},
|
||||
{Price: 54, Amount: 3, ID: 4},
|
||||
{Price: 55, Amount: 5, ID: 5},
|
||||
{Price: 56, Amount: 2, ID: 6},
|
||||
{Price: 57, Amount: 1.5, ID: 7},
|
||||
{Price: 58, Amount: 0.5, ID: 8},
|
||||
{Price: 59, Amount: 23, ID: 9},
|
||||
{Price: 50, Amount: 9, ID: 10},
|
||||
{Price: 60, Amount: 7, ID: 11},
|
||||
}
|
||||
|
||||
bids = []orderbook.Tranche{
|
||||
{Price: 49, Amount: 1, ID: 12},
|
||||
{Price: 48, Amount: 0.5, ID: 13},
|
||||
{Price: 47, Amount: 2, ID: 14},
|
||||
{Price: 46, Amount: 3, ID: 15},
|
||||
{Price: 45, Amount: 5, ID: 16},
|
||||
{Price: 44, Amount: 2, ID: 17},
|
||||
{Price: 43, Amount: 1.5, ID: 18},
|
||||
{Price: 42, Amount: 0.5, ID: 19},
|
||||
{Price: 41, Amount: 23, ID: 20},
|
||||
{Price: 40, Amount: 9, ID: 21},
|
||||
{Price: 39, Amount: 7, ID: 22},
|
||||
}
|
||||
|
||||
snapShot2.Asks = asks
|
||||
snapShot2.Asks.SortAsks()
|
||||
snapShot2.Bids = bids
|
||||
snapShot2.Bids.SortBids()
|
||||
snapShot2.Asset = asset.Spot
|
||||
snapShot2.Pair, err = getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
snapShot2.LastUpdated = time.Now()
|
||||
require.NoError(t, holder.LoadSnapshot(&snapShot2))
|
||||
|
||||
var snapShot3 orderbook.Base
|
||||
snapShot3.Exchange = "WSORDERBOOKTEST3"
|
||||
asks = []orderbook.Tranche{
|
||||
{Price: 511, Amount: 1, ID: 1},
|
||||
{Price: 52, Amount: 0.5, ID: 2},
|
||||
{Price: 53, Amount: 2, ID: 3},
|
||||
{Price: 54, Amount: 3, ID: 4},
|
||||
{Price: 55, Amount: 5, ID: 5},
|
||||
{Price: 56, Amount: 2, ID: 6},
|
||||
{Price: 57, Amount: 1.5, ID: 7},
|
||||
{Price: 58, Amount: 0.5, ID: 8},
|
||||
{Price: 59, Amount: 23, ID: 9},
|
||||
{Price: 50, Amount: 9, ID: 10},
|
||||
{Price: 60, Amount: 7, ID: 11},
|
||||
}
|
||||
|
||||
bids = []orderbook.Tranche{
|
||||
{Price: 49, Amount: 1, ID: 12},
|
||||
{Price: 48, Amount: 0.5, ID: 13},
|
||||
{Price: 47, Amount: 2, ID: 14},
|
||||
{Price: 46, Amount: 3, ID: 15},
|
||||
{Price: 45, Amount: 5, ID: 16},
|
||||
{Price: 44, Amount: 2, ID: 17},
|
||||
{Price: 43, Amount: 1.5, ID: 18},
|
||||
{Price: 42, Amount: 0.5, ID: 19},
|
||||
{Price: 41, Amount: 23, ID: 20},
|
||||
{Price: 40, Amount: 9, ID: 21},
|
||||
{Price: 39, Amount: 7, ID: 22},
|
||||
}
|
||||
|
||||
snapShot3.Asks = asks
|
||||
snapShot3.Asks.SortAsks()
|
||||
snapShot3.Bids = bids
|
||||
snapShot3.Bids.SortBids()
|
||||
snapShot3.Asset = asset.Futures
|
||||
snapShot3.Pair, err = getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
snapShot3.LastUpdated = time.Now()
|
||||
require.NoError(t, holder.LoadSnapshot(&snapShot3))
|
||||
|
||||
ob, err := holder.GetOrderbook(snapShot1.Pair, snapShot1.Asset)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, snapShot1.Asks[0], ob.Asks[0])
|
||||
|
||||
ob, err = holder.GetOrderbook(snapShot2.Pair, snapShot2.Asset)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, snapShot2.Asks[0], ob.Asks[0])
|
||||
|
||||
ob, err = holder.GetOrderbook(snapShot3.Pair, snapShot3.Asset)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, snapShot3.Asks[0], ob.Asks[0])
|
||||
}
|
||||
|
||||
func TestGetOrderbook(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
ob, err := holder.GetOrderbook(cp, asset.Spot)
|
||||
require.NoError(t, err)
|
||||
|
||||
bufferOb := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
b, err := bufferOb.ob.Retrieve()
|
||||
require.NoError(t, err)
|
||||
|
||||
askLen, err := bufferOb.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
|
||||
bidLen, err := bufferOb.ob.GetBidLength()
|
||||
require.NoError(t, err)
|
||||
|
||||
if askLen != len(ob.Asks) ||
|
||||
bidLen != len(ob.Bids) ||
|
||||
b.Asset != ob.Asset ||
|
||||
b.Exchange != ob.Exchange ||
|
||||
b.LastUpdateID != ob.LastUpdateID ||
|
||||
b.PriceDuplication != ob.PriceDuplication ||
|
||||
b.Pair != ob.Pair {
|
||||
t.Fatal("data on both books should be the same")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := Orderbook{}
|
||||
err := w.Setup(nil, nil, nil)
|
||||
require.ErrorIs(t, err, errExchangeConfigNil)
|
||||
|
||||
exchangeConfig := &config.Exchange{}
|
||||
err = w.Setup(exchangeConfig, nil, nil)
|
||||
require.ErrorIs(t, err, errBufferConfigNil)
|
||||
|
||||
bufferConf := &Config{}
|
||||
err = w.Setup(exchangeConfig, bufferConf, nil)
|
||||
require.ErrorIs(t, err, errUnsetDataHandler)
|
||||
|
||||
exchangeConfig.Orderbook.WebsocketBufferEnabled = true
|
||||
err = w.Setup(exchangeConfig, bufferConf, make(chan any))
|
||||
require.ErrorIs(t, err, errIssueBufferEnabledButNoLimit)
|
||||
|
||||
exchangeConfig.Orderbook.WebsocketBufferLimit = 1337
|
||||
exchangeConfig.Orderbook.WebsocketBufferEnabled = true
|
||||
exchangeConfig.Name = "test"
|
||||
bufferConf.SortBuffer = true
|
||||
bufferConf.SortBufferByUpdateIDs = true
|
||||
bufferConf.UpdateEntriesByID = true
|
||||
err = w.Setup(exchangeConfig, bufferConf, make(chan any))
|
||||
require.NoError(t, err)
|
||||
|
||||
if w.obBufferLimit != 1337 ||
|
||||
!w.bufferEnabled ||
|
||||
!w.sortBuffer ||
|
||||
!w.sortBufferByUpdateIDs ||
|
||||
!w.updateEntriesByID ||
|
||||
w.exchangeName != "test" {
|
||||
t.Errorf("Setup incorrectly loaded %s", w.exchangeName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := Orderbook{}
|
||||
err := w.validate(nil)
|
||||
require.ErrorIs(t, err, errUpdateIsNil)
|
||||
err = w.validate(&orderbook.Update{})
|
||||
require.ErrorIs(t, err, errUpdateNoTargets)
|
||||
}
|
||||
|
||||
func TestEnsureMultipleUpdatesViaPrice(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
holder, _, _, err := createSnapshot(cp)
|
||||
require.NoError(t, err)
|
||||
|
||||
asks := bidAskGenerator()
|
||||
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
|
||||
err = book.updateByPrice(&orderbook.Update{
|
||||
Bids: asks,
|
||||
Asks: asks,
|
||||
Pair: cp,
|
||||
UpdateTime: time.Now(),
|
||||
Asset: asset.Spot,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
askLen, err := book.ob.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, 3, askLen)
|
||||
}
|
||||
|
||||
func deploySliceOrdered(size int) orderbook.Tranches {
|
||||
items := make([]orderbook.Tranche, size)
|
||||
for i := range size {
|
||||
items[i] = orderbook.Tranche{Amount: 1, Price: rand.Float64() + float64(i), ID: rand.Int63()} //nolint:gosec // Not needed for tests
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func TestUpdateByIDAndAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
asks := deploySliceOrdered(100)
|
||||
bids := slices.Clone(asks)
|
||||
bids.Reverse()
|
||||
|
||||
book, err := orderbook.DeployDepth("test", cp, asset.Spot)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
ob, err := book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ob.Verify())
|
||||
|
||||
holder := orderbookHolder{ob: book}
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{})
|
||||
require.ErrorIs(t, err, errInvalidAction)
|
||||
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.Amend,
|
||||
Bids: []orderbook.Tranche{{Price: 100, ID: 6969}},
|
||||
})
|
||||
require.ErrorIs(t, err, errAmendFailure)
|
||||
|
||||
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// append to slice
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.UpdateInsert,
|
||||
Bids: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 1}},
|
||||
Asks: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 1}},
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cpy, err := book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0., cpy.Bids[len(cpy.Bids)-1].Price)
|
||||
require.Equal(t, 100., cpy.Asks[len(cpy.Asks)-1].Price)
|
||||
|
||||
// Change amount
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.UpdateInsert,
|
||||
Bids: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 100}},
|
||||
Asks: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 100}},
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cpy, err = book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 100., cpy.Bids[len(cpy.Bids)-1].Amount)
|
||||
require.Equal(t, 100., cpy.Asks[len(cpy.Asks)-1].Amount)
|
||||
|
||||
// Change price level
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.UpdateInsert,
|
||||
Bids: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 99}},
|
||||
Asks: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 99}},
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cpy, err = book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 99., cpy.Bids[0].Amount)
|
||||
require.Equal(t, 100., cpy.Bids[0].Price)
|
||||
require.Equal(t, 99., cpy.Asks[0].Amount)
|
||||
require.Equal(t, 0., cpy.Asks[0].Price)
|
||||
|
||||
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
// Delete - not found
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.Delete,
|
||||
Asks: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 99}},
|
||||
})
|
||||
require.ErrorIs(t, err, errDeleteFailure)
|
||||
|
||||
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete - found
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.Delete,
|
||||
Asks: []orderbook.Tranche{asks[0]},
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
askLen, err := book.GetAskLength()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 99, askLen)
|
||||
|
||||
// Apply update
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.Amend,
|
||||
Asks: []orderbook.Tranche{{ID: 123456}},
|
||||
})
|
||||
require.ErrorIs(t, err, errAmendFailure)
|
||||
|
||||
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(bids), 0, time.Now(), time.Now(), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
ob, err = book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ob.Asks)
|
||||
require.NotEmpty(t, ob.Bids)
|
||||
|
||||
update := ob.Asks[0]
|
||||
update.Amount = 1337
|
||||
|
||||
err = holder.updateByIDAndAction(&orderbook.Update{
|
||||
Action: orderbook.Amend,
|
||||
Asks: []orderbook.Tranche{update},
|
||||
UpdateTime: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ob, err = book.Retrieve()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1337., ob.Asks[0].Amount)
|
||||
}
|
||||
|
||||
func TestFlushOrderbook(t *testing.T) {
|
||||
t.Parallel()
|
||||
cp, err := getExclusivePair()
|
||||
require.NoError(t, err)
|
||||
|
||||
w := &Orderbook{}
|
||||
err = w.Setup(&config.Exchange{Name: "test"}, &Config{}, make(chan any, 2))
|
||||
require.NoError(t, err)
|
||||
|
||||
var snapShot1 orderbook.Base
|
||||
snapShot1.Exchange = "Snapshooooot"
|
||||
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
|
||||
bids := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 9}}
|
||||
snapShot1.Asks = asks
|
||||
snapShot1.Bids = bids
|
||||
snapShot1.Asset = asset.Spot
|
||||
snapShot1.Pair = cp
|
||||
snapShot1.LastUpdated = time.Now()
|
||||
|
||||
err = w.FlushOrderbook(cp, asset.Spot)
|
||||
if err == nil {
|
||||
t.Fatal("book not loaded error cannot be nil")
|
||||
}
|
||||
|
||||
_, err = w.GetOrderbook(cp, asset.Spot)
|
||||
require.ErrorIs(t, err, errDepthNotFound)
|
||||
|
||||
require.NoError(t, w.LoadSnapshot(&snapShot1))
|
||||
require.NoError(t, w.FlushOrderbook(cp, asset.Spot))
|
||||
|
||||
_, err = w.GetOrderbook(cp, asset.Spot)
|
||||
require.ErrorIs(t, err, orderbook.ErrOrderbookInvalid)
|
||||
}
|
||||
59
internal/exchange/websocket/buffer/buffer_types.go
Normal file
59
internal/exchange/websocket/buffer/buffer_types.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common/key"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
|
||||
)
|
||||
|
||||
// Config defines the configuration variables for the websocket buffer; snapshot
|
||||
// and incremental update orderbook processing.
|
||||
type Config struct {
|
||||
// SortBuffer enables a websocket to sort incoming updates before processing.
|
||||
SortBuffer bool
|
||||
// SortBufferByUpdateIDs allows the sorting of the buffered updates by their
|
||||
// corresponding update IDs.
|
||||
SortBufferByUpdateIDs bool
|
||||
// UpdateEntriesByID will match by IDs instead of price to perform the an
|
||||
// action. e.g. update, delete, insert.
|
||||
UpdateEntriesByID bool
|
||||
// UpdateIDProgression requires that the new update ID be greater than the
|
||||
// prior ID. This will skip processing and not error.
|
||||
UpdateIDProgression bool
|
||||
// Checksum is a package defined checksum calculation for updated books.
|
||||
Checksum func(state *orderbook.Base, checksum uint32) error
|
||||
}
|
||||
|
||||
// Orderbook defines a local cache of orderbooks for amending, appending
|
||||
// and deleting changes and updates the main store for a stream
|
||||
type Orderbook struct {
|
||||
ob map[key.PairAsset]*orderbookHolder
|
||||
obBufferLimit int
|
||||
bufferEnabled bool
|
||||
sortBuffer bool
|
||||
sortBufferByUpdateIDs bool // When timestamps aren't provided, an id can help sort
|
||||
updateEntriesByID bool // Use the update IDs to match ob entries
|
||||
exchangeName string
|
||||
dataHandler chan<- any
|
||||
verbose bool
|
||||
|
||||
// updateIDProgression requires that the new update ID be greater than the
|
||||
// prior ID. This will skip processing and not error.
|
||||
updateIDProgression bool
|
||||
// checksum is a package defined checksum calculation for updated books.
|
||||
checksum func(state *orderbook.Base, checksum uint32) error
|
||||
// TODO: sync.RWMutex. For the moment we process the orderbook in a single
|
||||
// thread. In future when there are workers directly involved this can be
|
||||
// can be improved with RW mechanics which will allow updates to occur at
|
||||
// the same time on different books.
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
// orderbookHolder defines a store of pending updates and a pointer to the
|
||||
// orderbook depth
|
||||
type orderbookHolder struct {
|
||||
ob *orderbook.Depth
|
||||
buffer *[]orderbook.Update
|
||||
updateID int64
|
||||
}
|
||||
481
internal/exchange/websocket/connection.go
Normal file
481
internal/exchange/websocket/connection.go
Normal file
@@ -0,0 +1,481 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/thrasher-corp/gocryptotrader/encoding/json"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
var (
|
||||
// errConnectionFault is a connection fault error which alerts the system that a connection cycle needs to take place.
|
||||
errConnectionFault = errors.New("connection fault")
|
||||
errWebsocketIsDisconnected = errors.New("websocket connection is disconnected")
|
||||
errRateLimitNotFound = errors.New("rate limit definition not found")
|
||||
)
|
||||
|
||||
// Connection defines the interface for websocket connections
|
||||
type Connection interface {
|
||||
Dial(*gws.Dialer, http.Header) error
|
||||
DialContext(context.Context, *gws.Dialer, http.Header) error
|
||||
ReadMessage() Response
|
||||
SetupPingHandler(request.EndpointLimit, PingHandler)
|
||||
// GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set
|
||||
// (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function
|
||||
// defined in websocket_connection.go.
|
||||
GenerateMessageID(highPrecision bool) int64
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error)
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int) ([][]byte, error)
|
||||
// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection
|
||||
SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int, messageInspector Inspector) ([][]byte, error)
|
||||
// SendRawMessage sends a message over the connection without JSON encoding it
|
||||
SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error
|
||||
// SendJSONMessage sends a JSON encoded message over the connection
|
||||
SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error
|
||||
SetURL(string)
|
||||
SetProxy(string)
|
||||
GetURL() string
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
// ConnectionSetup defines variables for an individual stream connection
|
||||
type ConnectionSetup struct {
|
||||
ResponseCheckTimeout time.Duration
|
||||
ResponseMaxLimit time.Duration
|
||||
RateLimit *request.RateLimiterWithWeight
|
||||
Authenticated bool
|
||||
ConnectionLevelReporter Reporter
|
||||
|
||||
// URL defines the websocket server URL to connect to
|
||||
URL string
|
||||
// Connector is the function that will be called to connect to the
|
||||
// exchange's websocket server. This will be called once when the stream
|
||||
// service is started. Any bespoke connection logic should be handled here.
|
||||
Connector func(ctx context.Context, conn Connection) error
|
||||
// GenerateSubscriptions is a function that will be called to generate a
|
||||
// list of subscriptions to be made to the exchange's websocket server.
|
||||
GenerateSubscriptions func() (subscription.List, error)
|
||||
// Subscriber is a function that will be called to send subscription
|
||||
// messages based on the exchange's websocket server requirements to
|
||||
// subscribe to specific channels.
|
||||
Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error
|
||||
// Unsubscriber is a function that will be called to send unsubscription
|
||||
// messages based on the exchange's websocket server requirements to
|
||||
// unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED.
|
||||
Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error
|
||||
// Handler defines the function that will be called when a message is
|
||||
// received from the exchange's websocket server. This function should
|
||||
// handle the incoming message and pass it to the appropriate data handler.
|
||||
Handler func(ctx context.Context, incoming []byte) error
|
||||
// BespokeGenerateMessageID is a function that returns a unique message ID.
|
||||
// This is useful for when an exchange connection requires a unique or
|
||||
// structured message ID for each message sent.
|
||||
BespokeGenerateMessageID func(highPrecision bool) int64
|
||||
Authenticate func(ctx context.Context, conn Connection) error
|
||||
// MessageFilter defines the criteria used to match messages to a specific connection.
|
||||
// The filter enables precise routing and handling of messages for distinct connection contexts.
|
||||
MessageFilter any
|
||||
}
|
||||
|
||||
// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection
|
||||
// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages
|
||||
type Inspector interface {
|
||||
IsFinal([]byte) bool
|
||||
}
|
||||
|
||||
// Response defines generalised data from the websocket connection
|
||||
type Response struct {
|
||||
Type int
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// connection contains all the data needed to send a message to a websocket connection
|
||||
type connection struct {
|
||||
Verbose bool
|
||||
connected int32
|
||||
writeControl sync.Mutex // Gorilla websocket does not allow more than one goroutine to utilise write methods
|
||||
RateLimit *request.RateLimiterWithWeight // RateLimit is a rate limiter for the connection itself
|
||||
RateLimitDefinitions request.RateLimitDefinitions // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections
|
||||
Reporter Reporter
|
||||
ExchangeName string
|
||||
URL string
|
||||
ProxyURL string
|
||||
Wg *sync.WaitGroup
|
||||
Connection *gws.Conn
|
||||
shutdown chan struct{}
|
||||
Match *Match
|
||||
ResponseMaxLimit time.Duration
|
||||
Traffic chan struct{}
|
||||
readMessageErrors chan error
|
||||
bespokeGenerateMessageID func(highPrecision bool) int64
|
||||
}
|
||||
|
||||
// Dial sets proxy urls and then connects to the websocket
|
||||
func (c *connection) Dial(dialer *gws.Dialer, headers http.Header) error {
|
||||
return c.DialContext(context.Background(), dialer, headers)
|
||||
}
|
||||
|
||||
// DialContext sets proxy urls and then connects to the websocket
|
||||
func (c *connection) DialContext(ctx context.Context, dialer *gws.Dialer, headers http.Header) error {
|
||||
if c.ProxyURL != "" {
|
||||
proxy, err := url.Parse(c.ProxyURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dialer.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
|
||||
var err error
|
||||
var conStatus *http.Response
|
||||
c.Connection, conStatus, err = dialer.DialContext(ctx, c.URL, headers)
|
||||
if err != nil {
|
||||
if conStatus != nil {
|
||||
_ = conStatus.Body.Close()
|
||||
return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", c.ExchangeName, c.URL, conStatus, conStatus.StatusCode, err)
|
||||
}
|
||||
return fmt.Errorf("%s websocket connection: %v Error: %w", c.ExchangeName, c.URL, err)
|
||||
}
|
||||
_ = conStatus.Body.Close()
|
||||
|
||||
if c.Verbose {
|
||||
log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", c.ExchangeName, c.URL)
|
||||
}
|
||||
select {
|
||||
case c.Traffic <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
c.setConnectedStatus(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendJSONMessage sends a JSON encoded message over the connection
|
||||
func (c *connection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error {
|
||||
return c.writeToConn(ctx, epl, func() error {
|
||||
if request.IsVerbose(ctx, c.Verbose) {
|
||||
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(msg))
|
||||
}
|
||||
}
|
||||
return c.Connection.WriteJSON(data)
|
||||
})
|
||||
}
|
||||
|
||||
// SendRawMessage sends a message over the connection without JSON encoding it
|
||||
func (c *connection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error {
|
||||
return c.writeToConn(ctx, epl, func() error {
|
||||
if request.IsVerbose(ctx, c.Verbose) {
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(message))
|
||||
}
|
||||
return c.Connection.WriteMessage(messageType, message)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *connection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error {
|
||||
if !c.IsConnected() {
|
||||
return fmt.Errorf("%v websocket connection: cannot send message %w", c.ExchangeName, errWebsocketIsDisconnected)
|
||||
}
|
||||
|
||||
var rl *request.RateLimiterWithWeight
|
||||
if c.RateLimitDefinitions != nil {
|
||||
var ok bool
|
||||
if rl, ok = c.RateLimitDefinitions[epl]; !ok && c.RateLimit == nil {
|
||||
// Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is
|
||||
// set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits
|
||||
// if they are defined.
|
||||
return fmt.Errorf("%s websocket connection: %w for %v", c.ExchangeName, errRateLimitNotFound, epl)
|
||||
}
|
||||
}
|
||||
|
||||
if rl == nil {
|
||||
// If a global rate limit definition is not found, use the connection rate limit as a fallback.
|
||||
rl = c.RateLimit
|
||||
}
|
||||
|
||||
if rl != nil {
|
||||
if err := request.RateLimit(ctx, rl); err != nil {
|
||||
return fmt.Errorf("%s websocket connection: rate limit error: %w", c.ExchangeName, err)
|
||||
}
|
||||
}
|
||||
// This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check.
|
||||
c.writeControl.Lock()
|
||||
defer c.writeControl.Unlock()
|
||||
return writeConn()
|
||||
}
|
||||
|
||||
// SetupPingHandler will automatically send ping or pong messages based on
|
||||
// WebsocketPingHandler configuration
|
||||
func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) {
|
||||
if handler.UseGorillaHandler {
|
||||
c.Connection.SetPingHandler(func(msg string) error {
|
||||
err := c.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay))
|
||||
if err == gws.ErrCloseSent {
|
||||
return nil
|
||||
} else if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Wg.Add(1)
|
||||
go func() {
|
||||
defer c.Wg.Done()
|
||||
ticker := time.NewTicker(handler.Delay)
|
||||
for {
|
||||
select {
|
||||
case <-c.shutdown:
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// setConnectedStatus sets connection status if changed it will return true.
|
||||
// TODO: Swap out these atomic switches and opt for sync.RWMutex.
|
||||
func (c *connection) setConnectedStatus(b bool) bool {
|
||||
if b {
|
||||
return atomic.SwapInt32(&c.connected, 1) == 0
|
||||
}
|
||||
return atomic.SwapInt32(&c.connected, 0) == 1
|
||||
}
|
||||
|
||||
// IsConnected exposes websocket connection status
|
||||
func (c *connection) IsConnected() bool {
|
||||
return atomic.LoadInt32(&c.connected) == 1
|
||||
}
|
||||
|
||||
// ReadMessage reads messages, can handle text, gzip and binary
|
||||
func (c *connection) ReadMessage() Response {
|
||||
mType, resp, err := c.Connection.ReadMessage()
|
||||
if err != nil {
|
||||
// If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the
|
||||
// reader routine to exit. This leaves the connection without an active reader,
|
||||
// leading to potential buffer issue from the ongoing websocket writes.
|
||||
// Such errors are passed to `c.readMessageErrors` when the connection is active.
|
||||
// The `connectionMonitor` handles these errors by flushing the buffer, reconnecting,
|
||||
// and resubscribing to the websocket to restore the connection.
|
||||
if c.setConnectedStatus(false) {
|
||||
// NOTE: When c.setConnectedStatus() returns true the underlying
|
||||
// state was changed and this infers that the connection was
|
||||
// externally closed and an error is reported else Shutdown()
|
||||
// method on WebsocketConnection type has been called and can
|
||||
// be skipped.
|
||||
select {
|
||||
case c.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault):
|
||||
default:
|
||||
// bypass if there is no receiver, as this stops it returning
|
||||
// when shutdown is called.
|
||||
log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", c.ExchangeName, err)
|
||||
}
|
||||
}
|
||||
return Response{}
|
||||
}
|
||||
|
||||
select {
|
||||
case c.Traffic <- struct{}{}:
|
||||
default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding
|
||||
}
|
||||
|
||||
var standardMessage []byte
|
||||
switch mType {
|
||||
case gws.TextMessage:
|
||||
standardMessage = resp
|
||||
case gws.BinaryMessage:
|
||||
standardMessage, err = c.parseBinaryResponse(resp)
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr, "%v %v: Parse binary response error: %v", c.ExchangeName, removeURLQueryString(c.URL), err)
|
||||
return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case.
|
||||
}
|
||||
}
|
||||
if c.Verbose {
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Message received: %v", c.ExchangeName, removeURLQueryString(c.URL), string(standardMessage))
|
||||
}
|
||||
return Response{Raw: standardMessage, Type: mType}
|
||||
}
|
||||
|
||||
// parseBinaryResponse parses a websocket binary response into a usable byte array
|
||||
func (c *connection) parseBinaryResponse(resp []byte) ([]byte, error) {
|
||||
var reader io.ReadCloser
|
||||
var err error
|
||||
if len(resp) >= 2 && resp[0] == 31 && resp[1] == 139 { // Detect GZIP
|
||||
reader, err = gzip.NewReader(bytes.NewReader(resp))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
reader = flate.NewReader(bytes.NewReader(resp))
|
||||
}
|
||||
standardMessage, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return standardMessage, reader.Close()
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a message ID for the individual connection.
|
||||
// If a bespoke function is set (by using SetupNewConnection) it will use that,
|
||||
// otherwise it will use the defaultGenerateMessageID function.
|
||||
func (c *connection) GenerateMessageID(highPrec bool) int64 {
|
||||
if c.bespokeGenerateMessageID != nil {
|
||||
return c.bespokeGenerateMessageID(highPrec)
|
||||
}
|
||||
return c.defaultGenerateMessageID(highPrec)
|
||||
}
|
||||
|
||||
// defaultGenerateMessageID generates the default message ID
|
||||
func (c *connection) defaultGenerateMessageID(highPrec bool) int64 {
|
||||
var minValue int64 = 1e8
|
||||
var maxValue int64 = 2e8
|
||||
if highPrec {
|
||||
maxValue = 2e12
|
||||
minValue = 1e12
|
||||
}
|
||||
// utilization of hard coded positive numbers and default crypto/rand
|
||||
// io.reader will panic on error instead of returning
|
||||
randomNumber, err := rand.Int(rand.Reader, big.NewInt(maxValue-minValue+1))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return randomNumber.Int64() + minValue
|
||||
}
|
||||
|
||||
// Shutdown shuts down and closes specific connection
|
||||
func (c *connection) Shutdown() error {
|
||||
if c == nil || c.Connection == nil {
|
||||
return nil
|
||||
}
|
||||
c.setConnectedStatus(false)
|
||||
c.writeControl.Lock()
|
||||
defer c.writeControl.Unlock()
|
||||
return c.Connection.NetConn().Close()
|
||||
}
|
||||
|
||||
// SetURL sets connection URL
|
||||
func (c *connection) SetURL(url string) {
|
||||
c.URL = url
|
||||
}
|
||||
|
||||
// SetProxy sets connection proxy
|
||||
func (c *connection) SetProxy(proxy string) {
|
||||
c.ProxyURL = proxy
|
||||
}
|
||||
|
||||
// GetURL returns the connection URL
|
||||
func (c *connection) GetURL() string {
|
||||
return c.URL
|
||||
}
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
func (c *connection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) {
|
||||
resps, err := c.SendMessageReturnResponses(ctx, epl, signature, request, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resps[0], nil
|
||||
}
|
||||
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
|
||||
func (c *connection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) {
|
||||
return c.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil)
|
||||
}
|
||||
|
||||
// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses
|
||||
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
|
||||
func (c *connection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) {
|
||||
outbound, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
ch, err := c.Match.Set(signature, expected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = c.SendRawMessage(ctx, epl, gws.TextMessage, outbound)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resps, err := c.waitForResponses(ctx, signature, ch, expected, messageInspector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.Reporter != nil {
|
||||
c.Reporter.Latency(c.ExchangeName, outbound, time.Since(start))
|
||||
}
|
||||
|
||||
return resps, err
|
||||
}
|
||||
|
||||
// waitForResponses waits for N responses from a channel
|
||||
func (c *connection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) {
|
||||
timeout := time.NewTimer(c.ResponseMaxLimit * time.Duration(expected))
|
||||
defer timeout.Stop()
|
||||
|
||||
resps := make([][]byte, 0, expected)
|
||||
inspection:
|
||||
for range expected {
|
||||
select {
|
||||
case resp := <-ch:
|
||||
resps = append(resps, resp)
|
||||
// Checks recently received message to determine if this is in fact the final message in a sequence of messages.
|
||||
if messageInspector != nil && messageInspector.IsFinal(resp) {
|
||||
c.Match.RemoveSignature(signature)
|
||||
break inspection
|
||||
}
|
||||
case <-timeout.C:
|
||||
c.Match.RemoveSignature(signature)
|
||||
return nil, fmt.Errorf("%s %w %v", c.ExchangeName, ErrSignatureTimeout, signature)
|
||||
case <-ctx.Done():
|
||||
c.Match.RemoveSignature(signature)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call.
|
||||
if request.IsVerbose(ctx, false) {
|
||||
for i := range resps {
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", c.ExchangeName, removeURLQueryString(c.URL), i+1, len(resps), string(resps[i]))
|
||||
}
|
||||
}
|
||||
|
||||
return resps, nil
|
||||
}
|
||||
|
||||
func removeURLQueryString(url string) string {
|
||||
if index := strings.Index(url, "?"); index != -1 {
|
||||
return url[:index]
|
||||
}
|
||||
return url
|
||||
}
|
||||
1070
internal/exchange/websocket/manager.go
Normal file
1070
internal/exchange/websocket/manager.go
Normal file
File diff suppressed because it is too large
Load Diff
1338
internal/exchange/websocket/manager_test.go
Normal file
1338
internal/exchange/websocket/manager_test.go
Normal file
File diff suppressed because it is too large
Load Diff
86
internal/exchange/websocket/match.go
Normal file
86
internal/exchange/websocket/match.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrSignatureNotMatched is returned when a signature does not match a request
|
||||
var ErrSignatureNotMatched = errors.New("websocket response to request signature not matched")
|
||||
|
||||
var (
|
||||
errSignatureCollision = errors.New("signature collision")
|
||||
errInvalidBufferSize = errors.New("buffer size must be positive")
|
||||
)
|
||||
|
||||
// NewMatch returns a new Match
|
||||
func NewMatch() *Match {
|
||||
return &Match{m: make(map[any]*incoming)}
|
||||
}
|
||||
|
||||
// Match is a distributed subtype that handles the matching of requests and
|
||||
// responses in a timely manner, reducing the need to differentiate between
|
||||
// connections. Stream systems fan in all incoming payloads to one routine for
|
||||
// processing.
|
||||
type Match struct {
|
||||
m map[any]*incoming
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type incoming struct {
|
||||
expected int
|
||||
c chan<- []byte
|
||||
}
|
||||
|
||||
// IncomingWithData matches with requests and takes in the returned payload, to
|
||||
// be processed outside of a stream processing routine and returns true if a handler was found
|
||||
func (m *Match) IncomingWithData(signature any, data []byte) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
ch, ok := m.m[signature]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
ch.c <- data
|
||||
ch.expected--
|
||||
if ch.expected == 0 {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// RequireMatchWithData validates that incoming data matches a request's signature.
|
||||
// If a match is found, the data is processed; otherwise, it returns an error.
|
||||
func (m *Match) RequireMatchWithData(signature any, data []byte) error {
|
||||
if m.IncomingWithData(signature, data) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("'%v' %w with data %v", signature, ErrSignatureNotMatched, string(data))
|
||||
}
|
||||
|
||||
// Set the signature response channel for incoming data
|
||||
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
|
||||
if bufSize <= 0 {
|
||||
return nil, errInvalidBufferSize
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, ok := m.m[signature]; ok {
|
||||
return nil, errSignatureCollision
|
||||
}
|
||||
ch := make(chan []byte, bufSize)
|
||||
m.m[signature] = &incoming{expected: bufSize, c: ch}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// RemoveSignature removes the signature response from map and closes the channel.
|
||||
func (m *Match) RemoveSignature(signature any) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if ch, ok := m.m[signature]; ok {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
}
|
||||
68
internal/exchange/websocket/match_test.go
Normal file
68
internal/exchange/websocket/match_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
load := []byte("42")
|
||||
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
|
||||
|
||||
match := NewMatch()
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
|
||||
|
||||
_, err := match.Set("hello", 0)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
|
||||
_, err = match.Set("hello", -1)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
|
||||
ch, err := match.Set("hello", 2)
|
||||
require.NoError(t, err, "Set must not error")
|
||||
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
|
||||
assert.Equal(t, "hello", string(<-ch))
|
||||
|
||||
_, err = match.Set("hello", 2)
|
||||
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
|
||||
|
||||
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
|
||||
|
||||
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
|
||||
}
|
||||
|
||||
func TestRemoveSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
match := NewMatch()
|
||||
ch, err := match.Set("masterblaster", 1)
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("Should not be able to read from an empty channel")
|
||||
default:
|
||||
}
|
||||
require.NoError(t, err)
|
||||
match.RemoveSignature("masterblaster")
|
||||
select {
|
||||
case garbage := <-ch:
|
||||
require.Empty(t, garbage)
|
||||
default:
|
||||
t.Fatal("Should be able to read from a closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireMatchWithData(t *testing.T) {
|
||||
t.Parallel()
|
||||
match := NewMatch()
|
||||
err := match.RequireMatchWithData("hello", []byte("world"))
|
||||
require.ErrorIs(t, err, ErrSignatureNotMatched, "Must error on unmatched signature")
|
||||
assert.Contains(t, err.Error(), "world", "Should contain the data in the error message")
|
||||
assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message")
|
||||
|
||||
ch, err := match.Set("hello", 1)
|
||||
require.NoError(t, err, "Set must not error")
|
||||
err = match.RequireMatchWithData("hello", []byte("world"))
|
||||
require.NoError(t, err, "Must not error on matched signature")
|
||||
assert.Equal(t, "world", string(<-ch))
|
||||
}
|
||||
349
internal/exchange/websocket/subscriptions.go
Normal file
349
internal/exchange/websocket/subscriptions.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
// Public subscription errors
|
||||
var (
|
||||
ErrSubscriptionFailure = errors.New("subscription failure")
|
||||
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
|
||||
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
|
||||
)
|
||||
|
||||
// Public subscription errors
|
||||
var (
|
||||
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
|
||||
)
|
||||
|
||||
// UnsubscribeChannels unsubscribes from a list of websocket channel
|
||||
func (m *Manager) UnsubscribeChannels(conn Connection, channels subscription.List) error {
|
||||
if len(channels) == 0 {
|
||||
return nil // No channels to unsubscribe from is not an error
|
||||
}
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
return m.unsubscribe(wrapper.subscriptions, channels, func(channels subscription.List) error {
|
||||
return wrapper.setup.Unsubscriber(context.TODO(), conn, channels)
|
||||
})
|
||||
}
|
||||
|
||||
if m.Unsubscriber == nil {
|
||||
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
return m.unsubscribe(m.subscriptions, channels, func(channels subscription.List) error {
|
||||
return m.Unsubscriber(channels)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error {
|
||||
if store == nil {
|
||||
return nil // No channels to unsubscribe from is not an error
|
||||
}
|
||||
for _, s := range channels {
|
||||
if store.Get(s) == nil {
|
||||
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
|
||||
}
|
||||
}
|
||||
return unsub(channels)
|
||||
}
|
||||
|
||||
// ResubscribeToChannel resubscribes to channel
|
||||
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
|
||||
// Errors if subscription is already subscribing
|
||||
func (m *Manager) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error {
|
||||
l := subscription.List{s}
|
||||
if err := s.SetState(subscription.ResubscribingState); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, s)
|
||||
}
|
||||
if err := m.UnsubscribeChannels(conn, l); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.SubscribeToChannels(conn, l)
|
||||
}
|
||||
|
||||
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
|
||||
// Errors are returned for duplicates or exceeding max Subscriptions
|
||||
func (m *Manager) SubscribeToChannels(conn Connection, subs subscription.List) error {
|
||||
if slices.Contains(subs, nil) {
|
||||
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
|
||||
}
|
||||
if err := m.checkSubscriptions(conn, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
return wrapper.setup.Subscriber(context.TODO(), conn, subs)
|
||||
}
|
||||
|
||||
if m.Subscriber == nil {
|
||||
return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
if err := m.Subscriber(subs); err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddSubscriptions adds subscriptions to the subscription store
|
||||
// Sets state to Subscribing unless the state is already set
|
||||
func (m *Manager) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
|
||||
}
|
||||
var subscriptionStore **subscription.Store
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = &wrapper.subscriptions
|
||||
} else {
|
||||
subscriptionStore = &m.subscriptions
|
||||
}
|
||||
|
||||
if *subscriptionStore == nil {
|
||||
*subscriptionStore = subscription.NewStore()
|
||||
}
|
||||
var errs error
|
||||
for _, s := range subs {
|
||||
if s.State() == subscription.InactiveState {
|
||||
if err := s.SetState(subscription.SubscribingState); err != nil {
|
||||
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
||||
}
|
||||
}
|
||||
if err := (*subscriptionStore).Add(s); err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
|
||||
func (m *Manager) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
var subscriptionStore **subscription.Store
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = &wrapper.subscriptions
|
||||
} else {
|
||||
subscriptionStore = &m.subscriptions
|
||||
}
|
||||
|
||||
if *subscriptionStore == nil {
|
||||
*subscriptionStore = subscription.NewStore()
|
||||
}
|
||||
|
||||
var errs error
|
||||
for _, s := range subs {
|
||||
if err := s.SetState(subscription.SubscribedState); err != nil {
|
||||
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
||||
}
|
||||
if err := (*subscriptionStore).Add(s); err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
|
||||
func (m *Manager) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
var subscriptionStore *subscription.Store
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = wrapper.subscriptions
|
||||
} else {
|
||||
subscriptionStore = m.subscriptions
|
||||
}
|
||||
|
||||
if subscriptionStore == nil {
|
||||
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
var errs error
|
||||
for _, s := range subs {
|
||||
if err := s.SetState(subscription.UnsubscribedState); err != nil {
|
||||
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
|
||||
}
|
||||
if err := subscriptionStore.Remove(s); err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// GetSubscription returns a subscription at the key provided
|
||||
// returns nil if no subscription is at that key or the key is nil
|
||||
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
|
||||
func (m *Manager) GetSubscription(key any) *subscription.Subscription {
|
||||
if m == nil || key == nil {
|
||||
return nil
|
||||
}
|
||||
for _, c := range m.connectionManager {
|
||||
if c.subscriptions == nil {
|
||||
continue
|
||||
}
|
||||
sub := c.subscriptions.Get(key)
|
||||
if sub != nil {
|
||||
return sub
|
||||
}
|
||||
}
|
||||
if m.subscriptions == nil {
|
||||
return nil
|
||||
}
|
||||
return m.subscriptions.Get(key)
|
||||
}
|
||||
|
||||
// GetSubscriptions returns a new slice of the subscriptions
|
||||
func (m *Manager) GetSubscriptions() subscription.List {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
var subs subscription.List
|
||||
for _, c := range m.connectionManager {
|
||||
if c.subscriptions != nil {
|
||||
subs = append(subs, c.subscriptions.List()...)
|
||||
}
|
||||
}
|
||||
if m.subscriptions != nil {
|
||||
subs = append(subs, m.subscriptions.List()...)
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
|
||||
// The subscription state is not considered when counting existing subscriptions
|
||||
func (m *Manager) checkSubscriptions(conn Connection, subs subscription.List) error {
|
||||
var subscriptionStore *subscription.Store
|
||||
if wrapper, ok := m.connections[conn]; ok && conn != nil {
|
||||
subscriptionStore = wrapper.subscriptions
|
||||
} else {
|
||||
subscriptionStore = m.subscriptions
|
||||
}
|
||||
if subscriptionStore == nil {
|
||||
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
|
||||
}
|
||||
|
||||
existing := subscriptionStore.Len()
|
||||
if m.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > m.MaxSubscriptionsPerConnection {
|
||||
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
|
||||
errSubscriptionsExceedsLimit,
|
||||
existing,
|
||||
len(subs),
|
||||
m.MaxSubscriptionsPerConnection)
|
||||
}
|
||||
|
||||
for _, s := range subs {
|
||||
if s.State() == subscription.ResubscribingState {
|
||||
continue
|
||||
}
|
||||
if found := subscriptionStore.Get(s); found != nil {
|
||||
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushChannels flushes channel subscriptions when there is a pair/asset change
|
||||
func (m *Manager) FlushChannels() error {
|
||||
if !m.IsEnabled() {
|
||||
return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketNotEnabled)
|
||||
}
|
||||
|
||||
if !m.IsConnected() {
|
||||
return fmt.Errorf("%s %w", m.exchangeName, ErrNotConnected)
|
||||
}
|
||||
|
||||
// If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to
|
||||
// maintain consistency.
|
||||
if !m.features.Subscribe || !m.features.Unsubscribe {
|
||||
m.m.Lock()
|
||||
defer m.m.Unlock()
|
||||
if err := m.shutdown(); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.connect()
|
||||
}
|
||||
|
||||
if !m.useMultiConnectionManagement {
|
||||
newSubs, err := m.GenerateSubs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.updateChannelSubscriptions(nil, m.subscriptions, newSubs)
|
||||
}
|
||||
|
||||
for x := range m.connectionManager {
|
||||
newSubs, err := m.connectionManager[x].setup.GenerateSubscriptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Case if there is nothing to unsubscribe from and the connection is nil
|
||||
if len(newSubs) == 0 && m.connectionManager[x].connection == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection.
|
||||
if m.connectionManager[x].connection == nil {
|
||||
conn := m.getConnectionFromSetup(m.connectionManager[x].setup)
|
||||
if err := m.connectionManager[x].setup.Connector(context.TODO(), conn); err != nil {
|
||||
return err
|
||||
}
|
||||
m.Wg.Add(1)
|
||||
go m.Reader(context.TODO(), conn, m.connectionManager[x].setup.Handler)
|
||||
m.connections[conn] = m.connectionManager[x]
|
||||
m.connectionManager[x].connection = conn
|
||||
}
|
||||
|
||||
err = m.updateChannelSubscriptions(m.connectionManager[x].connection, m.connectionManager[x].subscriptions, newSubs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
|
||||
if m.connectionManager[x].subscriptions.Len() == 0 {
|
||||
delete(m.connections, m.connectionManager[x].connection) // Remove from lookup map
|
||||
if err := m.connectionManager[x].connection.Shutdown(); err != nil {
|
||||
log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", m.exchangeName, err)
|
||||
}
|
||||
m.connectionManager[x].connection = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
|
||||
// have been subscribed to or unsubscribed from.
|
||||
func (m *Manager) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
|
||||
subs, unsubs := store.Diff(incoming)
|
||||
if len(unsubs) != 0 {
|
||||
if err := m.UnsubscribeChannels(c, unsubs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if contained := store.Contained(unsubs); len(contained) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotRemoved, contained)
|
||||
}
|
||||
}
|
||||
if len(subs) != 0 {
|
||||
if err := m.SubscribeToChannels(c, subs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if missing := store.Missing(subs); len(missing) > 0 {
|
||||
return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotAdded, missing)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
305
internal/exchange/websocket/subscriptions_test.go
Normal file
305
internal/exchange/websocket/subscriptions_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
|
||||
)
|
||||
|
||||
// TestSubscribe logic test
|
||||
func TestSubscribeUnsubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewManager()
|
||||
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
ws.Unsubscriber = currySimpleUnsub(ws)
|
||||
|
||||
subs, err := ws.GenerateSubs()
|
||||
require.NoError(t, err, "Generating test subscriptions must not error")
|
||||
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error")
|
||||
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
|
||||
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
|
||||
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
|
||||
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
|
||||
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
|
||||
}
|
||||
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
|
||||
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
|
||||
}
|
||||
if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
|
||||
assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
|
||||
}
|
||||
if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") {
|
||||
assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
|
||||
}
|
||||
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
|
||||
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
|
||||
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
|
||||
assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error")
|
||||
assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error")
|
||||
|
||||
ws.Subscriber = func(subscription.List) error { return errDastardlyReason }
|
||||
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
|
||||
|
||||
err = ws.SubscribeToChannels(nil, subscription.List{nil})
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
|
||||
|
||||
multi := NewManager()
|
||||
set := newDefaultSetup()
|
||||
set.UseMultiConnectionManagement = true
|
||||
assert.NoError(t, multi.Setup(set))
|
||||
|
||||
amazingCandidate := &ConnectionSetup{
|
||||
URL: "AMAZING",
|
||||
Connector: func(context.Context, Connection) error { return nil },
|
||||
GenerateSubscriptions: ws.GenerateSubs,
|
||||
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleSubConn(multi)(ctx, c, s)
|
||||
},
|
||||
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
|
||||
return currySimpleUnsubConn(multi)(ctx, c, s)
|
||||
},
|
||||
Handler: func(context.Context, []byte) error { return nil },
|
||||
}
|
||||
require.NoError(t, multi.SetupNewConnection(amazingCandidate))
|
||||
|
||||
amazingConn := multi.getConnectionFromSetup(amazingCandidate)
|
||||
multi.connections = map[Connection]*connectionWrapper{
|
||||
amazingConn: multi.connectionManager[0],
|
||||
}
|
||||
|
||||
subs, err = amazingCandidate.GenerateSubscriptions()
|
||||
require.NoError(t, err, "Generating test subscriptions must not error")
|
||||
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
|
||||
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
|
||||
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
|
||||
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
|
||||
|
||||
assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error")
|
||||
|
||||
assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error")
|
||||
assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions")
|
||||
bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"})
|
||||
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
|
||||
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
|
||||
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
|
||||
}
|
||||
if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
|
||||
assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
|
||||
}
|
||||
if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
|
||||
assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
|
||||
}
|
||||
if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") {
|
||||
assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
|
||||
}
|
||||
assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil")
|
||||
assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil")
|
||||
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
|
||||
assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error")
|
||||
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error")
|
||||
|
||||
amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason }
|
||||
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
|
||||
|
||||
err = multi.SubscribeToChannels(amazingConn, subscription.List{nil})
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
|
||||
}
|
||||
|
||||
// TestResubscribe tests Resubscribing to existing subscriptions
|
||||
func TestResubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := NewManager()
|
||||
|
||||
wackedOutSetup := newDefaultSetup()
|
||||
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
|
||||
err := ws.Setup(wackedOutSetup)
|
||||
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
|
||||
|
||||
err = ws.Setup(newDefaultSetup())
|
||||
assert.NoError(t, err, "WS Setup should not error")
|
||||
|
||||
ws.Subscriber = currySimpleSub(ws)
|
||||
ws.Unsubscriber = currySimpleUnsub(ws)
|
||||
|
||||
channel := subscription.List{{Channel: "resubTest"}}
|
||||
|
||||
assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
|
||||
assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error")
|
||||
assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed")
|
||||
}
|
||||
|
||||
// TestSubscriptions tests adding, getting and removing subscriptions
|
||||
func TestSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
|
||||
assert.ErrorIs(t, (*Manager)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
|
||||
s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
|
||||
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
|
||||
assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
|
||||
assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error")
|
||||
assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing")
|
||||
|
||||
err := w.RemoveSubscriptions(nil, s)
|
||||
require.NoError(t, err, "RemoveSubscriptions must not error")
|
||||
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
|
||||
assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed")
|
||||
|
||||
require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error")
|
||||
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
|
||||
assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state")
|
||||
}
|
||||
|
||||
// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions
|
||||
func TestSuccessfulSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
|
||||
assert.ErrorIs(t, (*Manager)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
|
||||
c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
|
||||
require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription must not error")
|
||||
assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
|
||||
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error")
|
||||
require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error")
|
||||
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error")
|
||||
|
||||
err := w.RemoveSubscriptions(nil, c)
|
||||
require.NoError(t, err, "RemoveSubscriptions must not error")
|
||||
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
|
||||
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found")
|
||||
assert.ErrorIs(t, (*Manager)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
|
||||
w.subscriptions = nil
|
||||
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket")
|
||||
}
|
||||
|
||||
// TestGetSubscription logic test
|
||||
func TestGetSubscription(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Nil(t, (*Manager).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
|
||||
assert.Nil(t, (&Manager{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
|
||||
w := NewManager()
|
||||
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
|
||||
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
|
||||
require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error")
|
||||
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
|
||||
}
|
||||
|
||||
// TestGetSubscriptions logic test
|
||||
func TestGetSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Nil(t, (*Manager).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
|
||||
assert.Nil(t, (&Manager{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
|
||||
w := NewManager()
|
||||
s := subscription.List{
|
||||
{Key: 42, Channel: "hello3"},
|
||||
{Key: 45, Channel: "hello4"},
|
||||
}
|
||||
err := w.AddSubscriptions(nil, s...)
|
||||
require.NoError(t, err, "AddSubscriptions must not error")
|
||||
assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details")
|
||||
}
|
||||
|
||||
func TestCheckSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
ws := Manager{}
|
||||
err := ws.checkSubscriptions(nil, nil)
|
||||
assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions")
|
||||
assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions")
|
||||
|
||||
ws.subscriptions = subscription.NewStore()
|
||||
err = ws.checkSubscriptions(nil, nil)
|
||||
assert.NoError(t, err, "checkSubscriptions should not error on a nil list")
|
||||
|
||||
ws.MaxSubscriptionsPerConnection = 1
|
||||
|
||||
err = ws.checkSubscriptions(nil, subscription.List{{}})
|
||||
assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty")
|
||||
|
||||
ws.subscriptions = subscription.NewStore()
|
||||
err = ws.checkSubscriptions(nil, subscription.List{{}, {}})
|
||||
assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly")
|
||||
|
||||
ws.MaxSubscriptionsPerConnection = 2
|
||||
|
||||
ws.subscriptions = subscription.NewStore()
|
||||
err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"})
|
||||
require.NoError(t, err, "Add subscription must not error")
|
||||
err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}})
|
||||
assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly")
|
||||
|
||||
err = ws.checkSubscriptions(nil, subscription.List{{}})
|
||||
assert.NoError(t, err, "checkSubscriptions should not error")
|
||||
}
|
||||
|
||||
func TestUpdateChannelSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ws := NewManager()
|
||||
store := subscription.NewStore()
|
||||
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
require.Zero(t, store.Len())
|
||||
|
||||
ws.Subscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Add(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ws.subscriptions = store
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, store.Len())
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
ws.Unsubscriber = func(subs subscription.List) error {
|
||||
for _, sub := range subs {
|
||||
if err := store.Remove(sub); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, store.Len())
|
||||
}
|
||||
|
||||
func currySimpleSub(w *Manager) func(subscription.List) error {
|
||||
return func(subs subscription.List) error {
|
||||
return w.AddSuccessfulSubscriptions(nil, subs...)
|
||||
}
|
||||
}
|
||||
|
||||
func currySimpleSubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
|
||||
return func(_ context.Context, conn Connection, subs subscription.List) error {
|
||||
return w.AddSuccessfulSubscriptions(conn, subs...)
|
||||
}
|
||||
}
|
||||
|
||||
func currySimpleUnsub(w *Manager) func(subscription.List) error {
|
||||
return func(unsubs subscription.List) error {
|
||||
return w.RemoveSubscriptions(nil, unsubs...)
|
||||
}
|
||||
}
|
||||
|
||||
func currySimpleUnsubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
|
||||
return func(_ context.Context, conn Connection, unsubs subscription.List) error {
|
||||
return w.RemoveSubscriptions(conn, unsubs...)
|
||||
}
|
||||
}
|
||||
56
internal/exchange/websocket/types.go
Normal file
56
internal/exchange/websocket/types.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/currency"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
|
||||
)
|
||||
|
||||
// PingHandler container for ping handler settings
|
||||
type PingHandler struct {
|
||||
Websocket bool
|
||||
UseGorillaHandler bool
|
||||
MessageType int
|
||||
Message []byte
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
// FundingData defines funding data
|
||||
type FundingData struct {
|
||||
Timestamp time.Time
|
||||
CurrencyPair currency.Pair
|
||||
AssetType asset.Item
|
||||
Exchange string
|
||||
Amount float64
|
||||
Rate float64
|
||||
Period int64
|
||||
Side order.Side
|
||||
}
|
||||
|
||||
// KlineData defines kline feed
|
||||
type KlineData struct {
|
||||
Timestamp time.Time
|
||||
Pair currency.Pair
|
||||
AssetType asset.Item
|
||||
Exchange string
|
||||
StartTime time.Time
|
||||
CloseTime time.Time
|
||||
Interval string
|
||||
OpenPrice float64
|
||||
ClosePrice float64
|
||||
HighPrice float64
|
||||
LowPrice float64
|
||||
Volume float64
|
||||
}
|
||||
|
||||
// UnhandledMessageWarning defines a container for unhandled message warnings
|
||||
type UnhandledMessageWarning struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// Reporter interface groups observability functionality over Websocket request latency.
|
||||
type Reporter interface {
|
||||
Latency(name string, message []byte, t time.Duration)
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package exchange
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/config"
|
||||
@@ -31,6 +31,6 @@ func TestMockHTTPInstance(t *testing.T) {
|
||||
|
||||
// TestMockWsInstance exercises MockWsInstance
|
||||
func TestMockWsInstance(t *testing.T) {
|
||||
b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *websocket.Conn) error { return nil }))
|
||||
b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *gws.Conn) error { return nil }))
|
||||
require.NotNil(t, b, "MockWsInstance must not be nil")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user