Websocket: Restructure files and types (#1859)

* Websocket: Rename stream package

* Websocket: Rename Websocket to Manager

* Websocket: Replace explicit errs with common.NilGuard

* Websocket: Move websocket_types.go to types.go

* Websocket: Minor field comment and alignment in types

* Webosocket: Rename WebsocketConnection to Connection

* Alphapoint: Make gorilla ws import explicit

Just to avoid confusion with our own packages.

* Websocket: Move stream_match to match

* Websocket: Move websocket_connection to connection

* Websocket: Move websocket.go to manager.go

* Websocket: Break out all subscription methods into subscriptions.go

* Websocket: Move connection type into its file

* Websocket: Remove PositionUpdated

Type is not used anywhere

* Kraken: Use local constant for pong

Was the only use of websocket.Pong and doesn't really feel right to
represent kraken's api resp in one of our packages

* Websocket: Move connection sub-types to connection package

* Websocket: Move manager types into manager

* Websocket: Move ConnectionWrapper into manager

* Websocket: Move websocket_test to manager_test

* Websocket: Privatise connectionWrapper

* Websocket: Remaining types into types.go

These really belong somewhere else mostly, but this will do for now

* Websocket: Tidy up connection method vars

* Gofumpt: Moving package imports around

* Websocket: Rename errDuplicateConnectionSetup

* Websocket: Fix duplicate import of gws

* Websocket: Fix gofumpt -extra

* Websocket: Standardise import of gws across other pkgs

* Kraken: Remove unused sub conf consts

These were replaced by the generic Levels and Depth fields on all subs

* Websocket: Privitise ConnectioWrapper fields

* Websocket: inline single use var WebsocketNotAuthenticatedUsingRest

* Websocket: Move documentation to template

* Bithumb: Assertify TestWsHandleData
This commit is contained in:
Gareth Kirwan
2025-04-10 08:25:02 +02:00
committed by GitHub
parent 676b2e0367
commit b4e45e9a1b
119 changed files with 3169 additions and 3056 deletions

View File

@@ -0,0 +1,174 @@
# GoCryptoTrader package Websocket
<img src="/common/gctlogo.png?raw=true" width="350px" height="350px" hspace="70">
[![Build Status](https://github.com/thrasher-corp/gocryptotrader/actions/workflows/tests.yml/badge.svg?branch=master)](https://github.com/thrasher-corp/gocryptotrader/actions/workflows/tests.yml)
[![Software License](https://img.shields.io/badge/License-MIT-orange.svg?style=flat-square)](https://github.com/thrasher-corp/gocryptotrader/blob/master/LICENSE)
[![GoDoc](https://godoc.org/github.com/thrasher-corp/gocryptotrader?status.svg)](https://godoc.org/github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket)
[![Coverage Status](https://codecov.io/gh/thrasher-corp/gocryptotrader/graph/badge.svg?token=41784B23TS)](https://codecov.io/gh/thrasher-corp/gocryptotrader)
[![Go Report Card](https://goreportcard.com/badge/github.com/thrasher-corp/gocryptotrader)](https://goreportcard.com/report/github.com/thrasher-corp/gocryptotrader)
This websocket package is part of the GoCryptoTrader codebase.
## This is still in active development
You can track ideas, planned features and what's in progress on our [GoCryptoTrader Kanban board](https://github.com/orgs/thrasher-corp/projects/3).
Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTc5ZDE1ZTNiOGM3ZGMyMmY1NTAxYWZhODE0MWM5N2JlZDk1NDU0YTViYzk4NTk3OTRiMDQzNGQ1YTc4YmRlMTk)
## Overview
The `websocket` package provides methods to manage connections and subscriptions for exchange websockets.
## Features
- Handle real-time market data streams
- Unified interface for managing data streams
- Multi-connection management - a system that can be used to manage multiple connections to the same exchange
- Connection monitoring - a system that can be used to monitor the health of the websocket connections. This can be used to check if the connection is still alive and if it is not, it will attempt to reconnect
- Traffic monitoring - will reconnect if no message is sent for a period of time defined in your config
- Subscription management - a system that can be used to manage subscriptions to various data streams
- Rate limiting - a system that can be used to rate limit the number of requests sent to the exchange
- Message ID generation - a system that can be used to generate message IDs for websocket requests
- Websocket message response matching - can be used to match websocket responses to the requests that were sent
## Usage
### Default single websocket connection
Example setup for the `websocket` package connection:
```go
package main
import (
"github.com/thrasher-corp/gocryptotrader/internal/exchange/websocket"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
)
type Exchange struct {
exchange.Base
}
// In the exchange wrapper this will set up the initial pointer field provided by exchange.Base
func (e *Exchange) SetDefault() {
e.Websocket = websocket.NewManager()
e.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit
e.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout
e.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit
}
// In the exchange wrapper this is the original setup pattern for the websocket services
func (e *Exchange) Setup(exch *config.Exchange) error {
// This sets up global connection, sub, unsub and generate subscriptions for each connection defined below.
if err := e.Websocket.Setup(&websocket.ManagerSetup{
ExchangeConfig: exch,
DefaultURL: connectionURLString,
RunningURL: connectionURLString,
Connector: e.WsConnect,
Subscriber: e.Subscribe,
Unsubscriber: e.Unsubscribe,
GenerateSubscriptions: e.GenerateDefaultSubscriptions,
Features: &e.Features.Supports.WebsocketCapabilities,
MaxWebsocketSubscriptionsPerConnection: 240,
OrderbookBufferConfig: buffer.Config{ Checksum: e.CalculateUpdateOrderbookChecksum },
}); err != nil {
return err
}
// This is a public websocket connection
if err := ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
URL: connectionURLString,
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exchangeWebsocketResponseMaxLimit,
RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1),
}); err != nil {
return err
}
// This is a private websocket connection
return ok.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
URL: privateConnectionURLString,
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exchangeWebsocketResponseMaxLimit,
Authenticated: true,
RateLimit: request.NewRateLimitWithWeight(time.Second, 2, 1),
})
}
```
### Multiple websocket connections
The example below provides the now optional multi connection management system which allows for more connections
to be maintained and established based off URL, connections types, asset types etc.
```go
func (e *Exchange) Setup(exch *config.Exchange) error {
// This sets up global connection, sub, unsub and generate subscriptions for each connection defined below.
if err := e.Websocket.Setup(&websocket.ManagerSetup{
ExchangeConfig: exch,
Features: &e.Features.Supports.WebsocketCapabilities,
FillsFeed: e.Features.Enabled.FillsFeed,
TradeFeed: e.Features.Enabled.TradeFeed,
UseMultiConnectionManagement: true,
})
if err != nil {
return err
}
// Spot connection
err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
URL: connectionURLStringForSpot,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
// Custom handlers for the specific connection:
Handler: e.WsHandleSpotData,
Subscriber: e.SpotSubscribe,
Unsubscriber: e.SpotUnsubscribe,
GenerateSubscriptions: e.GenerateDefaultSubscriptionsSpot,
Connector: e.WsConnectSpot,
BespokeGenerateMessageID: e.GenerateWebsocketMessageID,
})
if err != nil {
return err
}
// Futures connection - USDT margined
err = g.Websocket.SetupNewConnection(&websocket.ConnectionSetup{
URL: connectionURLStringForSpotForFutures,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
// Custom handlers for the specific connection:
Handler: func(ctx context.Context, incoming []byte) error { return e.WsHandleFuturesData(ctx, incoming, asset.Futures) },
Subscriber: e.FuturesSubscribe,
Unsubscriber: e.FuturesUnsubscribe,
GenerateSubscriptions: func() (subscription.List, error) { return e.GenerateFuturesDefaultSubscriptions(currency.USDT) },
Connector: e.WsFuturesConnect,
BespokeGenerateMessageID: e.GenerateWebsocketMessageID,
})
if err != nil {
return err
}
}
```
## Contribution
Please feel free to submit any pull requests or suggest any desired features to be added.
When submitting a PR, please abide by our coding guidelines:
+ Code must adhere to the official Go [formatting](https://golang.org/doc/effective_go.html#formatting) guidelines (i.e. uses [gofmt](https://golang.org/cmd/gofmt/)).
+ Code must be documented adhering to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines.
+ Code must adhere to our [coding style](https://github.com/thrasher-corp/gocryptotrader/blob/master/doc/coding_style.md).
+ Pull requests need to be based on and opened against the `master` branch.
## Donations
<img src="https://github.com/thrasher-corp/gocryptotrader/blob/master/web/src/assets/donate.png?raw=true" hspace="70">
If this framework helped you in any way, or you would like to support the developers working on it, please donate Bitcoin to:
***bc1qk0jareu4jytc0cfrhr5wgshsq8282awpavfahc***

View File

@@ -0,0 +1,344 @@
package buffer
import (
"errors"
"fmt"
"sort"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
"github.com/thrasher-corp/gocryptotrader/log"
)
const packageError = "websocket orderbook buffer error: %w"
var (
errExchangeConfigNil = errors.New("exchange config is nil")
errBufferConfigNil = errors.New("buffer config is nil")
errUnsetDataHandler = errors.New("datahandler unset")
errIssueBufferEnabledButNoLimit = errors.New("buffer enabled but no limit set")
errUpdateIsNil = errors.New("update is nil")
errUpdateNoTargets = errors.New("update bid/ask targets cannot be nil")
errDepthNotFound = errors.New("orderbook depth not found")
errRESTOverwrite = errors.New("orderbook has been overwritten by REST protocol")
errInvalidAction = errors.New("invalid action")
errAmendFailure = errors.New("orderbook amend update failure")
errDeleteFailure = errors.New("orderbook delete update failure")
errInsertFailure = errors.New("orderbook insert update failure")
errUpdateInsertFailure = errors.New("orderbook update/insert update failure")
errRESTTimerLapse = errors.New("rest sync timer lapse with active websocket connection")
errOrderbookFlushed = errors.New("orderbook flushed")
)
// Setup sets private variables
func (w *Orderbook) Setup(exchangeConfig *config.Exchange, c *Config, dataHandler chan<- any) error {
if exchangeConfig == nil { // exchange config fields are checked in websocket package prior to calling this, so further checks are not needed
return fmt.Errorf(packageError, errExchangeConfigNil)
}
if c == nil {
return fmt.Errorf(packageError, errBufferConfigNil)
}
if dataHandler == nil {
return fmt.Errorf(packageError, errUnsetDataHandler)
}
if exchangeConfig.Orderbook.WebsocketBufferEnabled &&
exchangeConfig.Orderbook.WebsocketBufferLimit < 1 {
return fmt.Errorf(packageError, errIssueBufferEnabledButNoLimit)
}
// NOTE: These variables are set by config.json under "orderbook" for each individual exchange
w.bufferEnabled = exchangeConfig.Orderbook.WebsocketBufferEnabled
w.obBufferLimit = exchangeConfig.Orderbook.WebsocketBufferLimit
w.sortBuffer = c.SortBuffer
w.sortBufferByUpdateIDs = c.SortBufferByUpdateIDs
w.updateEntriesByID = c.UpdateEntriesByID
w.exchangeName = exchangeConfig.Name
w.dataHandler = dataHandler
w.ob = make(map[key.PairAsset]*orderbookHolder)
w.verbose = exchangeConfig.Verbose
w.updateIDProgression = c.UpdateIDProgression
w.checksum = c.Checksum
return nil
}
// validate validates update against setup values
func (w *Orderbook) validate(u *orderbook.Update) error {
if u == nil {
return fmt.Errorf(packageError, errUpdateIsNil)
}
if len(u.Bids) == 0 && len(u.Asks) == 0 {
return fmt.Errorf(packageError, errUpdateNoTargets)
}
return nil
}
// Update updates a stored pointer to an orderbook.Depth struct containing a
// bid and ask Tranches, this switches between the usage of a buffered update
func (w *Orderbook) Update(u *orderbook.Update) error {
if err := w.validate(u); err != nil {
return err
}
w.mtx.Lock()
defer w.mtx.Unlock()
book, ok := w.ob[key.PairAsset{Base: u.Pair.Base.Item, Quote: u.Pair.Quote.Item, Asset: u.Asset}]
if !ok {
return fmt.Errorf("%w for Exchange %s CurrencyPair: %s AssetType: %s",
errDepthNotFound,
w.exchangeName,
u.Pair,
u.Asset)
}
// out of order update ID can be skipped
if w.updateIDProgression && u.UpdateID <= book.updateID {
if w.verbose {
log.Warnf(log.WebsocketMgr,
"Exchange %s CurrencyPair: %s AssetType: %s out of order websocket update received",
w.exchangeName,
u.Pair,
u.Asset)
}
return nil
}
// Checks for when the rest protocol overwrites a streaming dominated book
// will stop updating book via incremental updates. This occurs because our
// sync manager (engine/sync.go) timer has elapsed for streaming. Usually
// because the book is highly illiquid.
isREST, err := book.ob.IsRESTSnapshot()
if err != nil {
if !errors.Is(err, orderbook.ErrOrderbookInvalid) {
return err
}
// In the event a checksum or processing error invalidates the book, all
// updates that could be stored in the websocket buffer, skip applying
// until a new snapshot comes through.
if w.verbose {
log.Warnf(log.WebsocketMgr,
"Exchange %s CurrencyPair: %s AssetType: %s underlying book is invalid, cannot apply update.",
w.exchangeName,
u.Pair,
u.Asset)
}
return nil
}
if isREST {
if w.verbose {
log.Warnf(log.WebsocketMgr,
"%s for Exchange %s CurrencyPair: %s AssetType: %s consider extending synctimeoutwebsocket",
errRESTOverwrite,
w.exchangeName,
u.Pair,
u.Asset)
}
// Instance of illiquidity, this signal notifies that there is websocket
// activity. We can invalidate the book and request a new snapshot. All
// further updates through the websocket should be caught above in the
// IsRestSnapshot() call.
return book.ob.Invalidate(errRESTTimerLapse)
}
if w.bufferEnabled {
var processed bool
processed, err = w.processBufferUpdate(book, u)
if err != nil {
return err
}
if !processed {
return nil
}
} else {
err = w.processObUpdate(book, u)
if err != nil {
return err
}
}
// Publish all state changes, disregarding verbosity or sync requirements.
book.ob.Publish()
w.dataHandler <- book.ob
return nil
}
// processBufferUpdate stores update into buffer, when buffer at capacity as
// defined by w.obBufferLimit it well then sort and apply updates.
func (w *Orderbook) processBufferUpdate(o *orderbookHolder, u *orderbook.Update) (bool, error) {
*o.buffer = append(*o.buffer, *u)
if len(*o.buffer) < w.obBufferLimit {
return false, nil
}
if w.sortBuffer {
// sort by last updated to ensure each update is in order
if w.sortBufferByUpdateIDs {
sort.Slice(*o.buffer, func(i, j int) bool {
return (*o.buffer)[i].UpdateID < (*o.buffer)[j].UpdateID
})
} else {
sort.Slice(*o.buffer, func(i, j int) bool {
return (*o.buffer)[i].UpdateTime.Before((*o.buffer)[j].UpdateTime)
})
}
}
for i := range *o.buffer {
err := w.processObUpdate(o, &(*o.buffer)[i])
if err != nil {
return false, err
}
}
// clear buffer of old updates
*o.buffer = nil
return true, nil
}
// processObUpdate processes updates either by its corresponding id or by price level
func (w *Orderbook) processObUpdate(o *orderbookHolder, u *orderbook.Update) error {
// Both update methods require post processing to ensure the orderbook is in a valid state.
if w.updateEntriesByID {
if err := o.updateByIDAndAction(u); err != nil {
return err
}
} else {
if err := o.updateByPrice(u); err != nil {
return err
}
}
if w.checksum != nil {
compare, err := o.ob.Retrieve()
if err != nil {
return err
}
err = w.checksum(compare, u.Checksum)
if err != nil {
return o.ob.Invalidate(err)
}
o.updateID = u.UpdateID
} else if o.ob.VerifyOrderbook() {
compare, err := o.ob.Retrieve()
if err != nil {
return err
}
err = compare.Verify()
if err != nil {
return o.ob.Invalidate(err)
}
}
return nil
}
// updateByPrice amends amount if match occurs by price, deletes if amount is
// zero or less and inserts if not found.
func (o *orderbookHolder) updateByPrice(updts *orderbook.Update) error {
return o.ob.UpdateBidAskByPrice(updts)
}
// updateByIDAndAction will receive an action to execute against the orderbook
// it will then match by IDs instead of price to perform the action
func (o *orderbookHolder) updateByIDAndAction(updts *orderbook.Update) error {
switch updts.Action {
case orderbook.Amend:
err := o.ob.UpdateBidAskByID(updts)
if err != nil {
return fmt.Errorf("%w %w", errAmendFailure, err)
}
case orderbook.Delete:
// edge case for Bitfinex as their streaming endpoint duplicates deletes
bypassErr := o.ob.GetName() == "Bitfinex" && o.ob.IsFundingRate()
err := o.ob.DeleteBidAskByID(updts, bypassErr)
if err != nil {
return fmt.Errorf("%w %w", errDeleteFailure, err)
}
case orderbook.Insert:
err := o.ob.InsertBidAskByID(updts)
if err != nil {
return fmt.Errorf("%w %w", errInsertFailure, err)
}
case orderbook.UpdateInsert:
err := o.ob.UpdateInsertByID(updts)
if err != nil {
return fmt.Errorf("%w %w", errUpdateInsertFailure, err)
}
default:
return fmt.Errorf("%w [%d]", errInvalidAction, updts.Action)
}
return nil
}
// LoadSnapshot loads initial snapshot of orderbook data from websocket
func (w *Orderbook) LoadSnapshot(book *orderbook.Base) error {
// Checks if book can deploy to depth
err := book.Verify()
if err != nil {
return err
}
w.mtx.Lock()
defer w.mtx.Unlock()
holder, ok := w.ob[key.PairAsset{Base: book.Pair.Base.Item, Quote: book.Pair.Quote.Item, Asset: book.Asset}]
if !ok {
// Associate orderbook pointer with local exchange depth map
var depth *orderbook.Depth
depth, err = orderbook.DeployDepth(book.Exchange, book.Pair, book.Asset)
if err != nil {
return err
}
depth.AssignOptions(book)
buffer := make([]orderbook.Update, w.obBufferLimit)
holder = &orderbookHolder{ob: depth, buffer: &buffer}
w.ob[key.PairAsset{Base: book.Pair.Base.Item, Quote: book.Pair.Quote.Item, Asset: book.Asset}] = holder
}
holder.updateID = book.LastUpdateID
err = holder.ob.LoadSnapshot(book.Bids, book.Asks, book.LastUpdateID, book.LastUpdated, book.UpdatePushedAt, false)
if err != nil {
return err
}
holder.ob.Publish()
w.dataHandler <- holder.ob
return nil
}
// GetOrderbook returns an orderbook copy as orderbook.Base
func (w *Orderbook) GetOrderbook(p currency.Pair, a asset.Item) (*orderbook.Base, error) {
w.mtx.Lock()
defer w.mtx.Unlock()
book, ok := w.ob[key.PairAsset{Base: p.Base.Item, Quote: p.Quote.Item, Asset: a}]
if !ok {
return nil, fmt.Errorf("%s %s %s %w", w.exchangeName, p, a, errDepthNotFound)
}
return book.ob.Retrieve()
}
// FlushBuffer flushes w.ob data to be garbage collected and refreshed when a
// connection is lost and reconnected
func (w *Orderbook) FlushBuffer() {
w.mtx.Lock()
w.ob = make(map[key.PairAsset]*orderbookHolder)
w.mtx.Unlock()
}
// FlushOrderbook flushes independent orderbook
func (w *Orderbook) FlushOrderbook(p currency.Pair, a asset.Item) error {
w.mtx.Lock()
defer w.mtx.Unlock()
book, ok := w.ob[key.PairAsset{Base: p.Base.Item, Quote: p.Quote.Item, Asset: a}]
if !ok {
return fmt.Errorf("cannot flush orderbook %s %s %s %w",
w.exchangeName,
p,
a,
errDepthNotFound)
}
// error not needed in this return
_ = book.ob.Invalidate(errOrderbookFlushed)
return nil
}

View File

@@ -0,0 +1,992 @@
package buffer
import (
"errors"
"math/rand"
"slices"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
)
var (
itemArray = [][]orderbook.Tranche{
{{Price: 1000, Amount: 1, ID: 1000}},
{{Price: 2000, Amount: 1, ID: 2000}},
{{Price: 3000, Amount: 1, ID: 3000}},
{{Price: 3000, Amount: 2, ID: 4000}},
{{Price: 4000, Amount: 0, ID: 6000}},
{{Price: 5000, Amount: 1, ID: 5000}},
}
offset = common.Counter{}
)
const exchangeName = "exchangeTest"
// getExclusivePair returns a currency pair with a unique ID for testing as books are centralised and changes will affect other tests
func getExclusivePair() (currency.Pair, error) {
return currency.NewPairFromStrings(currency.BTC.String(), currency.USDT.String()+strconv.FormatInt(offset.IncrementAndGet(), 10))
}
func createSnapshot(pair currency.Pair, bookVerifiy ...bool) (holder *Orderbook, asks, bids orderbook.Tranches, err error) {
asks = orderbook.Tranches{{Price: 4000, Amount: 1, ID: 6}}
bids = orderbook.Tranches{{Price: 4000, Amount: 1, ID: 6}}
book := &orderbook.Base{
Exchange: exchangeName,
Asks: asks,
Bids: bids,
Asset: asset.Spot,
Pair: pair,
PriceDuplication: true,
LastUpdated: time.Now(),
VerifyOrderbook: len(bookVerifiy) > 0 && bookVerifiy[0],
}
newBook := make(map[key.PairAsset]*orderbookHolder)
ch := make(chan any)
go func(<-chan any) { // reader
for range ch {
continue
}
}(ch)
holder = &Orderbook{
exchangeName: exchangeName,
dataHandler: ch,
ob: newBook,
}
err = holder.LoadSnapshot(book)
return holder, asks, bids, err
}
func bidAskGenerator() []orderbook.Tranche {
response := make([]orderbook.Tranche, 100)
for i := range 100 {
price := float64(rand.Intn(1000)) //nolint:gosec // no need to import crypo/rand for testing
if price == 0 {
price = 1
}
response[i] = orderbook.Tranche{
Amount: float64(rand.Intn(10)), //nolint:gosec // no need to import crypo/rand for testing
Price: price,
ID: int64(i),
}
}
return response
}
func BenchmarkUpdateBidsByPrice(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
ob, _, _, err := createSnapshot(cp)
require.NoError(b, err)
for b.Loop() {
bidAsks := bidAskGenerator()
update := &orderbook.Update{
Bids: bidAsks,
Asks: bidAsks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
holder := ob.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
require.NoError(b, holder.updateByPrice(update))
}
}
func BenchmarkUpdateAsksByPrice(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
ob, _, _, err := createSnapshot(cp)
require.NoError(b, err)
for b.Loop() {
bidAsks := bidAskGenerator()
update := &orderbook.Update{
Bids: bidAsks,
Asks: bidAsks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
holder := ob.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
require.NoError(b, holder.updateByPrice(update))
}
}
// BenchmarkBufferPerformance demonstrates buffer more performant than multi
// process calls
// 890016 1688 ns/op 416 B/op 3 allocs/op
func BenchmarkBufferPerformance(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
holder, asks, bids, err := createSnapshot(cp)
require.NoError(b, err)
holder.bufferEnabled = true
update := &orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
for b.Loop() {
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
update.Asks = itemArray[randomIndex]
update.Bids = itemArray[randomIndex]
require.NoError(b, holder.Update(update))
}
}
// BenchmarkBufferSortingPerformance benchmark
//
// 613964 2093 ns/op 440 B/op 4 allocs/op
func BenchmarkBufferSortingPerformance(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
holder, asks, bids, err := createSnapshot(cp)
require.NoError(b, err)
holder.bufferEnabled = true
holder.sortBuffer = true
update := &orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
for b.Loop() {
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
update.Asks = itemArray[randomIndex]
update.Bids = itemArray[randomIndex]
require.NoError(b, holder.Update(update))
}
}
// BenchmarkBufferSortingPerformance benchmark
// 914500 1599 ns/op 440 B/op 4 allocs/op
func BenchmarkBufferSortingByIDPerformance(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
holder, asks, bids, err := createSnapshot(cp)
require.NoError(b, err)
holder.bufferEnabled = true
holder.sortBuffer = true
holder.sortBufferByUpdateIDs = true
update := &orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
for b.Loop() {
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
update.Asks = itemArray[randomIndex]
update.Bids = itemArray[randomIndex]
require.NoError(b, holder.Update(update))
}
}
// BenchmarkNoBufferPerformance demonstrates orderbook process more performant
// than buffer
// 122659 12792 ns/op 972 B/op 7 allocs/op PRIOR
// 1225924 1028 ns/op 240 B/op 2 allocs/op CURRENT
func BenchmarkNoBufferPerformance(b *testing.B) {
cp, err := getExclusivePair()
require.NoError(b, err)
obl, asks, bids, err := createSnapshot(cp)
require.NoError(b, err)
update := &orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
}
for b.Loop() {
randomIndex := rand.Intn(4) //nolint:gosec // no need to import crypo/rand for testing
update.Asks = itemArray[randomIndex]
update.Bids = itemArray[randomIndex]
require.NoError(b, obl.Update(update))
}
}
func TestUpdates(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
err = book.updateByPrice(&orderbook.Update{
Bids: itemArray[5],
Asks: itemArray[5],
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
assert.NoError(t, err)
err = book.updateByPrice(&orderbook.Update{
Bids: itemArray[0],
Asks: itemArray[0],
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
assert.NoError(t, err)
askLen, err := book.ob.GetAskLength()
require.NoError(t, err)
assert.Equal(t, 3, askLen)
}
// TestHittingTheBuffer logic test
func TestHittingTheBuffer(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
holder.bufferEnabled = true
holder.obBufferLimit = 5
for i := range itemArray {
asks := itemArray[i]
bids := itemArray[i]
err = holder.Update(&orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
require.NoError(t, err)
}
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
askLen, err := book.ob.GetAskLength()
require.NoError(t, err)
assert.Equal(t, 3, askLen)
bidLen, err := book.ob.GetBidLength()
require.NoError(t, err)
assert.Equal(t, 3, bidLen)
}
// TestInsertWithIDs logic test
func TestInsertWithIDs(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
holder.bufferEnabled = true
holder.updateEntriesByID = true
holder.obBufferLimit = 5
for i := range itemArray {
asks := itemArray[i]
if asks[0].Amount <= 0 {
continue
}
bids := itemArray[i]
err = holder.Update(&orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
Action: orderbook.UpdateInsert,
})
require.NoError(t, err)
}
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
askLen, err := book.ob.GetAskLength()
require.NoError(t, err)
assert.Equal(t, 6, askLen)
bidLen, err := book.ob.GetBidLength()
require.NoError(t, err)
assert.Equal(t, 6, bidLen)
cp, err = getExclusivePair()
require.NoError(t, err)
holder, _, _, err = createSnapshot(cp, true)
require.NoError(t, err)
holder.checksum = nil
holder.updateIDProgression = false
err = holder.Update(&orderbook.Update{
UpdateTime: time.Now(),
Asset: asset.Spot,
Asks: []orderbook.Tranche{{Price: 999999}},
Pair: cp,
})
require.NoError(t, err)
}
// TestSortIDs logic test
func TestSortIDs(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
holder.bufferEnabled = true
holder.sortBufferByUpdateIDs = true
holder.sortBuffer = true
holder.obBufferLimit = 5
for i := range itemArray {
asks := itemArray[i]
bids := itemArray[i]
err = holder.Update(&orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateID: int64(i),
Asset: asset.Spot,
UpdateTime: time.Now(),
})
require.NoError(t, err)
}
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
askLen, err := book.ob.GetAskLength()
require.NoError(t, err)
assert.Equal(t, 3, askLen)
bidLen, err := book.ob.GetBidLength()
require.NoError(t, err)
assert.Equal(t, 3, bidLen)
}
// TestOutOfOrderIDs logic test
func TestOutOfOrderIDs(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
outOFOrderIDs := []int64{2, 1, 5, 3, 4, 6, 7}
assert.Equal(t, 1000., itemArray[0][0].Price)
holder.bufferEnabled = true
holder.sortBuffer = true
holder.obBufferLimit = 5
for i := range itemArray {
asks := itemArray[i]
err = holder.Update(&orderbook.Update{
Asks: asks,
Pair: cp,
UpdateID: outOFOrderIDs[i],
Asset: asset.Spot,
UpdateTime: time.Now(),
})
require.NoError(t, err)
}
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
cpy, err := book.ob.Retrieve()
require.NoError(t, err)
// Index 1 since index 0 is price 7000
assert.Equal(t, 2000., cpy.Asks[1].Price)
}
func TestOrderbookLastUpdateID(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
assert.Equal(t, 1000., itemArray[0][0].Price)
holder.checksum = func(*orderbook.Base, uint32) error { return errors.New("testerino") }
// this update invalidates the book
err = holder.Update(&orderbook.Update{
Asks: []orderbook.Tranche{{Price: 999999}},
Pair: cp,
UpdateID: -1,
Asset: asset.Spot,
UpdateTime: time.Now(),
})
require.ErrorIs(t, err, orderbook.ErrOrderbookInvalid)
cp, err = getExclusivePair()
require.NoError(t, err)
holder, _, _, err = createSnapshot(cp)
require.NoError(t, err)
holder.checksum = func(*orderbook.Base, uint32) error { return nil }
holder.updateIDProgression = true
for i := range itemArray {
asks := itemArray[i]
err = holder.Update(&orderbook.Update{
Asks: asks,
Pair: cp,
UpdateID: int64(i) + 1,
Asset: asset.Spot,
UpdateTime: time.Now(),
})
require.NoError(t, err)
}
// out of order
err = holder.Update(&orderbook.Update{
Asks: []orderbook.Tranche{{Price: 999999}},
Pair: cp,
UpdateID: 1,
Asset: asset.Spot,
})
require.NoError(t, err)
ob, err := holder.GetOrderbook(cp, asset.Spot)
require.NoError(t, err)
assert.Equal(t, int64(len(itemArray)), ob.LastUpdateID)
}
// TestRunUpdateWithoutSnapshot logic test
func TestRunUpdateWithoutSnapshot(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
var holder Orderbook
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
bids := []orderbook.Tranche{{Price: 5999, Amount: 1, ID: 8}, {Price: 4000, Amount: 1, ID: 9}}
holder.exchangeName = exchangeName
err = holder.Update(&orderbook.Update{
Bids: bids,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
require.ErrorIs(t, err, errDepthNotFound)
}
// TestRunUpdateWithoutAnyUpdates logic test
func TestRunUpdateWithoutAnyUpdates(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
var obl Orderbook
obl.exchangeName = exchangeName
err = obl.Update(&orderbook.Update{
Bids: []orderbook.Tranche{},
Asks: []orderbook.Tranche{},
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
require.ErrorIs(t, err, errUpdateNoTargets)
}
// TestRunSnapshotWithNoData logic test
func TestRunSnapshotWithNoData(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
var obl Orderbook
obl.ob = make(map[key.PairAsset]*orderbookHolder)
obl.dataHandler = make(chan any, 1)
var snapShot1 orderbook.Base
snapShot1.Asset = asset.Spot
snapShot1.Pair = cp
snapShot1.Exchange = "test"
obl.exchangeName = "test"
snapShot1.LastUpdated = time.Now()
require.NoError(t, obl.LoadSnapshot(&snapShot1))
}
// TestLoadSnapshot logic test
func TestLoadSnapshot(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
var obl Orderbook
obl.dataHandler = make(chan any, 100)
obl.ob = make(map[key.PairAsset]*orderbookHolder)
var snapShot1 orderbook.Base
snapShot1.Exchange = "SnapshotWithOverride"
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
bids := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 9}}
snapShot1.Asks = asks
snapShot1.Bids = bids
snapShot1.Asset = asset.Spot
snapShot1.Pair = cp
snapShot1.LastUpdated = time.Now()
require.NoError(t, obl.LoadSnapshot(&snapShot1))
}
// TestFlushBuffer logic test
func TestFlushBuffer(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
obl, _, _, err := createSnapshot(cp)
require.NoError(t, err)
assert.NotEmpty(t, obl.ob)
obl.FlushBuffer()
assert.Empty(t, obl.ob)
}
// TestInsertingSnapShots logic test
func TestInsertingSnapShots(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
var holder Orderbook
holder.dataHandler = make(chan any, 100)
holder.ob = make(map[key.PairAsset]*orderbookHolder)
var snapShot1 orderbook.Base
snapShot1.Exchange = "WSORDERBOOKTEST1"
asks := []orderbook.Tranche{
{Price: 6000, Amount: 1, ID: 1},
{Price: 6001, Amount: 0.5, ID: 2},
{Price: 6002, Amount: 2, ID: 3},
{Price: 6003, Amount: 3, ID: 4},
{Price: 6004, Amount: 5, ID: 5},
{Price: 6005, Amount: 2, ID: 6},
{Price: 6006, Amount: 1.5, ID: 7},
{Price: 6007, Amount: 0.5, ID: 8},
{Price: 6008, Amount: 23, ID: 9},
{Price: 6009, Amount: 9, ID: 10},
{Price: 6010, Amount: 7, ID: 11},
}
bids := []orderbook.Tranche{
{Price: 5999, Amount: 1, ID: 12},
{Price: 5998, Amount: 0.5, ID: 13},
{Price: 5997, Amount: 2, ID: 14},
{Price: 5996, Amount: 3, ID: 15},
{Price: 5995, Amount: 5, ID: 16},
{Price: 5994, Amount: 2, ID: 17},
{Price: 5993, Amount: 1.5, ID: 18},
{Price: 5992, Amount: 0.5, ID: 19},
{Price: 5991, Amount: 23, ID: 20},
{Price: 5990, Amount: 9, ID: 21},
{Price: 5989, Amount: 7, ID: 22},
}
snapShot1.Asks = asks
snapShot1.Bids = bids
snapShot1.Asset = asset.Spot
snapShot1.Pair = cp
snapShot1.LastUpdated = time.Now()
require.NoError(t, holder.LoadSnapshot(&snapShot1))
var snapShot2 orderbook.Base
snapShot2.Exchange = "WSORDERBOOKTEST2"
asks = []orderbook.Tranche{
{Price: 51, Amount: 1, ID: 1},
{Price: 52, Amount: 0.5, ID: 2},
{Price: 53, Amount: 2, ID: 3},
{Price: 54, Amount: 3, ID: 4},
{Price: 55, Amount: 5, ID: 5},
{Price: 56, Amount: 2, ID: 6},
{Price: 57, Amount: 1.5, ID: 7},
{Price: 58, Amount: 0.5, ID: 8},
{Price: 59, Amount: 23, ID: 9},
{Price: 50, Amount: 9, ID: 10},
{Price: 60, Amount: 7, ID: 11},
}
bids = []orderbook.Tranche{
{Price: 49, Amount: 1, ID: 12},
{Price: 48, Amount: 0.5, ID: 13},
{Price: 47, Amount: 2, ID: 14},
{Price: 46, Amount: 3, ID: 15},
{Price: 45, Amount: 5, ID: 16},
{Price: 44, Amount: 2, ID: 17},
{Price: 43, Amount: 1.5, ID: 18},
{Price: 42, Amount: 0.5, ID: 19},
{Price: 41, Amount: 23, ID: 20},
{Price: 40, Amount: 9, ID: 21},
{Price: 39, Amount: 7, ID: 22},
}
snapShot2.Asks = asks
snapShot2.Asks.SortAsks()
snapShot2.Bids = bids
snapShot2.Bids.SortBids()
snapShot2.Asset = asset.Spot
snapShot2.Pair, err = getExclusivePair()
require.NoError(t, err)
snapShot2.LastUpdated = time.Now()
require.NoError(t, holder.LoadSnapshot(&snapShot2))
var snapShot3 orderbook.Base
snapShot3.Exchange = "WSORDERBOOKTEST3"
asks = []orderbook.Tranche{
{Price: 511, Amount: 1, ID: 1},
{Price: 52, Amount: 0.5, ID: 2},
{Price: 53, Amount: 2, ID: 3},
{Price: 54, Amount: 3, ID: 4},
{Price: 55, Amount: 5, ID: 5},
{Price: 56, Amount: 2, ID: 6},
{Price: 57, Amount: 1.5, ID: 7},
{Price: 58, Amount: 0.5, ID: 8},
{Price: 59, Amount: 23, ID: 9},
{Price: 50, Amount: 9, ID: 10},
{Price: 60, Amount: 7, ID: 11},
}
bids = []orderbook.Tranche{
{Price: 49, Amount: 1, ID: 12},
{Price: 48, Amount: 0.5, ID: 13},
{Price: 47, Amount: 2, ID: 14},
{Price: 46, Amount: 3, ID: 15},
{Price: 45, Amount: 5, ID: 16},
{Price: 44, Amount: 2, ID: 17},
{Price: 43, Amount: 1.5, ID: 18},
{Price: 42, Amount: 0.5, ID: 19},
{Price: 41, Amount: 23, ID: 20},
{Price: 40, Amount: 9, ID: 21},
{Price: 39, Amount: 7, ID: 22},
}
snapShot3.Asks = asks
snapShot3.Asks.SortAsks()
snapShot3.Bids = bids
snapShot3.Bids.SortBids()
snapShot3.Asset = asset.Futures
snapShot3.Pair, err = getExclusivePair()
require.NoError(t, err)
snapShot3.LastUpdated = time.Now()
require.NoError(t, holder.LoadSnapshot(&snapShot3))
ob, err := holder.GetOrderbook(snapShot1.Pair, snapShot1.Asset)
require.NoError(t, err)
assert.Equal(t, snapShot1.Asks[0], ob.Asks[0])
ob, err = holder.GetOrderbook(snapShot2.Pair, snapShot2.Asset)
require.NoError(t, err)
assert.Equal(t, snapShot2.Asks[0], ob.Asks[0])
ob, err = holder.GetOrderbook(snapShot3.Pair, snapShot3.Asset)
require.NoError(t, err)
assert.Equal(t, snapShot3.Asks[0], ob.Asks[0])
}
func TestGetOrderbook(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
ob, err := holder.GetOrderbook(cp, asset.Spot)
require.NoError(t, err)
bufferOb := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
b, err := bufferOb.ob.Retrieve()
require.NoError(t, err)
askLen, err := bufferOb.ob.GetAskLength()
require.NoError(t, err)
bidLen, err := bufferOb.ob.GetBidLength()
require.NoError(t, err)
if askLen != len(ob.Asks) ||
bidLen != len(ob.Bids) ||
b.Asset != ob.Asset ||
b.Exchange != ob.Exchange ||
b.LastUpdateID != ob.LastUpdateID ||
b.PriceDuplication != ob.PriceDuplication ||
b.Pair != ob.Pair {
t.Fatal("data on both books should be the same")
}
}
func TestSetup(t *testing.T) {
t.Parallel()
w := Orderbook{}
err := w.Setup(nil, nil, nil)
require.ErrorIs(t, err, errExchangeConfigNil)
exchangeConfig := &config.Exchange{}
err = w.Setup(exchangeConfig, nil, nil)
require.ErrorIs(t, err, errBufferConfigNil)
bufferConf := &Config{}
err = w.Setup(exchangeConfig, bufferConf, nil)
require.ErrorIs(t, err, errUnsetDataHandler)
exchangeConfig.Orderbook.WebsocketBufferEnabled = true
err = w.Setup(exchangeConfig, bufferConf, make(chan any))
require.ErrorIs(t, err, errIssueBufferEnabledButNoLimit)
exchangeConfig.Orderbook.WebsocketBufferLimit = 1337
exchangeConfig.Orderbook.WebsocketBufferEnabled = true
exchangeConfig.Name = "test"
bufferConf.SortBuffer = true
bufferConf.SortBufferByUpdateIDs = true
bufferConf.UpdateEntriesByID = true
err = w.Setup(exchangeConfig, bufferConf, make(chan any))
require.NoError(t, err)
if w.obBufferLimit != 1337 ||
!w.bufferEnabled ||
!w.sortBuffer ||
!w.sortBufferByUpdateIDs ||
!w.updateEntriesByID ||
w.exchangeName != "test" {
t.Errorf("Setup incorrectly loaded %s", w.exchangeName)
}
}
func TestValidate(t *testing.T) {
t.Parallel()
w := Orderbook{}
err := w.validate(nil)
require.ErrorIs(t, err, errUpdateIsNil)
err = w.validate(&orderbook.Update{})
require.ErrorIs(t, err, errUpdateNoTargets)
}
func TestEnsureMultipleUpdatesViaPrice(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
holder, _, _, err := createSnapshot(cp)
require.NoError(t, err)
asks := bidAskGenerator()
book := holder.ob[key.PairAsset{Base: cp.Base.Item, Quote: cp.Quote.Item, Asset: asset.Spot}]
err = book.updateByPrice(&orderbook.Update{
Bids: asks,
Asks: asks,
Pair: cp,
UpdateTime: time.Now(),
Asset: asset.Spot,
})
require.NoError(t, err)
askLen, err := book.ob.GetAskLength()
require.NoError(t, err)
assert.LessOrEqual(t, 3, askLen)
}
func deploySliceOrdered(size int) orderbook.Tranches {
items := make([]orderbook.Tranche, size)
for i := range size {
items[i] = orderbook.Tranche{Amount: 1, Price: rand.Float64() + float64(i), ID: rand.Int63()} //nolint:gosec // Not needed for tests
}
return items
}
func TestUpdateByIDAndAction(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
asks := deploySliceOrdered(100)
bids := slices.Clone(asks)
bids.Reverse()
book, err := orderbook.DeployDepth("test", cp, asset.Spot)
require.NoError(t, err)
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
require.NoError(t, err)
ob, err := book.Retrieve()
require.NoError(t, err)
require.NoError(t, ob.Verify())
holder := orderbookHolder{ob: book}
err = holder.updateByIDAndAction(&orderbook.Update{})
require.ErrorIs(t, err, errInvalidAction)
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.Amend,
Bids: []orderbook.Tranche{{Price: 100, ID: 6969}},
})
require.ErrorIs(t, err, errAmendFailure)
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
require.NoError(t, err)
// append to slice
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.UpdateInsert,
Bids: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 1}},
Asks: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 1}},
UpdateTime: time.Now(),
})
require.NoError(t, err)
cpy, err := book.Retrieve()
require.NoError(t, err)
require.Equal(t, 0., cpy.Bids[len(cpy.Bids)-1].Price)
require.Equal(t, 100., cpy.Asks[len(cpy.Asks)-1].Price)
// Change amount
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.UpdateInsert,
Bids: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 100}},
Asks: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 100}},
UpdateTime: time.Now(),
})
require.NoError(t, err)
cpy, err = book.Retrieve()
require.NoError(t, err)
require.Equal(t, 100., cpy.Bids[len(cpy.Bids)-1].Amount)
require.Equal(t, 100., cpy.Asks[len(cpy.Asks)-1].Amount)
// Change price level
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.UpdateInsert,
Bids: []orderbook.Tranche{{Price: 100, ID: 1337, Amount: 99}},
Asks: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 99}},
UpdateTime: time.Now(),
})
require.NoError(t, err)
cpy, err = book.Retrieve()
require.NoError(t, err)
require.Equal(t, 99., cpy.Bids[0].Amount)
require.Equal(t, 100., cpy.Bids[0].Price)
require.Equal(t, 99., cpy.Asks[0].Amount)
require.Equal(t, 0., cpy.Asks[0].Price)
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
require.NoError(t, err)
// Delete - not found
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.Delete,
Asks: []orderbook.Tranche{{Price: 0, ID: 1337, Amount: 99}},
})
require.ErrorIs(t, err, errDeleteFailure)
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(asks), 0, time.Now(), time.Now(), true)
require.NoError(t, err)
// Delete - found
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.Delete,
Asks: []orderbook.Tranche{asks[0]},
UpdateTime: time.Now(),
})
require.NoError(t, err)
askLen, err := book.GetAskLength()
require.NoError(t, err)
require.Equal(t, 99, askLen)
// Apply update
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.Amend,
Asks: []orderbook.Tranche{{ID: 123456}},
})
require.ErrorIs(t, err, errAmendFailure)
err = book.LoadSnapshot(slices.Clone(bids), slices.Clone(bids), 0, time.Now(), time.Now(), true)
require.NoError(t, err)
ob, err = book.Retrieve()
require.NoError(t, err)
require.NotEmpty(t, ob.Asks)
require.NotEmpty(t, ob.Bids)
update := ob.Asks[0]
update.Amount = 1337
err = holder.updateByIDAndAction(&orderbook.Update{
Action: orderbook.Amend,
Asks: []orderbook.Tranche{update},
UpdateTime: time.Now(),
})
require.NoError(t, err)
ob, err = book.Retrieve()
require.NoError(t, err)
require.Equal(t, 1337., ob.Asks[0].Amount)
}
func TestFlushOrderbook(t *testing.T) {
t.Parallel()
cp, err := getExclusivePair()
require.NoError(t, err)
w := &Orderbook{}
err = w.Setup(&config.Exchange{Name: "test"}, &Config{}, make(chan any, 2))
require.NoError(t, err)
var snapShot1 orderbook.Base
snapShot1.Exchange = "Snapshooooot"
asks := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 8}}
bids := []orderbook.Tranche{{Price: 4000, Amount: 1, ID: 9}}
snapShot1.Asks = asks
snapShot1.Bids = bids
snapShot1.Asset = asset.Spot
snapShot1.Pair = cp
snapShot1.LastUpdated = time.Now()
err = w.FlushOrderbook(cp, asset.Spot)
if err == nil {
t.Fatal("book not loaded error cannot be nil")
}
_, err = w.GetOrderbook(cp, asset.Spot)
require.ErrorIs(t, err, errDepthNotFound)
require.NoError(t, w.LoadSnapshot(&snapShot1))
require.NoError(t, w.FlushOrderbook(cp, asset.Spot))
_, err = w.GetOrderbook(cp, asset.Spot)
require.ErrorIs(t, err, orderbook.ErrOrderbookInvalid)
}

View File

@@ -0,0 +1,59 @@
package buffer
import (
"sync"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
)
// Config defines the configuration variables for the websocket buffer; snapshot
// and incremental update orderbook processing.
type Config struct {
// SortBuffer enables a websocket to sort incoming updates before processing.
SortBuffer bool
// SortBufferByUpdateIDs allows the sorting of the buffered updates by their
// corresponding update IDs.
SortBufferByUpdateIDs bool
// UpdateEntriesByID will match by IDs instead of price to perform the an
// action. e.g. update, delete, insert.
UpdateEntriesByID bool
// UpdateIDProgression requires that the new update ID be greater than the
// prior ID. This will skip processing and not error.
UpdateIDProgression bool
// Checksum is a package defined checksum calculation for updated books.
Checksum func(state *orderbook.Base, checksum uint32) error
}
// Orderbook defines a local cache of orderbooks for amending, appending
// and deleting changes and updates the main store for a stream
type Orderbook struct {
ob map[key.PairAsset]*orderbookHolder
obBufferLimit int
bufferEnabled bool
sortBuffer bool
sortBufferByUpdateIDs bool // When timestamps aren't provided, an id can help sort
updateEntriesByID bool // Use the update IDs to match ob entries
exchangeName string
dataHandler chan<- any
verbose bool
// updateIDProgression requires that the new update ID be greater than the
// prior ID. This will skip processing and not error.
updateIDProgression bool
// checksum is a package defined checksum calculation for updated books.
checksum func(state *orderbook.Base, checksum uint32) error
// TODO: sync.RWMutex. For the moment we process the orderbook in a single
// thread. In future when there are workers directly involved this can be
// can be improved with RW mechanics which will allow updates to occur at
// the same time on different books.
mtx sync.Mutex
}
// orderbookHolder defines a store of pending updates and a pointer to the
// orderbook depth
type orderbookHolder struct {
ob *orderbook.Depth
buffer *[]orderbook.Update
updateID int64
}

View File

@@ -0,0 +1,481 @@
package websocket
import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
gws "github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/encoding/json"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/log"
)
var (
// errConnectionFault is a connection fault error which alerts the system that a connection cycle needs to take place.
errConnectionFault = errors.New("connection fault")
errWebsocketIsDisconnected = errors.New("websocket connection is disconnected")
errRateLimitNotFound = errors.New("rate limit definition not found")
)
// Connection defines the interface for websocket connections
type Connection interface {
Dial(*gws.Dialer, http.Header) error
DialContext(context.Context, *gws.Dialer, http.Header) error
ReadMessage() Response
SetupPingHandler(request.EndpointLimit, PingHandler)
// GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set
// (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function
// defined in websocket_connection.go.
GenerateMessageID(highPrecision bool) int64
// SendMessageReturnResponse will send a WS message to the connection and wait for response
SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error)
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int) ([][]byte, error)
// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses with message inspection
SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, request any, expected int, messageInspector Inspector) ([][]byte, error)
// SendRawMessage sends a message over the connection without JSON encoding it
SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error
// SendJSONMessage sends a JSON encoded message over the connection
SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error
SetURL(string)
SetProxy(string)
GetURL() string
Shutdown() error
}
// ConnectionSetup defines variables for an individual stream connection
type ConnectionSetup struct {
ResponseCheckTimeout time.Duration
ResponseMaxLimit time.Duration
RateLimit *request.RateLimiterWithWeight
Authenticated bool
ConnectionLevelReporter Reporter
// URL defines the websocket server URL to connect to
URL string
// Connector is the function that will be called to connect to the
// exchange's websocket server. This will be called once when the stream
// service is started. Any bespoke connection logic should be handled here.
Connector func(ctx context.Context, conn Connection) error
// GenerateSubscriptions is a function that will be called to generate a
// list of subscriptions to be made to the exchange's websocket server.
GenerateSubscriptions func() (subscription.List, error)
// Subscriber is a function that will be called to send subscription
// messages based on the exchange's websocket server requirements to
// subscribe to specific channels.
Subscriber func(ctx context.Context, conn Connection, sub subscription.List) error
// Unsubscriber is a function that will be called to send unsubscription
// messages based on the exchange's websocket server requirements to
// unsubscribe from specific channels. NOTE: IF THE FEATURE IS ENABLED.
Unsubscriber func(ctx context.Context, conn Connection, unsub subscription.List) error
// Handler defines the function that will be called when a message is
// received from the exchange's websocket server. This function should
// handle the incoming message and pass it to the appropriate data handler.
Handler func(ctx context.Context, incoming []byte) error
// BespokeGenerateMessageID is a function that returns a unique message ID.
// This is useful for when an exchange connection requires a unique or
// structured message ID for each message sent.
BespokeGenerateMessageID func(highPrecision bool) int64
Authenticate func(ctx context.Context, conn Connection) error
// MessageFilter defines the criteria used to match messages to a specific connection.
// The filter enables precise routing and handling of messages for distinct connection contexts.
MessageFilter any
}
// Inspector is used to verify messages via SendMessageReturnResponsesWithInspection
// It inspects the []bytes websocket message and returns true if the message is the final message in a sequence of expected messages
type Inspector interface {
IsFinal([]byte) bool
}
// Response defines generalised data from the websocket connection
type Response struct {
Type int
Raw []byte
}
// connection contains all the data needed to send a message to a websocket connection
type connection struct {
Verbose bool
connected int32
writeControl sync.Mutex // Gorilla websocket does not allow more than one goroutine to utilise write methods
RateLimit *request.RateLimiterWithWeight // RateLimit is a rate limiter for the connection itself
RateLimitDefinitions request.RateLimitDefinitions // RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections
Reporter Reporter
ExchangeName string
URL string
ProxyURL string
Wg *sync.WaitGroup
Connection *gws.Conn
shutdown chan struct{}
Match *Match
ResponseMaxLimit time.Duration
Traffic chan struct{}
readMessageErrors chan error
bespokeGenerateMessageID func(highPrecision bool) int64
}
// Dial sets proxy urls and then connects to the websocket
func (c *connection) Dial(dialer *gws.Dialer, headers http.Header) error {
return c.DialContext(context.Background(), dialer, headers)
}
// DialContext sets proxy urls and then connects to the websocket
func (c *connection) DialContext(ctx context.Context, dialer *gws.Dialer, headers http.Header) error {
if c.ProxyURL != "" {
proxy, err := url.Parse(c.ProxyURL)
if err != nil {
return err
}
dialer.Proxy = http.ProxyURL(proxy)
}
var err error
var conStatus *http.Response
c.Connection, conStatus, err = dialer.DialContext(ctx, c.URL, headers)
if err != nil {
if conStatus != nil {
_ = conStatus.Body.Close()
return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", c.ExchangeName, c.URL, conStatus, conStatus.StatusCode, err)
}
return fmt.Errorf("%s websocket connection: %v Error: %w", c.ExchangeName, c.URL, err)
}
_ = conStatus.Body.Close()
if c.Verbose {
log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", c.ExchangeName, c.URL)
}
select {
case c.Traffic <- struct{}{}:
default:
}
c.setConnectedStatus(true)
return nil
}
// SendJSONMessage sends a JSON encoded message over the connection
func (c *connection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error {
return c.writeToConn(ctx, epl, func() error {
if request.IsVerbose(ctx, c.Verbose) {
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(msg))
}
}
return c.Connection.WriteJSON(data)
})
}
// SendRawMessage sends a message over the connection without JSON encoding it
func (c *connection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error {
return c.writeToConn(ctx, epl, func() error {
if request.IsVerbose(ctx, c.Verbose) {
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", c.ExchangeName, removeURLQueryString(c.URL), string(message))
}
return c.Connection.WriteMessage(messageType, message)
})
}
func (c *connection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error {
if !c.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message %w", c.ExchangeName, errWebsocketIsDisconnected)
}
var rl *request.RateLimiterWithWeight
if c.RateLimitDefinitions != nil {
var ok bool
if rl, ok = c.RateLimitDefinitions[epl]; !ok && c.RateLimit == nil {
// Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is
// set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits
// if they are defined.
return fmt.Errorf("%s websocket connection: %w for %v", c.ExchangeName, errRateLimitNotFound, epl)
}
}
if rl == nil {
// If a global rate limit definition is not found, use the connection rate limit as a fallback.
rl = c.RateLimit
}
if rl != nil {
if err := request.RateLimit(ctx, rl); err != nil {
return fmt.Errorf("%s websocket connection: rate limit error: %w", c.ExchangeName, err)
}
}
// This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check.
c.writeControl.Lock()
defer c.writeControl.Unlock()
return writeConn()
}
// SetupPingHandler will automatically send ping or pong messages based on
// WebsocketPingHandler configuration
func (c *connection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) {
if handler.UseGorillaHandler {
c.Connection.SetPingHandler(func(msg string) error {
err := c.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay))
if err == gws.ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Timeout() {
return nil
}
return err
})
return
}
c.Wg.Add(1)
go func() {
defer c.Wg.Done()
ticker := time.NewTicker(handler.Delay)
for {
select {
case <-c.shutdown:
ticker.Stop()
return
case <-ticker.C:
err := c.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
if err != nil {
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]: %v", c.ExchangeName, handler.Message, err)
return
}
}
}
}()
}
// setConnectedStatus sets connection status if changed it will return true.
// TODO: Swap out these atomic switches and opt for sync.RWMutex.
func (c *connection) setConnectedStatus(b bool) bool {
if b {
return atomic.SwapInt32(&c.connected, 1) == 0
}
return atomic.SwapInt32(&c.connected, 0) == 1
}
// IsConnected exposes websocket connection status
func (c *connection) IsConnected() bool {
return atomic.LoadInt32(&c.connected) == 1
}
// ReadMessage reads messages, can handle text, gzip and binary
func (c *connection) ReadMessage() Response {
mType, resp, err := c.Connection.ReadMessage()
if err != nil {
// If any error occurs, a Response{Raw: nil, Type: 0} is returned, causing the
// reader routine to exit. This leaves the connection without an active reader,
// leading to potential buffer issue from the ongoing websocket writes.
// Such errors are passed to `c.readMessageErrors` when the connection is active.
// The `connectionMonitor` handles these errors by flushing the buffer, reconnecting,
// and resubscribing to the websocket to restore the connection.
if c.setConnectedStatus(false) {
// NOTE: When c.setConnectedStatus() returns true the underlying
// state was changed and this infers that the connection was
// externally closed and an error is reported else Shutdown()
// method on WebsocketConnection type has been called and can
// be skipped.
select {
case c.readMessageErrors <- fmt.Errorf("%w: %w", err, errConnectionFault):
default:
// bypass if there is no receiver, as this stops it returning
// when shutdown is called.
log.Warnf(log.WebsocketMgr, "%s failed to relay error: %v", c.ExchangeName, err)
}
}
return Response{}
}
select {
case c.Traffic <- struct{}{}:
default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding
}
var standardMessage []byte
switch mType {
case gws.TextMessage:
standardMessage = resp
case gws.BinaryMessage:
standardMessage, err = c.parseBinaryResponse(resp)
if err != nil {
log.Errorf(log.WebsocketMgr, "%v %v: Parse binary response error: %v", c.ExchangeName, removeURLQueryString(c.URL), err)
return Response{Raw: []byte(``)} // Non-nil response to avoid the reader returning on this case.
}
}
if c.Verbose {
log.Debugf(log.WebsocketMgr, "%v %v: Message received: %v", c.ExchangeName, removeURLQueryString(c.URL), string(standardMessage))
}
return Response{Raw: standardMessage, Type: mType}
}
// parseBinaryResponse parses a websocket binary response into a usable byte array
func (c *connection) parseBinaryResponse(resp []byte) ([]byte, error) {
var reader io.ReadCloser
var err error
if len(resp) >= 2 && resp[0] == 31 && resp[1] == 139 { // Detect GZIP
reader, err = gzip.NewReader(bytes.NewReader(resp))
if err != nil {
return nil, err
}
} else {
reader = flate.NewReader(bytes.NewReader(resp))
}
standardMessage, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
return standardMessage, reader.Close()
}
// GenerateMessageID generates a message ID for the individual connection.
// If a bespoke function is set (by using SetupNewConnection) it will use that,
// otherwise it will use the defaultGenerateMessageID function.
func (c *connection) GenerateMessageID(highPrec bool) int64 {
if c.bespokeGenerateMessageID != nil {
return c.bespokeGenerateMessageID(highPrec)
}
return c.defaultGenerateMessageID(highPrec)
}
// defaultGenerateMessageID generates the default message ID
func (c *connection) defaultGenerateMessageID(highPrec bool) int64 {
var minValue int64 = 1e8
var maxValue int64 = 2e8
if highPrec {
maxValue = 2e12
minValue = 1e12
}
// utilization of hard coded positive numbers and default crypto/rand
// io.reader will panic on error instead of returning
randomNumber, err := rand.Int(rand.Reader, big.NewInt(maxValue-minValue+1))
if err != nil {
panic(err)
}
return randomNumber.Int64() + minValue
}
// Shutdown shuts down and closes specific connection
func (c *connection) Shutdown() error {
if c == nil || c.Connection == nil {
return nil
}
c.setConnectedStatus(false)
c.writeControl.Lock()
defer c.writeControl.Unlock()
return c.Connection.NetConn().Close()
}
// SetURL sets connection URL
func (c *connection) SetURL(url string) {
c.URL = url
}
// SetProxy sets connection proxy
func (c *connection) SetProxy(proxy string) {
c.ProxyURL = proxy
}
// GetURL returns the connection URL
func (c *connection) GetURL() string {
return c.URL
}
// SendMessageReturnResponse will send a WS message to the connection and wait for response
func (c *connection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) {
resps, err := c.SendMessageReturnResponses(ctx, epl, signature, request, 1)
if err != nil {
return nil, err
}
return resps[0], nil
}
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
func (c *connection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) {
return c.SendMessageReturnResponsesWithInspector(ctx, epl, signature, payload, expected, nil)
}
// SendMessageReturnResponsesWithInspector will send a WS message to the connection and wait for N responses
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
func (c *connection) SendMessageReturnResponsesWithInspector(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector Inspector) ([][]byte, error) {
outbound, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
}
ch, err := c.Match.Set(signature, expected)
if err != nil {
return nil, err
}
start := time.Now()
err = c.SendRawMessage(ctx, epl, gws.TextMessage, outbound)
if err != nil {
return nil, err
}
resps, err := c.waitForResponses(ctx, signature, ch, expected, messageInspector)
if err != nil {
return nil, err
}
if c.Reporter != nil {
c.Reporter.Latency(c.ExchangeName, outbound, time.Since(start))
}
return resps, err
}
// waitForResponses waits for N responses from a channel
func (c *connection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector Inspector) ([][]byte, error) {
timeout := time.NewTimer(c.ResponseMaxLimit * time.Duration(expected))
defer timeout.Stop()
resps := make([][]byte, 0, expected)
inspection:
for range expected {
select {
case resp := <-ch:
resps = append(resps, resp)
// Checks recently received message to determine if this is in fact the final message in a sequence of messages.
if messageInspector != nil && messageInspector.IsFinal(resp) {
c.Match.RemoveSignature(signature)
break inspection
}
case <-timeout.C:
c.Match.RemoveSignature(signature)
return nil, fmt.Errorf("%s %w %v", c.ExchangeName, ErrSignatureTimeout, signature)
case <-ctx.Done():
c.Match.RemoveSignature(signature)
return nil, ctx.Err()
}
}
// Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call.
if request.IsVerbose(ctx, false) {
for i := range resps {
log.Debugf(log.WebsocketMgr, "%v %v: Received response [%d/%d]: %v", c.ExchangeName, removeURLQueryString(c.URL), i+1, len(resps), string(resps[i]))
}
}
return resps, nil
}
func removeURLQueryString(url string) string {
if index := strings.Index(url, "?"); index != -1 {
return url[:index]
}
return url
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,86 @@
package websocket
import (
"errors"
"fmt"
"sync"
)
// ErrSignatureNotMatched is returned when a signature does not match a request
var ErrSignatureNotMatched = errors.New("websocket response to request signature not matched")
var (
errSignatureCollision = errors.New("signature collision")
errInvalidBufferSize = errors.New("buffer size must be positive")
)
// NewMatch returns a new Match
func NewMatch() *Match {
return &Match{m: make(map[any]*incoming)}
}
// Match is a distributed subtype that handles the matching of requests and
// responses in a timely manner, reducing the need to differentiate between
// connections. Stream systems fan in all incoming payloads to one routine for
// processing.
type Match struct {
m map[any]*incoming
mu sync.Mutex
}
type incoming struct {
expected int
c chan<- []byte
}
// IncomingWithData matches with requests and takes in the returned payload, to
// be processed outside of a stream processing routine and returns true if a handler was found
func (m *Match) IncomingWithData(signature any, data []byte) bool {
m.mu.Lock()
defer m.mu.Unlock()
ch, ok := m.m[signature]
if !ok {
return false
}
ch.c <- data
ch.expected--
if ch.expected == 0 {
close(ch.c)
delete(m.m, signature)
}
return true
}
// RequireMatchWithData validates that incoming data matches a request's signature.
// If a match is found, the data is processed; otherwise, it returns an error.
func (m *Match) RequireMatchWithData(signature any, data []byte) error {
if m.IncomingWithData(signature, data) {
return nil
}
return fmt.Errorf("'%v' %w with data %v", signature, ErrSignatureNotMatched, string(data))
}
// Set the signature response channel for incoming data
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
if bufSize <= 0 {
return nil, errInvalidBufferSize
}
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.m[signature]; ok {
return nil, errSignatureCollision
}
ch := make(chan []byte, bufSize)
m.m[signature] = &incoming{expected: bufSize, c: ch}
return ch, nil
}
// RemoveSignature removes the signature response from map and closes the channel.
func (m *Match) RemoveSignature(signature any) {
m.mu.Lock()
defer m.mu.Unlock()
if ch, ok := m.m[signature]; ok {
close(ch.c)
delete(m.m, signature)
}
}

View File

@@ -0,0 +1,68 @@
package websocket
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMatch(t *testing.T) {
t.Parallel()
load := []byte("42")
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
match := NewMatch()
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
_, err := match.Set("hello", 0)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
_, err = match.Set("hello", -1)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
ch, err := match.Set("hello", 2)
require.NoError(t, err, "Set must not error")
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
assert.Equal(t, "hello", string(<-ch))
_, err = match.Set("hello", 2)
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
}
func TestRemoveSignature(t *testing.T) {
t.Parallel()
match := NewMatch()
ch, err := match.Set("masterblaster", 1)
select {
case <-ch:
t.Fatal("Should not be able to read from an empty channel")
default:
}
require.NoError(t, err)
match.RemoveSignature("masterblaster")
select {
case garbage := <-ch:
require.Empty(t, garbage)
default:
t.Fatal("Should be able to read from a closed channel")
}
}
func TestRequireMatchWithData(t *testing.T) {
t.Parallel()
match := NewMatch()
err := match.RequireMatchWithData("hello", []byte("world"))
require.ErrorIs(t, err, ErrSignatureNotMatched, "Must error on unmatched signature")
assert.Contains(t, err.Error(), "world", "Should contain the data in the error message")
assert.Contains(t, err.Error(), "hello", "Should contain the signature in the error message")
ch, err := match.Set("hello", 1)
require.NoError(t, err, "Set must not error")
err = match.RequireMatchWithData("hello", []byte("world"))
require.NoError(t, err, "Must not error on matched signature")
assert.Equal(t, "world", string(<-ch))
}

View File

@@ -0,0 +1,349 @@
package websocket
import (
"context"
"errors"
"fmt"
"slices"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/log"
)
// Public subscription errors
var (
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionsNotAdded = errors.New("subscriptions not added")
ErrSubscriptionsNotRemoved = errors.New("subscriptions not removed")
)
// Public subscription errors
var (
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
)
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (m *Manager) UnsubscribeChannels(conn Connection, channels subscription.List) error {
if len(channels) == 0 {
return nil // No channels to unsubscribe from is not an error
}
if wrapper, ok := m.connections[conn]; ok && conn != nil {
return m.unsubscribe(wrapper.subscriptions, channels, func(channels subscription.List) error {
return wrapper.setup.Unsubscriber(context.TODO(), conn, channels)
})
}
if m.Unsubscriber == nil {
return fmt.Errorf("%w: Global Unsubscriber not set", common.ErrNilPointer)
}
return m.unsubscribe(m.subscriptions, channels, func(channels subscription.List) error {
return m.Unsubscriber(channels)
})
}
func (m *Manager) unsubscribe(store *subscription.Store, channels subscription.List, unsub func(channels subscription.List) error) error {
if store == nil {
return nil // No channels to unsubscribe from is not an error
}
for _, s := range channels {
if store.Get(s) == nil {
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
}
}
return unsub(channels)
}
// ResubscribeToChannel resubscribes to channel
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
// Errors if subscription is already subscribing
func (m *Manager) ResubscribeToChannel(conn Connection, s *subscription.Subscription) error {
l := subscription.List{s}
if err := s.SetState(subscription.ResubscribingState); err != nil {
return fmt.Errorf("%w: %s", err, s)
}
if err := m.UnsubscribeChannels(conn, l); err != nil {
return err
}
return m.SubscribeToChannels(conn, l)
}
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (m *Manager) SubscribeToChannels(conn Connection, subs subscription.List) error {
if slices.Contains(subs, nil) {
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
}
if err := m.checkSubscriptions(conn, subs); err != nil {
return err
}
if wrapper, ok := m.connections[conn]; ok && conn != nil {
return wrapper.setup.Subscriber(context.TODO(), conn, subs)
}
if m.Subscriber == nil {
return fmt.Errorf("%w: Global Subscriber not set", common.ErrNilPointer)
}
if err := m.Subscriber(subs); err != nil {
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
}
return nil
}
// AddSubscriptions adds subscriptions to the subscription store
// Sets state to Subscribing unless the state is already set
func (m *Manager) AddSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
if m == nil {
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
}
var subscriptionStore **subscription.Store
if wrapper, ok := m.connections[conn]; ok && conn != nil {
subscriptionStore = &wrapper.subscriptions
} else {
subscriptionStore = &m.subscriptions
}
if *subscriptionStore == nil {
*subscriptionStore = subscription.NewStore()
}
var errs error
for _, s := range subs {
if s.State() == subscription.InactiveState {
if err := s.SetState(subscription.SubscribingState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
}
if err := (*subscriptionStore).Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
func (m *Manager) AddSuccessfulSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
if m == nil {
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
}
var subscriptionStore **subscription.Store
if wrapper, ok := m.connections[conn]; ok && conn != nil {
subscriptionStore = &wrapper.subscriptions
} else {
subscriptionStore = &m.subscriptions
}
if *subscriptionStore == nil {
*subscriptionStore = subscription.NewStore()
}
var errs error
for _, s := range subs {
if err := s.SetState(subscription.SubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := (*subscriptionStore).Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
func (m *Manager) RemoveSubscriptions(conn Connection, subs ...*subscription.Subscription) error {
if m == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
}
var subscriptionStore *subscription.Store
if wrapper, ok := m.connections[conn]; ok && conn != nil {
subscriptionStore = wrapper.subscriptions
} else {
subscriptionStore = m.subscriptions
}
if subscriptionStore == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
}
var errs error
for _, s := range subs {
if err := s.SetState(subscription.UnsubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := subscriptionStore.Remove(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
func (m *Manager) GetSubscription(key any) *subscription.Subscription {
if m == nil || key == nil {
return nil
}
for _, c := range m.connectionManager {
if c.subscriptions == nil {
continue
}
sub := c.subscriptions.Get(key)
if sub != nil {
return sub
}
}
if m.subscriptions == nil {
return nil
}
return m.subscriptions.Get(key)
}
// GetSubscriptions returns a new slice of the subscriptions
func (m *Manager) GetSubscriptions() subscription.List {
if m == nil {
return nil
}
var subs subscription.List
for _, c := range m.connectionManager {
if c.subscriptions != nil {
subs = append(subs, c.subscriptions.List()...)
}
}
if m.subscriptions != nil {
subs = append(subs, m.subscriptions.List()...)
}
return subs
}
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
// The subscription state is not considered when counting existing subscriptions
func (m *Manager) checkSubscriptions(conn Connection, subs subscription.List) error {
var subscriptionStore *subscription.Store
if wrapper, ok := m.connections[conn]; ok && conn != nil {
subscriptionStore = wrapper.subscriptions
} else {
subscriptionStore = m.subscriptions
}
if subscriptionStore == nil {
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
}
existing := subscriptionStore.Len()
if m.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > m.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
existing,
len(subs),
m.MaxSubscriptionsPerConnection)
}
for _, s := range subs {
if s.State() == subscription.ResubscribingState {
continue
}
if found := subscriptionStore.Get(s); found != nil {
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
}
}
return nil
}
// FlushChannels flushes channel subscriptions when there is a pair/asset change
func (m *Manager) FlushChannels() error {
if !m.IsEnabled() {
return fmt.Errorf("%s %w", m.exchangeName, ErrWebsocketNotEnabled)
}
if !m.IsConnected() {
return fmt.Errorf("%s %w", m.exchangeName, ErrNotConnected)
}
// If the exchange does not support subscribing and or unsubscribing the full connection needs to be flushed to
// maintain consistency.
if !m.features.Subscribe || !m.features.Unsubscribe {
m.m.Lock()
defer m.m.Unlock()
if err := m.shutdown(); err != nil {
return err
}
return m.connect()
}
if !m.useMultiConnectionManagement {
newSubs, err := m.GenerateSubs()
if err != nil {
return err
}
return m.updateChannelSubscriptions(nil, m.subscriptions, newSubs)
}
for x := range m.connectionManager {
newSubs, err := m.connectionManager[x].setup.GenerateSubscriptions()
if err != nil {
return err
}
// Case if there is nothing to unsubscribe from and the connection is nil
if len(newSubs) == 0 && m.connectionManager[x].connection == nil {
continue
}
// If there are subscriptions to subscribe to but no connection to subscribe to, establish a new connection.
if m.connectionManager[x].connection == nil {
conn := m.getConnectionFromSetup(m.connectionManager[x].setup)
if err := m.connectionManager[x].setup.Connector(context.TODO(), conn); err != nil {
return err
}
m.Wg.Add(1)
go m.Reader(context.TODO(), conn, m.connectionManager[x].setup.Handler)
m.connections[conn] = m.connectionManager[x]
m.connectionManager[x].connection = conn
}
err = m.updateChannelSubscriptions(m.connectionManager[x].connection, m.connectionManager[x].subscriptions, newSubs)
if err != nil {
return err
}
// If there are no subscriptions to subscribe to, close the connection as it is no longer needed.
if m.connectionManager[x].subscriptions.Len() == 0 {
delete(m.connections, m.connectionManager[x].connection) // Remove from lookup map
if err := m.connectionManager[x].connection.Shutdown(); err != nil {
log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", m.exchangeName, err)
}
m.connectionManager[x].connection = nil
}
}
return nil
}
// updateChannelSubscriptions subscribes or unsubscribes from channels and checks that the correct number of channels
// have been subscribed to or unsubscribed from.
func (m *Manager) updateChannelSubscriptions(c Connection, store *subscription.Store, incoming subscription.List) error {
subs, unsubs := store.Diff(incoming)
if len(unsubs) != 0 {
if err := m.UnsubscribeChannels(c, unsubs); err != nil {
return err
}
if contained := store.Contained(unsubs); len(contained) > 0 {
return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotRemoved, contained)
}
}
if len(subs) != 0 {
if err := m.SubscribeToChannels(c, subs); err != nil {
return err
}
if missing := store.Missing(subs); len(missing) > 0 {
return fmt.Errorf("%v %w `%s`", m.exchangeName, ErrSubscriptionsNotAdded, missing)
}
}
return nil
}

View File

@@ -0,0 +1,305 @@
package websocket
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
)
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewManager()
assert.NoError(t, ws.Setup(newDefaultSetup()), "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
require.NoError(t, err, "Generating test subscriptions must not error")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, ws.UnsubscribeChannels(nil, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(nil, subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, ws.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
assert.Equal(t, "TestSub3", ws.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, ws.GetSubscription(42), "GetSubscription by int key should find a channel") {
assert.Equal(t, "TestSub4", ws.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, ws.SubscribeToChannels(nil, nil), "Subscribe to an nil List should not error")
assert.NoError(t, ws.UnsubscribeChannels(nil, subs), "Unsubscribing should not error")
ws.Subscriber = func(subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, ws.SubscribeToChannels(nil, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = ws.SubscribeToChannels(nil, subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
multi := NewManager()
set := newDefaultSetup()
set.UseMultiConnectionManagement = true
assert.NoError(t, multi.Setup(set))
amazingCandidate := &ConnectionSetup{
URL: "AMAZING",
Connector: func(context.Context, Connection) error { return nil },
GenerateSubscriptions: ws.GenerateSubs,
Subscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleSubConn(multi)(ctx, c, s)
},
Unsubscriber: func(ctx context.Context, c Connection, s subscription.List) error {
return currySimpleUnsubConn(multi)(ctx, c, s)
},
Handler: func(context.Context, []byte) error { return nil },
}
require.NoError(t, multi.SetupNewConnection(amazingCandidate))
amazingConn := multi.getConnectionFromSetup(amazingCandidate)
multi.connections = map[Connection]*connectionWrapper{
amazingConn: multi.connectionManager[0],
}
subs, err = amazingCandidate.GenerateSubscriptions()
require.NoError(t, err, "Generating test subscriptions must not error")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(nil, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.ErrorIs(t, new(Manager).UnsubscribeChannels(amazingConn, subs), common.ErrNilPointer, "Should error when unsubscribing with nil unsubscribe function")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, multi.UnsubscribeChannels(amazingConn, subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, multi.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.ErrorIs(t, multi.SubscribeToChannels(nil, subs), common.ErrNilPointer, "If no connection is set, Subscribe should error")
assert.NoError(t, multi.SubscribeToChannels(amazingConn, subs), "Basic Subscribing should not error")
assert.Len(t, multi.GetSubscriptions(), 4, "Should have 4 subscriptions")
bySub = multi.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, multi.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", multi.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, multi.GetSubscription(testSubKey{"mauve"}), "GetSubscription by type key should find a channel") {
assert.Equal(t, "TestSub3", multi.GetSubscription(testSubKey{"mauve"}).Channel, "GetSubscription by type key should return a pointer a copy of the right channel")
}
if assert.NotNil(t, multi.GetSubscription(42), "GetSubscription by int key should find a channel") {
assert.Equal(t, "TestSub4", multi.GetSubscription(42).Channel, "GetSubscription by int key should return a pointer a copy of the right channel")
}
assert.Nil(t, multi.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, multi.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, multi.SubscribeToChannels(amazingConn, nil), "Subscribe to an nil List should not error")
assert.NoError(t, multi.UnsubscribeChannels(amazingConn, subs), "Unsubscribing should not error")
amazingCandidate.Subscriber = func(context.Context, Connection, subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, multi.SubscribeToChannels(amazingConn, subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = multi.SubscribeToChannels(amazingConn, subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
}
// TestResubscribe tests Resubscribing to existing subscriptions
func TestResubscribe(t *testing.T) {
t.Parallel()
ws := NewManager()
wackedOutSetup := newDefaultSetup()
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
err := ws.Setup(wackedOutSetup)
assert.ErrorIs(t, err, errInvalidMaxSubscriptions, "Invalid MaxWebsocketSubscriptionsPerConnection should error")
err = ws.Setup(newDefaultSetup())
assert.NoError(t, err, "WS Setup should not error")
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
channel := subscription.List{{Channel: "resubTest"}}
assert.ErrorIs(t, ws.ResubscribeToChannel(nil, channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(nil, channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(nil, channel[0]), "Resubscribe should not error now the channel is subscribed")
}
// TestSubscriptions tests adding, getting and removing subscriptions
func TestSubscriptions(t *testing.T) {
t.Parallel()
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Manager)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSubscriptions(nil, s), subscription.ErrDuplicate, "Adding same subscription should return error")
assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing")
err := w.RemoveSubscriptions(nil, s)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed")
require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error")
require.NoError(t, w.AddSubscriptions(nil, s), "Adding first subscription must not error")
assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state")
}
// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions
func TestSuccessfulSubscriptions(t *testing.T) {
t.Parallel()
w := new(Manager) // Do not use NewManager; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Manager)(nil).AddSuccessfulSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSuccessfulSubscriptions(nil, c), "Adding first subscription must not error")
assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrInStateAlready, "Adding subscription in same state should return error")
require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(nil, c), subscription.ErrDuplicate, "Adding same subscription should return error")
err := w.RemoveSubscriptions(nil, c)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), subscription.ErrNotFound, "Should error correctly when not found")
assert.ErrorIs(t, (*Manager)(nil).RemoveSubscriptions(nil, nil), common.ErrNilPointer, "Should error correctly when nil websocket")
w.subscriptions = nil
assert.ErrorIs(t, w.RemoveSubscriptions(nil, c), common.ErrNilPointer, "Should error correctly when nil websocket")
}
// TestGetSubscription logic test
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Manager).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Manager{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := NewManager()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
require.NoError(t, w.AddSubscriptions(nil, s), "AddSubscriptions must not error")
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}
// TestGetSubscriptions logic test
func TestGetSubscriptions(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Manager).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Manager{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
w := NewManager()
s := subscription.List{
{Key: 42, Channel: "hello3"},
{Key: 45, Channel: "hello4"},
}
err := w.AddSubscriptions(nil, s...)
require.NoError(t, err, "AddSubscriptions must not error")
assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details")
}
func TestCheckSubscriptions(t *testing.T) {
t.Parallel()
ws := Manager{}
err := ws.checkSubscriptions(nil, nil)
assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions")
assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil, nil)
assert.NoError(t, err, "checkSubscriptions should not error on a nil list")
ws.MaxSubscriptionsPerConnection = 1
err = ws.checkSubscriptions(nil, subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil, subscription.List{{}, {}})
assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly")
ws.MaxSubscriptionsPerConnection = 2
ws.subscriptions = subscription.NewStore()
err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"})
require.NoError(t, err, "Add subscription must not error")
err = ws.checkSubscriptions(nil, subscription.List{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly")
err = ws.checkSubscriptions(nil, subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error")
}
func TestUpdateChannelSubscriptions(t *testing.T) {
t.Parallel()
ws := NewManager()
store := subscription.NewStore()
err := ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.ErrorIs(t, err, common.ErrNilPointer)
require.Zero(t, store.Len())
ws.Subscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Add(sub); err != nil {
return err
}
}
return nil
}
ws.subscriptions = store
err = ws.updateChannelSubscriptions(nil, store, subscription.List{{Channel: "test"}})
require.NoError(t, err)
require.Equal(t, 1, store.Len())
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.ErrorIs(t, err, common.ErrNilPointer)
ws.Unsubscriber = func(subs subscription.List) error {
for _, sub := range subs {
if err := store.Remove(sub); err != nil {
return err
}
}
return nil
}
err = ws.updateChannelSubscriptions(nil, store, subscription.List{})
require.NoError(t, err)
require.Zero(t, store.Len())
}
func currySimpleSub(w *Manager) func(subscription.List) error {
return func(subs subscription.List) error {
return w.AddSuccessfulSubscriptions(nil, subs...)
}
}
func currySimpleSubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
return func(_ context.Context, conn Connection, subs subscription.List) error {
return w.AddSuccessfulSubscriptions(conn, subs...)
}
}
func currySimpleUnsub(w *Manager) func(subscription.List) error {
return func(unsubs subscription.List) error {
return w.RemoveSubscriptions(nil, unsubs...)
}
}
func currySimpleUnsubConn(w *Manager) func(context.Context, Connection, subscription.List) error {
return func(_ context.Context, conn Connection, unsubs subscription.List) error {
return w.RemoveSubscriptions(conn, unsubs...)
}
}

View File

@@ -0,0 +1,56 @@
package websocket
import (
"time"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
)
// PingHandler container for ping handler settings
type PingHandler struct {
Websocket bool
UseGorillaHandler bool
MessageType int
Message []byte
Delay time.Duration
}
// FundingData defines funding data
type FundingData struct {
Timestamp time.Time
CurrencyPair currency.Pair
AssetType asset.Item
Exchange string
Amount float64
Rate float64
Period int64
Side order.Side
}
// KlineData defines kline feed
type KlineData struct {
Timestamp time.Time
Pair currency.Pair
AssetType asset.Item
Exchange string
StartTime time.Time
CloseTime time.Time
Interval string
OpenPrice float64
ClosePrice float64
HighPrice float64
LowPrice float64
Volume float64
}
// UnhandledMessageWarning defines a container for unhandled message warnings
type UnhandledMessageWarning struct {
Message string
}
// Reporter interface groups observability functionality over Websocket request latency.
type Reporter interface {
Latency(name string, message []byte, t time.Duration)
}

View File

@@ -3,7 +3,7 @@ package exchange
import (
"testing"
"github.com/gorilla/websocket"
gws "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/config"
@@ -31,6 +31,6 @@ func TestMockHTTPInstance(t *testing.T) {
// TestMockWsInstance exercises MockWsInstance
func TestMockWsInstance(t *testing.T) {
b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *websocket.Conn) error { return nil }))
b := MockWsInstance[binance.Binance](t, mockws.CurryWsMockUpgrader(t, func(_ testing.TB, _ []byte, _ *gws.Conn) error { return nil }))
require.NotNil(t, b, "MockWsInstance must not be nil")
}