subscriptions: Encapsulate, replace Pair with Pairs and refactor; improve exchange support

* Websocket: Use ErrSubscribedAlready

instead of errChannelAlreadySubscribed

* Subscriptions: Replace Pair with Pairs

Given that some subscriptions have multiple pairs, support that as the
standard.

* Docs: Update subscriptions in add new exch

* RPC: Update Subscription Pairs

* Linter: Disable testifylint.Len

We deliberately use Equal over Len to avoid spamming the contents of large Slices

* Websocket: Add suffix to state consts

* Binance: Subscription Pairs support

* Bitfinex: Subscription Pairs support

* Bithumb: Subscription Pairs support

* Bitmex: Subscription Pairs support

* Bitstamp: Subscription Pairs support

* BTCMarkets: Subscription Pairs support

* BTSE: Subscription Pairs support

* Coinbase: Subscription Pairs support

* Coinut: Subscription Pairs support

* GateIO: Subscription Pairs support

* Gemini: Subscription Pairs support and improvement

* Hitbtc: Subscription Pairs support

* Huboi: Subscription Pairs support

* Kucoin: Subscription Pairs support

* Okcoin: Subscription Pairs support

* Poloniex: Subscription Pairs support

* Kraken: Add subscription Pairs support

Note: This is a naieve implementation because we want to rebase the
kraken websocket rewrite on top of this

* Bybit: Subscription Pairs support

* Okx: Subscription Pairs support

* Bitmex: Subsription configuration

* Fixes unauthenticated websocket left as CanUseAuth
* Fixes auth subs happening privately

* CoinbasePro: Subscription Configuration

* Consolidate ProductIDs when all subscriptions are for the same list

* Websocket: Log actual sent message when Verbose

* Subscriptions: Improve clarity of which key is which in Match

* Subscriptions: Lint fix for HugeParam

* Subscriptions: Add AddPairs and move keys from test

* Subscriptions: Simplify subscription keys and add key types

* Subscriptions: Add List.GroupPairs Rename sub.AddPairs

* Subscription: Fix ExactKey not matching 0 pairs

* Subscriptions: Remove unused IdentityKey and HasPairKey

* Subscriptions: Fix GetKey test

* Subscriptions: Test coverage improvements

* Websocket: Change State on Add/Remove

* Subscriptions: Improve error context

* Subscriptions: Fix Enable: false subs not ignored

* Bitfinex: Fix WsAuth test failing on DataHandler

DataHandler is eaten by dataMonitor now, so we need to use ToRoutine

* Deribit: Subscription Pairs support

* Websocket: Accept nil lists for checkSubscriptions

If the user passes in a nil (implicitly empty) list, we would not panic.
Therefore the burden of correctness about that data lies with them.
The list of subscriptions is empty, and that's okay, and possibly
convenient

* Websocket: Add context to NilPointer errors

* Subscriptions: Add context to nil errors

* Exchange: Fix error expectations in UnsubToWSChans
This commit is contained in:
Gareth Kirwan
2024-06-07 08:54:08 +07:00
committed by GitHub
parent afb6f75d88
commit 1199f38546
75 changed files with 4551 additions and 3891 deletions

View File

@@ -43,9 +43,8 @@ const (
Options
OptionCombo
FutureCombo
// Added to represent a USDT and USDC based linear derivatives(futures/perpetual) assets in Bybit V5.
LinearContract
LinearContract // Added to represent a USDT and USDC based linear derivatives(futures/perpetual) assets in Bybit V5
All
optionsFlag = OptionCombo | Options
futuresFlag = PerpetualContract | PerpetualSwap | Futures | DeliveryFutures | UpsideProfitContract | DownsideProfitContract | CoinMarginedFutures | USDTMarginedFutures | USDCMarginedFutures | LinearContract | FutureCombo
@@ -70,6 +69,7 @@ const (
options = "options"
optionCombo = "option_combo"
futureCombo = "future_combo"
all = "all"
)
var (
@@ -120,6 +120,8 @@ func (a Item) String() string {
return optionCombo
case FutureCombo:
return futureCombo
case All:
return all
default:
return ""
}
@@ -225,11 +227,10 @@ func New(input string) (Item, error) {
return OptionCombo, nil
case futureCombo:
return FutureCombo, nil
case all:
return All, nil
default:
return 0, fmt.Errorf("%w '%v', only supports %s",
ErrNotSupported,
input,
supportedList)
return 0, fmt.Errorf("%w '%v', only supports %s", ErrNotSupported, input, supportedList)
}
}

View File

@@ -1462,7 +1462,7 @@ func TestGetHistoricTrades(t *testing.T) {
if mockTests {
expected = 1002
}
assert.Equal(t, expected, len(result), "GetHistoricTrades should return correct number of entries") //nolint:testifylint // assert.Len doesn't produce clear messages on result
assert.Equal(t, expected, len(result), "GetHistoricTrades should return correct number of entries")
for _, r := range result {
if !assert.WithinRange(t, r.Timestamp, start, end, "All trades should be within time range") {
break
@@ -1982,7 +1982,7 @@ func BenchmarkWsHandleData(bb *testing.B) {
func TestSubscribe(t *testing.T) {
t.Parallel()
b := b
channels := []subscription.Subscription{
channels := subscription.List{
{Channel: "btcusdt@ticker"},
{Channel: "btcusdt@trade"},
}
@@ -2008,7 +2008,7 @@ func TestSubscribe(t *testing.T) {
func TestSubscribeBadResp(t *testing.T) {
t.Parallel()
channels := []subscription.Subscription{
channels := subscription.List{
{Channel: "moons@ticker"},
}
mock := func(msg []byte, w *websocket.Conn) error {
@@ -2434,19 +2434,19 @@ func TestSeedLocalCache(t *testing.T) {
func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
expected := []subscription.Subscription{}
expected := subscription.List{}
pairs, err := b.GetEnabledPairs(asset.Spot)
assert.NoError(t, err, "GetEnabledPairs should not error")
for _, p := range pairs {
for _, c := range []string{"kline_1m", "depth@100ms", "ticker", "trade"} {
expected = append(expected, subscription.Subscription{
expected = append(expected, &subscription.Subscription{
Channel: p.Format(currency.PairFormat{Delimiter: "", Uppercase: false}).String() + "@" + c,
Pair: p,
Pairs: currency.Pairs{p},
Asset: asset.Spot,
})
}
}
subs, err := b.GenerateSubscriptions()
subs, err := b.generateSubscriptions()
assert.NoError(t, err, "GenerateSubscriptions should not error")
if assert.Len(t, subs, len(expected), "Should have the correct number of subs") {
assert.ElementsMatch(t, subs, expected, "Should get the correct subscriptions")

View File

@@ -12,6 +12,7 @@ import (
"github.com/buger/jsonparser"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
@@ -502,8 +503,8 @@ func (b *Binance) UpdateLocalBuffer(wsdp *WebsocketDepthStream) (bool, error) {
return false, err
}
// GenerateSubscriptions generates the default subscription set
func (b *Binance) GenerateSubscriptions() ([]subscription.Subscription, error) {
// generateSubscriptions generates the default subscription set
func (b *Binance) generateSubscriptions() (subscription.List, error) {
var channels = make([]string, 0, len(b.Features.Subscriptions))
for i := range b.Features.Subscriptions {
name, err := channelName(b.Features.Subscriptions[i])
@@ -512,7 +513,7 @@ func (b *Binance) GenerateSubscriptions() ([]subscription.Subscription, error) {
}
channels = append(channels, name)
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
pairs, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
@@ -521,9 +522,9 @@ func (b *Binance) GenerateSubscriptions() ([]subscription.Subscription, error) {
for z := range channels {
lp := pairs[y].Lower()
lp.Delimiter = ""
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: lp.String() + "@" + channels[z],
Pair: pairs[y],
Pairs: currency.Pairs{pairs[y]},
Asset: asset.Spot,
})
}
@@ -555,22 +556,21 @@ func channelName(s *subscription.Subscription) (string, error) {
}
// Subscribe subscribes to a set of channels
func (b *Binance) Subscribe(channels []subscription.Subscription) error {
func (b *Binance) Subscribe(channels subscription.List) error {
return b.ParallelChanOp(channels, b.subscribeToChan, 50)
}
// subscribeToChan handles a single subscription and parses the result
// on success it adds the subscription to the websocket
func (b *Binance) subscribeToChan(chans []subscription.Subscription) error {
func (b *Binance) subscribeToChan(chans subscription.List) error {
id := b.Websocket.Conn.GenerateMessageID(false)
cNames := make([]string, len(chans))
for i := range chans {
c := chans[i]
cNames[i] = c.Channel
c.State = subscription.SubscribingState
if err := b.Websocket.AddSubscription(&c); err != nil {
return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pair, err)
if err := b.Websocket.AddSubscriptions(c); err != nil {
return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err)
}
}
@@ -590,23 +590,29 @@ func (b *Binance) subscribeToChan(chans []subscription.Subscription) error {
}
if err != nil {
b.Websocket.RemoveSubscriptions(chans...)
if err2 := b.Websocket.RemoveSubscriptions(chans...); err2 != nil {
err = common.AppendError(err, err2)
}
err = fmt.Errorf("%w: %w; Channels: %s", stream.ErrSubscriptionFailure, err, strings.Join(cNames, ", "))
b.Websocket.DataHandler <- err
} else {
b.Websocket.AddSuccessfulSubscriptions(chans...)
for _, s := range chans {
if sErr := s.SetState(subscription.SubscribedState); sErr != nil {
err = common.AppendError(err, sErr)
}
}
}
return err
}
// Unsubscribe unsubscribes from a set of channels
func (b *Binance) Unsubscribe(channels []subscription.Subscription) error {
func (b *Binance) Unsubscribe(channels subscription.List) error {
return b.ParallelChanOp(channels, b.unsubscribeFromChan, 50)
}
// unsubscribeFromChan sends a websocket message to stop receiving data from a channel
func (b *Binance) unsubscribeFromChan(chans []subscription.Subscription) error {
func (b *Binance) unsubscribeFromChan(chans subscription.List) error {
id := b.Websocket.Conn.GenerateMessageID(false)
cNames := make([]string, len(chans))
@@ -633,10 +639,10 @@ func (b *Binance) unsubscribeFromChan(chans []subscription.Subscription) error {
err = fmt.Errorf("%w: %w; Channels: %s", stream.ErrUnsubscribeFailure, err, strings.Join(cNames, ", "))
b.Websocket.DataHandler <- err
} else {
b.Websocket.RemoveSubscriptions(chans...)
err = b.Websocket.RemoveSubscriptions(chans...)
}
return nil
return err
}
// ProcessUpdate processes the websocket orderbook update

View File

@@ -188,7 +188,7 @@ func (b *Binance) SetDefaults() {
GlobalResultLimit: 1000,
},
},
Subscriptions: []*subscription.Subscription{
Subscriptions: subscription.List{
{Enabled: true, Channel: subscription.TickerChannel},
{Enabled: true, Channel: subscription.AllTradesChannel},
{Enabled: true, Channel: subscription.CandlesChannel, Interval: kline.OneMin},
@@ -245,7 +245,7 @@ func (b *Binance) Setup(exch *config.Exchange) error {
Connector: b.WsConnect,
Subscriber: b.Subscribe,
Unsubscriber: b.Unsubscribe,
GenerateSubscriptions: b.GenerateSubscriptions,
GenerateSubscriptions: b.generateSubscriptions,
Features: &b.Features.Supports.WebsocketCapabilities,
OrderbookBufferConfig: buffer.Config{
SortBuffer: true,

View File

@@ -540,9 +540,9 @@ func (bi *Binanceus) UpdateLocalBuffer(wsdp *WebsocketDepthStream) (bool, error)
}
// GenerateSubscriptions generates the default subscription set
func (bi *Binanceus) GenerateSubscriptions() ([]subscription.Subscription, error) {
func (bi *Binanceus) GenerateSubscriptions() (subscription.List, error) {
var channels = []string{"@ticker", "@trade", "@kline_1m", "@depth@100ms"}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
pairs, err := bi.GetEnabledPairs(asset.Spot)
if err != nil {
@@ -558,9 +558,9 @@ subs:
log.Warnf(log.WebsocketMgr, "BinanceUS has 1024 subscription limit, only subscribing within limit. Requested %v", len(pairs)*len(channels))
break subs
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: lp.String() + channels[z],
Pair: pairs[y],
Pairs: currency.Pairs{pairs[y]},
Asset: asset.Spot,
})
}
@@ -570,7 +570,7 @@ subs:
}
// Subscribe subscribes to a set of channels
func (bi *Binanceus) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (bi *Binanceus) Subscribe(channelsToSubscribe subscription.List) error {
payload := WebsocketPayload{
Method: "SUBSCRIBE",
}
@@ -590,12 +590,11 @@ func (bi *Binanceus) Subscribe(channelsToSubscribe []subscription.Subscription)
return err
}
}
bi.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
return nil
return bi.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
}
// Unsubscribe unsubscribes from a set of channels
func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe subscription.List) error {
payload := WebsocketPayload{
Method: "UNSUBSCRIBE",
}
@@ -615,8 +614,7 @@ func (bi *Binanceus) Unsubscribe(channelsToUnsubscribe []subscription.Subscripti
return err
}
}
bi.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
return nil
return bi.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
}
func (bi *Binanceus) setupOrderbookManager() {

File diff suppressed because one or more lines are too long

View File

@@ -106,10 +106,7 @@ func (b *Bitfinex) WsDataHandler() {
select {
case b.Websocket.DataHandler <- err:
default:
log.Errorf(log.WebsocketMgr,
"%s websocket handle data error: %v",
b.Name,
err)
log.Errorf(log.WebsocketMgr, "%s websocket handle data error: %v", b.Name, err)
}
}
default:
@@ -149,7 +146,7 @@ func (b *Bitfinex) wsHandleData(respRaw []byte) error {
return b.handleWSChannelUpdate(c, eventType, d)
}
if b.Verbose {
log.Warnf(log.ExchangeSys, "%s %s; dropped WS message: %s", b.Name, stream.ErrSubscriptionNotFound, respRaw)
log.Warnf(log.ExchangeSys, "%s %s; dropped WS message: %s", b.Name, subscription.ErrNotFound, respRaw)
}
// We didn't have a mapping for this chanID; This probably means we have unsubscribed OR
// received our first message before processing the sub chanID
@@ -501,22 +498,25 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error {
c := b.Websocket.GetSubscription(subID)
if c == nil {
return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, stream.ErrSubscriptionNotFound, subID)
return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, subscription.ErrNotFound, subID)
}
chanID, err := jsonparser.GetInt(respRaw, "chanId")
if err != nil {
return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pair)
return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pairs)
}
// Note: chanID's int type avoids conflicts with the string type subID key because of the type difference
c = c.Clone()
c.Key = int(chanID)
// subscribeToChan removes the old subID keyed Subscription
b.Websocket.AddSuccessfulSubscriptions(*c)
if err := b.Websocket.AddSuccessfulSubscriptions(c); err != nil {
return fmt.Errorf("%w: %w subID: %s", stream.ErrSubscriptionFailure, err, subID)
}
if b.Verbose {
log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pair, chanID)
log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID)
}
if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) {
return fmt.Errorf("%v channel subscribe listener not found", subID)
@@ -525,6 +525,10 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error {
}
func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType string, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if eventType == wsChecksum {
return b.handleWSChecksum(c, d)
}
@@ -533,6 +537,10 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType
return nil
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
switch c.Channel {
case wsBook:
return b.handleWSBookUpdate(c, d)
@@ -548,6 +556,9 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType
}
func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
var token int
if f, ok := d[2].(float64); !ok {
return common.GetTypeAssertError("float64", d[2], "checksum")
@@ -579,6 +590,12 @@ func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{
}
func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
var newOrderbook []WebsocketBook
obSnapBundle, ok := d[1].([]interface{})
if !ok {
@@ -632,7 +649,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac
Amount: rateAmount})
}
}
if err := b.WsInsertSnapshot(c.Pair, c.Asset, newOrderbook, fundingRate); err != nil {
if err := b.WsInsertSnapshot(c.Pairs[0], c.Asset, newOrderbook, fundingRate); err != nil {
return fmt.Errorf("inserting snapshot error: %s",
err)
}
@@ -664,7 +681,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac
Amount: amountRate})
}
if err := b.WsUpdateOrderbook(c, c.Pair, c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil {
if err := b.WsUpdateOrderbook(c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil {
return fmt.Errorf("updating orderbook error: %s",
err)
}
@@ -674,6 +691,12 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac
}
func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
candleBundle, ok := d[1].([]interface{})
if !ok || len(candleBundle) == 0 {
return nil
@@ -712,7 +735,7 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf
}
klineData.Exchange = b.Name
klineData.AssetType = c.Asset
klineData.Pair = c.Pair
klineData.Pair = c.Pairs[0]
b.Websocket.DataHandler <- klineData
}
case float64:
@@ -741,13 +764,19 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf
}
klineData.Exchange = b.Name
klineData.AssetType = c.Asset
klineData.Pair = c.Pair
klineData.Pair = c.Pairs[0]
b.Websocket.DataHandler <- klineData
}
return nil
}
func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
tickerData, ok := d[1].([]interface{})
if !ok {
return errors.New("type assertion for tickerData")
@@ -755,7 +784,7 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf
t := &ticker.Price{
AssetType: c.Asset,
Pair: c.Pair,
Pair: c.Pairs[0],
ExchangeName: b.Name,
}
@@ -821,6 +850,12 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf
}
func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType string, d []interface{}) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
if !b.IsSaveTradeDataEnabled() {
return nil
}
@@ -936,7 +971,7 @@ func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType
}
trades[i] = trade.Data{
TID: strconv.FormatInt(tradeHolder[i].ID, 10),
CurrencyPair: c.Pair,
CurrencyPair: c.Pairs[0],
Timestamp: time.UnixMilli(tradeHolder[i].Timestamp),
Price: price,
Amount: newAmount,
@@ -1510,6 +1545,12 @@ func (b *Bitfinex) WsInsertSnapshot(p currency.Pair, assetType asset.Item, books
// WsUpdateOrderbook updates the orderbook list, removing and adding to the
// orderbook sides
func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
orderbookUpdate := orderbook.Update{
Asset: assetType,
Pair: p,
@@ -1592,7 +1633,9 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa
if err = validateCRC32(ob, checkme.Token); err != nil {
log.Errorf(log.WebsocketMgr, "%s websocket orderbook update error, will resubscribe orderbook: %v", b.Name, err)
b.resubOrderbook(c)
if e2 := b.resubOrderbook(c); e2 != nil {
log.Errorf(log.WebsocketMgr, "%s error resubscribing orderbook: %v", b.Name, e2)
}
return err
}
}
@@ -1603,8 +1646,15 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa
// resubOrderbook resubscribes the orderbook after a consistency error, probably a failed checksum,
// which forces a fresh snapshot. If we don't do this the orderbook will keep erroring and drifting.
// Flushing the orderbook happens immediately, but the ReSub itself is a go routine to avoid blocking the WS data channel
func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) {
if err := b.Websocket.Orderbook.FlushOrderbook(c.Pair, c.Asset); err != nil {
func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error {
if c == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
if err := b.Websocket.Orderbook.FlushOrderbook(c.Pairs[0], c.Asset); err != nil {
// Non-fatal error
log.Errorf(log.ExchangeSys, "%s error flushing orderbook: %v", b.Name, err)
}
@@ -1614,13 +1664,15 @@ func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) {
log.Errorf(log.ExchangeSys, "%s error resubscribing orderbook: %v", b.Name, err)
}
}()
return nil
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (b *Bitfinex) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (b *Bitfinex) GenerateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{wsBook, wsTrades, wsTicker, wsCandles}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
assets := b.GetAssetTypes(true)
for i := range assets {
if !b.IsAssetWebsocketSupported(assets[i]) {
@@ -1643,9 +1695,9 @@ func (b *Bitfinex) GenerateDefaultSubscriptions() ([]subscription.Subscription,
params[CandlesPeriodKey] = "30"
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[j],
Pair: enabledPairs[k],
Pairs: currency.Pairs{enabledPairs[k]},
Params: params,
Asset: assets[i],
})
@@ -1665,26 +1717,26 @@ func (b *Bitfinex) ConfigureWS() error {
}
// Subscribe sends a websocket message to receive data from channels
func (b *Bitfinex) Subscribe(channels []subscription.Subscription) error {
func (b *Bitfinex) Subscribe(channels subscription.List) error {
return b.ParallelChanOp(channels, b.subscribeToChan, 1)
}
// Unsubscribe sends a websocket message to stop receiving data from channels
func (b *Bitfinex) Unsubscribe(channels []subscription.Subscription) error {
func (b *Bitfinex) Unsubscribe(channels subscription.List) error {
return b.ParallelChanOp(channels, b.unsubscribeFromChan, 1)
}
// subscribeToChan handles a single subscription and parses the result
// on success it adds the subscription to the websocket
func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error {
func (b *Bitfinex) subscribeToChan(chans subscription.List) error {
if len(chans) != 1 {
return errors.New("subscription batching limited to 1")
}
c := chans[0]
req, err := subscribeReq(&c)
req, err := subscribeReq(c)
if err != nil {
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair)
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs)
}
// subId is a single round-trip identifier that provides linking sub requests to chanIDs
@@ -1695,23 +1747,22 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error {
// Add a temporary Key so we can find this Sub when we get the resp without delay or context switch
// Otherwise we might drop the first messages after the subscribed resp
c.Key = subID // Note subID string type avoids conflicts with later chanID key
c.State = subscription.SubscribingState
err = b.Websocket.AddSubscription(&c)
if err != nil {
return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pair, err)
if err = b.Websocket.AddSubscriptions(c); err != nil {
return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err)
}
// Always remove the temporary subscription keyed by subID
defer b.Websocket.RemoveSubscriptions(c)
defer func() {
_ = b.Websocket.RemoveSubscriptions(c)
}()
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req)
if err != nil {
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair)
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs)
}
if err = b.getErrResp(respRaw); err != nil {
wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair)
wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs)
b.Websocket.DataHandler <- wErr
return wErr
}
@@ -1721,6 +1772,13 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error {
// subscribeReq returns a map of request params for subscriptions
func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error) {
if c == nil {
return nil, fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
if len(c.Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
pair := c.Pairs[0]
req := map[string]interface{}{
"event": "subscribe",
"channel": c.Channel,
@@ -1743,13 +1801,13 @@ func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error)
prefix = "f"
}
needsDelimiter := c.Pair.Len() > 6
needsDelimiter := pair.Len() > 6
var formattedPair string
if needsDelimiter {
formattedPair = c.Pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String()
formattedPair = pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String()
} else {
formattedPair = currency.PairFormat{Uppercase: true}.Format(c.Pair)
formattedPair = currency.PairFormat{Uppercase: true}.Format(pair)
}
if c.Channel == wsCandles {
@@ -1776,7 +1834,7 @@ func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error)
}
// unsubscribeFromChan sends a websocket message to stop receiving data from a channel
func (b *Bitfinex) unsubscribeFromChan(chans []subscription.Subscription) error {
func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error {
if len(chans) != 1 {
return errors.New("subscription batching limited to 1")
}
@@ -1802,9 +1860,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans []subscription.Subscription) error
return wErr
}
b.Websocket.RemoveSubscriptions(c)
return nil
return b.Websocket.RemoveSubscriptions(c)
}
// getErrResp takes a json response string and looks for an error event type

View File

@@ -620,8 +620,8 @@ func (b *Bitfinex) SubmitOrder(ctx context.Context, o *order.Submit) (*order.Sub
var orderID string
status := order.New
if b.Websocket.CanUseAuthenticatedWebsocketForWrapper() {
symbolStr, err := b.fixCasing(fPair, o.AssetType) //nolint:govet // intentional shadow of err
if err != nil {
var symbolStr string
if symbolStr, err = b.fixCasing(fPair, o.AssetType); err != nil {
return nil, err
}
orderType := strings.ToUpper(o.Type.String())

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
@@ -166,10 +167,10 @@ func (b *Bithumb) wsHandleData(respRaw []byte) error {
return nil
}
// GenerateSubscriptions generates the default subscription set
func (b *Bithumb) GenerateSubscriptions() ([]subscription.Subscription, error) {
// generateSubscriptions generates the default subscription set
func (b *Bithumb) generateSubscriptions() (subscription.List, error) {
var channels = []string{"ticker", "transaction", "orderbookdepth"}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
pairs, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
@@ -179,44 +180,36 @@ func (b *Bithumb) GenerateSubscriptions() ([]subscription.Subscription, error) {
if err != nil {
return nil, err
}
pairs = pairs.Format(pFmt)
for x := range pairs {
for y := range channels {
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[y],
Pair: pairs[x].Format(pFmt),
Asset: asset.Spot,
})
}
for y := range channels {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[y],
Pairs: pairs,
Asset: asset.Spot,
})
}
return subscriptions, nil
}
// Subscribe subscribes to a set of channels
func (b *Bithumb) Subscribe(channelsToSubscribe []subscription.Subscription) error {
subs := make(map[string]*WsSubscribe)
for i := range channelsToSubscribe {
s, ok := subs[channelsToSubscribe[i].Channel]
if !ok {
s = &WsSubscribe{
Type: channelsToSubscribe[i].Channel,
}
subs[channelsToSubscribe[i].Channel] = s
func (b *Bithumb) Subscribe(channelsToSubscribe subscription.List) error {
var errs error
for _, s := range channelsToSubscribe {
req := &WsSubscribe{
Type: s.Channel,
Symbols: s.Pairs,
}
if s.Channel == "ticker" {
req.TickTypes = wsDefaultTickTypes
}
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(s)
}
s.Symbols = append(s.Symbols, channelsToSubscribe[i].Pair)
}
tSub, ok := subs["ticker"]
if ok {
tSub.TickTypes = wsDefaultTickTypes
}
for _, s := range subs {
err := b.Websocket.Conn.SendJSONMessage(s)
if err != nil {
return err
errs = common.AppendError(errs, err)
}
}
b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
return nil
return errs
}

View File

@@ -2,7 +2,6 @@ package bithumb
import (
"errors"
"sync"
"testing"
"github.com/thrasher-corp/gocryptotrader/currency"
@@ -46,7 +45,6 @@ func TestWsHandleData(t *testing.T) {
},
},
Websocket: &stream.Websocket{
Wg: new(sync.WaitGroup),
DataHandler: make(chan interface{}, 1),
},
},
@@ -91,7 +89,7 @@ func TestWsHandleData(t *testing.T) {
func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
sub, err := b.GenerateSubscriptions()
sub, err := b.generateSubscriptions()
if err != nil {
t.Fatal(err)
}

View File

@@ -162,7 +162,7 @@ func (b *Bithumb) Setup(exch *config.Exchange) error {
RunningURL: ePoint,
Connector: b.WsConnect,
Subscriber: b.Subscribe,
GenerateSubscriptions: b.GenerateSubscriptions,
GenerateSubscriptions: b.generateSubscriptions,
Features: &b.Features.Supports.WebsocketCapabilities,
})
if err != nil {

View File

@@ -65,6 +65,11 @@ const (
bitmexActionUpdateData = "update"
)
var subscriptionNames = map[string]string{
subscription.OrderbookChannel: bitmexWSOrderbookL2,
subscription.AllTradesChannel: bitmexWSTrade,
}
// WsConnect initiates a new websocket connection
func (b *Bitmex) WsConnect() error {
if !b.Websocket.IsEnabled() || !b.IsEnabled() {
@@ -100,18 +105,11 @@ func (b *Bitmex) WsConnect() error {
if b.Websocket.CanUseAuthenticatedEndpoints() {
err = b.websocketSendAuth(context.TODO())
if err != nil {
log.Errorf(log.ExchangeSys,
"%v - authentication failed: %v\n",
b.Name,
err)
} else {
authsubs, err := b.GenerateAuthenticatedSubscriptions()
if err != nil {
return err
}
return b.Websocket.SubscribeToChannels(authsubs)
b.Websocket.SetCanUseAuthenticatedEndpoints(false)
log.Errorf(log.ExchangeSys, "%v - authentication failed: %v\n", b.Name, err)
}
}
return nil
}
@@ -544,122 +542,96 @@ func (b *Bitmex) processOrderbook(data []OrderBookL2, action string, p currency.
return nil
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (b *Bitmex) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
channels := []string{bitmexWSOrderbookL2, bitmexWSTrade}
subscriptions := []subscription.Subscription{
{
Channel: bitmexWSAnnouncement,
},
// generateSubscriptions returns Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (b *Bitmex) generateSubscriptions() (subscription.List, error) {
authed := b.Websocket.CanUseAuthenticatedEndpoints()
assetPairs := map[asset.Item]currency.Pairs{}
for _, a := range b.GetAssetTypes(true) {
p, err := b.GetEnabledPairs(a)
if err != nil {
return nil, err
}
f, err := b.GetPairFormat(a, true)
if err != nil {
return nil, err
}
assetPairs[a] = p.Format(f)
}
assets := b.GetAssetTypes(true)
for x := range assets {
pFmt, err := b.GetPairFormat(assets[x], true)
if err != nil {
return nil, err
subs := subscription.List{}
for _, baseSub := range b.Features.Subscriptions {
if !authed && baseSub.Authenticated {
continue
}
contracts, err := b.GetEnabledPairs(assets[x])
if err != nil {
return nil, err
if baseSub.Asset == asset.Empty {
// Skip pair handling for subs which don't have an asset
subs = append(subs, baseSub.Clone())
continue
}
for y := range contracts {
for z := range channels {
if assets[x] == asset.Index && channels[z] == bitmexWSOrderbookL2 {
// There are no L2 orderbook for index assets
continue
}
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[z] + ":" + pFmt.Format(contracts[y]),
Pair: contracts[y],
Asset: assets[x],
})
for a, p := range assetPairs {
if baseSub.Channel == bitmexWSOrderbookL2 && a == asset.Index {
continue // There are no L2 orderbook for index assets
}
if baseSub.Asset != asset.All && baseSub.Asset != a {
continue
}
s := baseSub.Clone()
s.Asset = a
s.Pairs = p
subs = append(subs, s)
}
}
return subscriptions, nil
}
// GenerateAuthenticatedSubscriptions Adds authenticated subscriptions to websocket to be handled by ManageSubscriptions()
func (b *Bitmex) GenerateAuthenticatedSubscriptions() ([]subscription.Subscription, error) {
if !b.Websocket.CanUseAuthenticatedEndpoints() {
return nil, nil
}
pFmt, err := b.GetPairFormat(asset.PerpetualContract, true)
if err != nil {
return nil, err
}
contracts, err := b.GetEnabledPairs(asset.PerpetualContract)
if err != nil {
return nil, err
}
channels := []string{bitmexWSExecution,
bitmexWSPosition,
}
subscriptions := []subscription.Subscription{
{
Channel: bitmexWSAffiliate,
},
{
Channel: bitmexWSOrder,
},
{
Channel: bitmexWSMargin,
},
{
Channel: bitmexWSPrivateNotifications,
},
{
Channel: bitmexWSTransact,
},
{
Channel: bitmexWSWallet,
},
}
for i := range channels {
for j := range contracts {
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[i] + ":" + pFmt.Format(contracts[j]),
Pair: contracts[j],
Asset: asset.PerpetualContract,
})
}
}
return subscriptions, nil
return subs, nil
}
// Subscribe subscribes to a websocket channel
func (b *Bitmex) Subscribe(channelsToSubscribe []subscription.Subscription) error {
var subscriber WebsocketRequest
subscriber.Command = "subscribe"
for i := range channelsToSubscribe {
subscriber.Arguments = append(subscriber.Arguments,
channelsToSubscribe[i].Channel)
func (b *Bitmex) Subscribe(subs subscription.List) error {
req := WebsocketRequest{
Command: "subscribe",
}
err := b.Websocket.Conn.SendJSONMessage(subscriber)
if err != nil {
return err
for _, s := range subs {
for _, p := range s.Pairs {
cName := channelName(s.Channel)
req.Arguments = append(req.Arguments, cName+":"+p.String())
}
}
b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
return nil
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(subs...)
}
return err
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (b *Bitmex) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
var unsubscriber WebsocketRequest
unsubscriber.Command = "unsubscribe"
func (b *Bitmex) Unsubscribe(subs subscription.List) error {
req := WebsocketRequest{
Command: "unsubscribe",
}
for i := range channelsToUnsubscribe {
unsubscriber.Arguments = append(unsubscriber.Arguments,
channelsToUnsubscribe[i].Channel)
for _, s := range subs {
for _, p := range s.Pairs {
cName := channelName(s.Channel)
req.Arguments = append(req.Arguments, cName+":"+p.String())
}
}
err := b.Websocket.Conn.SendJSONMessage(unsubscriber)
if err != nil {
return err
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.RemoveSubscriptions(subs...)
}
b.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
return nil
return err
}
// channelName converts global channel Names used in config of channel input into bitmex channel names
// returns the name unchanged if no match is found
func channelName(name string) string {
if s, ok := subscriptionNames[name]; ok {
return s
}
return name
}
// WebsocketSendAuth sends an authenticated subscription

View File

@@ -28,6 +28,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
"github.com/thrasher-corp/gocryptotrader/log"
@@ -136,6 +137,19 @@ func (b *Bitmex) SetDefaults() {
Enabled: exchange.FeaturesEnabled{
AutoPairUpdates: true,
},
Subscriptions: subscription.List{
{Enabled: true, Channel: bitmexWSAnnouncement},
{Enabled: true, Channel: bitmexWSOrderbookL2, Asset: asset.All},
{Enabled: true, Channel: bitmexWSTrade, Asset: asset.All},
{Enabled: true, Channel: bitmexWSAffiliate, Authenticated: true},
{Enabled: true, Channel: bitmexWSOrder, Authenticated: true},
{Enabled: true, Channel: bitmexWSMargin, Authenticated: true},
{Enabled: true, Channel: bitmexWSPrivateNotifications, Authenticated: true},
{Enabled: true, Channel: bitmexWSTransact, Authenticated: true},
{Enabled: true, Channel: bitmexWSWallet, Authenticated: true},
{Enabled: true, Channel: bitmexWSExecution, Authenticated: true, Asset: asset.PerpetualContract},
{Enabled: true, Channel: bitmexWSPosition, Authenticated: true, Asset: asset.PerpetualContract},
},
}
b.Requester, err = request.New(b.Name,
@@ -185,7 +199,7 @@ func (b *Bitmex) Setup(exch *config.Exchange) error {
Connector: b.WsConnect,
Subscriber: b.Subscribe,
Unsubscriber: b.Unsubscribe,
GenerateSubscriptions: b.GenerateDefaultSubscriptions,
GenerateSubscriptions: b.generateSubscriptions,
Features: &b.Features.Supports.WebsocketCapabilities,
OrderbookBufferConfig: buffer.Config{
UpdateEntriesByID: true,

View File

@@ -231,30 +231,30 @@ func (b *Bitstamp) handleWSOrder(wsResp *websocketResponse, msg []byte) error {
return nil
}
func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (b *Bitstamp) generateDefaultSubscriptions() (subscription.List, error) {
enabledCurrencies, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
for i := range enabledCurrencies {
p, err := b.FormatExchangeCurrency(enabledCurrencies[i], asset.Spot)
if err != nil {
return nil, err
}
for j := range defaultSubChannels {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: defaultSubChannels[j] + "_" + p.String(),
Asset: asset.Spot,
Pair: p,
Pairs: currency.Pairs{p},
})
}
if b.Websocket.CanUseAuthenticatedEndpoints() {
for j := range defaultAuthSubChannels {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: defaultAuthSubChannels[j] + "_" + p.String(),
Asset: asset.Spot,
Pair: p,
Pairs: currency.Pairs{p},
Params: map[string]interface{}{
"auth": struct{}{},
},
@@ -266,7 +266,7 @@ func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription,
}
// Subscribe sends a websocket message to receive data from the channel
func (b *Bitstamp) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (b *Bitstamp) Subscribe(channelsToSubscribe subscription.List) error {
var errs error
var auth *WebsocketAuthResponse
@@ -281,44 +281,46 @@ func (b *Bitstamp) Subscribe(channelsToSubscribe []subscription.Subscription) er
}
}
for i := range channelsToSubscribe {
for _, s := range channelsToSubscribe {
req := websocketEventRequest{
Event: "bts:subscribe",
Data: websocketData{
Channel: channelsToSubscribe[i].Channel,
Channel: s.Channel,
},
}
if _, ok := channelsToSubscribe[i].Params["auth"]; ok && auth != nil {
if _, ok := s.Params["auth"]; ok && auth != nil {
req.Data.Channel = "private-" + req.Data.Channel + "-" + strconv.Itoa(int(auth.UserID))
req.Data.Auth = auth.Token
}
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(s)
}
if err != nil {
errs = common.AppendError(errs, err)
continue
}
b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
}
return errs
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (b *Bitstamp) Unsubscribe(channelsToUnsubscribe subscription.List) error {
var errs error
for i := range channelsToUnsubscribe {
for _, s := range channelsToUnsubscribe {
req := websocketEventRequest{
Event: "bts:unsubscribe",
Data: websocketData{
Channel: channelsToUnsubscribe[i].Channel,
Channel: s.Channel,
},
}
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.RemoveSubscriptions(s)
}
if err != nil {
errs = common.AppendError(errs, err)
continue
}
b.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i])
}
return errs
}

View File

@@ -82,13 +82,13 @@ const (
immediateOrCancel = "IOC"
fillOrKill = "FOK"
subscribe = "subscribe"
fundChange = "fundChange"
orderChange = "orderChange"
heartbeat = "heartbeat"
tick = "tick"
wsOB = "orderbookUpdate"
tradeEndPoint = "trade"
subscribe = "subscribe"
fundChange = "fundChange"
orderChange = "orderChange"
heartbeat = "heartbeat"
tick = "tick"
wsOrderbookUpdate = "orderbookUpdate"
tradeEndPoint = "trade"
// Subscription management when connection and subscription established
addSubscription = "addSubscription"

View File

@@ -127,7 +127,7 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error {
if b.Verbose {
log.Debugf(log.ExchangeSys, "%v - Websocket heartbeat received %s", b.Name, respRaw)
}
case wsOB:
case wsOrderbookUpdate:
var ob WsOrderbook
err := json.Unmarshal(respRaw, &ob)
if err != nil {
@@ -325,26 +325,24 @@ func (b *BTCMarkets) wsHandleData(respRaw []byte) error {
return nil
}
func (b *BTCMarkets) generateDefaultSubscriptions() ([]subscription.Subscription, error) {
var channels = []string{wsOB, tick, tradeEndPoint}
func (b *BTCMarkets) generateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{wsOrderbookUpdate, tick, tradeEndPoint}
enabledCurrencies, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
for i := range channels {
for j := range enabledCurrencies {
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[i],
Pair: enabledCurrencies[j],
Asset: asset.Spot,
})
}
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[i],
Pairs: enabledCurrencies,
Asset: asset.Spot,
})
}
if b.Websocket.CanUseAuthenticatedEndpoints() {
for i := range authChannels {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: authChannels[i],
})
}
@@ -353,96 +351,91 @@ func (b *BTCMarkets) generateDefaultSubscriptions() ([]subscription.Subscription
}
// Subscribe sends a websocket message to receive data from the channel
func (b *BTCMarkets) Subscribe(subs []subscription.Subscription) error {
var payload WsSubscribe
if len(subs) > 1 {
// TODO: Expand this to stream package as this assumes that we are doing
// an initial sync.
payload.MessageType = subscribe
} else {
payload.MessageType = addSubscription
payload.ClientType = clientType
func (b *BTCMarkets) Subscribe(subs subscription.List) error {
baseReq := &WsSubscribe{
MessageType: subscribe,
}
var authenticate bool
for i := range subs {
if !authenticate && common.StringDataContains(authChannels, subs[i].Channel) {
authenticate = true
var errs error
for _, s := range subs {
if baseReq.Key == "" && common.StringDataContains(authChannels, s.Channel) {
if err := b.authWsSubscibeReq(baseReq); err != nil {
return err
}
}
payload.Channels = append(payload.Channels, subs[i].Channel)
if subs[i].Pair.IsEmpty() {
continue
}
pair := subs[i].Pair.String()
if common.StringDataCompare(payload.MarketIDs, pair) {
continue
}
payload.MarketIDs = append(payload.MarketIDs, pair)
}
if authenticate {
creds, err := b.GetCredentials(context.TODO())
if baseReq.MessageType == subscribe && len(b.Websocket.GetSubscriptions()) != 0 {
baseReq.MessageType = addSubscription // After first *successful* subscription API requires addSubscription
baseReq.ClientType = clientType // Note: Only addSubscription requires/accepts clientType
}
r := baseReq
r.Channels = []string{s.Channel}
r.MarketIDs = s.Pairs.Strings()
err := b.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(s)
}
if err != nil {
return err
errs = common.AppendError(errs, err)
}
signTime := strconv.FormatInt(time.Now().UnixMilli(), 10)
strToSign := "/users/self/subscribe" + "\n" + signTime
var tempSign []byte
tempSign, err = crypto.GetHMAC(crypto.HashSHA512,
[]byte(strToSign),
[]byte(creds.Secret))
if err != nil {
return err
}
sign := crypto.Base64Encode(tempSign)
payload.Key = creds.Key
payload.Signature = sign
payload.Timestamp = signTime
}
if err := b.Websocket.Conn.SendJSONMessage(payload); err != nil {
return errs
}
func (b *BTCMarkets) authWsSubscibeReq(r *WsSubscribe) error {
creds, err := b.GetCredentials(context.TODO())
if err != nil {
return err
}
b.Websocket.AddSuccessfulSubscriptions(subs...)
r.Timestamp = strconv.FormatInt(time.Now().UnixMilli(), 10)
strToSign := "/users/self/subscribe" + "\n" + r.Timestamp
tempSign, err := crypto.GetHMAC(crypto.HashSHA512, []byte(strToSign), []byte(creds.Secret))
if err != nil {
return err
}
sign := crypto.Base64Encode(tempSign)
r.Key = creds.Key
r.Signature = sign
return nil
}
// Unsubscribe sends a websocket message to manage and remove a subscription.
func (b *BTCMarkets) Unsubscribe(subs []subscription.Subscription) error {
payload := WsSubscribe{
MessageType: removeSubscription,
ClientType: clientType,
}
for i := range subs {
payload.Channels = append(payload.Channels, subs[i].Channel)
if subs[i].Pair.IsEmpty() {
continue
func (b *BTCMarkets) Unsubscribe(subs subscription.List) error {
var errs error
for _, s := range subs {
req := WsSubscribe{
MessageType: removeSubscription,
ClientType: clientType,
Channels: []string{s.Channel},
MarketIDs: s.Pairs.Strings(),
}
pair := subs[i].Pair.String()
if common.StringDataCompare(payload.MarketIDs, pair) {
continue
err := b.Websocket.Conn.SendJSONMessage(req)
if err == nil {
err = b.Websocket.RemoveSubscriptions(s)
}
if err != nil {
errs = common.AppendError(errs, err)
}
payload.MarketIDs = append(payload.MarketIDs, pair)
}
err := b.Websocket.Conn.SendJSONMessage(payload)
if err != nil {
return err
}
b.Websocket.RemoveSubscriptions(subs...)
return nil
return errs
}
// ReSubscribeSpecificOrderbook removes the subscription and the subscribes
// again to fetch a new snapshot in the event of a de-sync event.
func (b *BTCMarkets) ReSubscribeSpecificOrderbook(pair currency.Pair) error {
sub := []subscription.Subscription{{
Channel: wsOB,
Pair: pair,
sub := subscription.List{{
Channel: wsOrderbookUpdate,
Pairs: currency.Pairs{pair},
Asset: asset.Spot,
}}
if err := b.Unsubscribe(sub); err != nil {
if err := b.Unsubscribe(sub); err != nil && !errors.Is(err, subscription.ErrNotFound) {
// ErrNotFound is okay, because we might be re-subscribing a single pair from a larger list
// BTC-Market handles unsub/sub of one pair gracefully and the other pairs are unaffected
return err
}
return b.Subscribe(sub)

View File

@@ -361,23 +361,23 @@ func (b *BTSE) orderbookFilter(price, amount float64) bool {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (b *BTSE) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (b *BTSE) GenerateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{"orderBookL2Api:%s_0", "tradeHistory:%s"}
pairs, err := b.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
if b.Websocket.CanUseAuthenticatedEndpoints() {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: "notificationApi",
})
}
for i := range channels {
for j := range pairs {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: fmt.Sprintf(channels[i], pairs[j]),
Pair: pairs[j],
Pairs: currency.Pairs{pairs[j]},
Asset: asset.Spot,
})
}
@@ -386,22 +386,21 @@ func (b *BTSE) GenerateDefaultSubscriptions() ([]subscription.Subscription, erro
}
// Subscribe sends a websocket message to receive data from the channel
func (b *BTSE) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (b *BTSE) Subscribe(channelsToSubscribe subscription.List) error {
var sub wsSub
sub.Operation = "subscribe"
for i := range channelsToSubscribe {
sub.Arguments = append(sub.Arguments, channelsToSubscribe[i].Channel)
}
err := b.Websocket.Conn.SendJSONMessage(sub)
if err != nil {
return err
if err == nil {
err = b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
}
b.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
return nil
return err
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (b *BTSE) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (b *BTSE) Unsubscribe(channelsToUnsubscribe subscription.List) error {
var unSub wsSub
unSub.Operation = "unsubscribe"
for i := range channelsToUnsubscribe {
@@ -409,9 +408,8 @@ func (b *BTSE) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) er
channelsToUnsubscribe[i].Channel)
}
err := b.Websocket.Conn.SendJSONMessage(unSub)
if err != nil {
return err
if err == nil {
err = b.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
}
b.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
return nil
return err
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
@@ -32,8 +33,8 @@ func (by *Bybit) WsInverseConnect() error {
}
// GenerateInverseDefaultSubscriptions generates default subscription
func (by *Bybit) GenerateInverseDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (by *Bybit) GenerateInverseDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker}
pairs, err := by.GetEnabledPairs(asset.CoinMarginedFutures)
if err != nil {
@@ -42,9 +43,9 @@ func (by *Bybit) GenerateInverseDefaultSubscriptions() ([]subscription.Subscript
for z := range pairs {
for x := range channels {
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: channels[x],
Pair: pairs[z],
Pairs: currency.Pairs{pairs[z]},
Asset: asset.CoinMarginedFutures,
})
}
@@ -53,16 +54,16 @@ func (by *Bybit) GenerateInverseDefaultSubscriptions() ([]subscription.Subscript
}
// InverseSubscribe sends a subscription message to linear public channels.
func (by *Bybit) InverseSubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) InverseSubscribe(channelSubscriptions subscription.List) error {
return by.handleInversePayloadSubscription("subscribe", channelSubscriptions)
}
// InverseUnsubscribe sends an unsubscription messages through linear public channels.
func (by *Bybit) InverseUnsubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) InverseUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleInversePayloadSubscription("unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.CoinMarginedFutures, operation, channelSubscriptions)
if err != nil {
return err

View File

@@ -41,8 +41,8 @@ func (by *Bybit) WsLinearConnect() error {
}
// GenerateLinearDefaultSubscriptions generates default subscription
func (by *Bybit) GenerateLinearDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (by *Bybit) GenerateLinearDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker}
pairs, err := by.GetEnabledPairs(asset.USDTMarginedFutures)
if err != nil {
@@ -61,9 +61,9 @@ func (by *Bybit) GenerateLinearDefaultSubscriptions() ([]subscription.Subscripti
for p := range linearPairMap[a] {
for x := range channels {
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: channels[x],
Pair: pairs[p],
Pairs: currency.Pairs{pairs[p]},
Asset: a,
})
}
@@ -73,16 +73,16 @@ func (by *Bybit) GenerateLinearDefaultSubscriptions() ([]subscription.Subscripti
}
// LinearSubscribe sends a subscription message to linear public channels.
func (by *Bybit) LinearSubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) LinearSubscribe(channelSubscriptions subscription.List) error {
return by.handleLinearPayloadSubscription("subscribe", channelSubscriptions)
}
// LinearUnsubscribe sends an unsubscription messages through linear public channels.
func (by *Bybit) LinearUnsubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) LinearUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleLinearPayloadSubscription("unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.USDTMarginedFutures, operation, channelSubscriptions)
if err != nil {
return err

View File

@@ -6,6 +6,7 @@ import (
"strconv"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
@@ -39,8 +40,8 @@ func (by *Bybit) WsOptionsConnect() error {
}
// GenerateOptionsDefaultSubscriptions generates default subscription
func (by *Bybit) GenerateOptionsDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (by *Bybit) GenerateOptionsDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker}
pairs, err := by.GetEnabledPairs(asset.Options)
if err != nil {
@@ -49,9 +50,9 @@ func (by *Bybit) GenerateOptionsDefaultSubscriptions() ([]subscription.Subscript
for z := range pairs {
for x := range channels {
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: channels[x],
Pair: pairs[z],
Pairs: currency.Pairs{pairs[z]},
Asset: asset.Options,
})
}
@@ -60,16 +61,16 @@ func (by *Bybit) GenerateOptionsDefaultSubscriptions() ([]subscription.Subscript
}
// OptionSubscribe sends a subscription message to options public channels.
func (by *Bybit) OptionSubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) OptionSubscribe(channelSubscriptions subscription.List) error {
return by.handleOptionsPayloadSubscription("subscribe", channelSubscriptions)
}
// OptionUnsubscribe sends an unsubscription messages through options public channels.
func (by *Bybit) OptionUnsubscribe(channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) OptionUnsubscribe(channelSubscriptions subscription.List) error {
return by.handleOptionsPayloadSubscription("unsubscribe", channelSubscriptions)
}
func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions []subscription.Subscription) error {
func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions subscription.List) error {
payloads, err := by.handleSubscriptions(asset.Options, operation, channelSubscriptions)
if err != nil {
return err

View File

@@ -134,11 +134,11 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
}
// Subscribe sends a websocket message to receive data from the channel
func (by *Bybit) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (by *Bybit) Subscribe(channelsToSubscribe subscription.List) error {
return by.handleSpotSubscription("subscribe", channelsToSubscribe)
}
func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, channelsToSubscribe []subscription.Subscription) ([]SubscriptionArgument, error) {
func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, channelsToSubscribe subscription.List) ([]SubscriptionArgument, error) {
var args []SubscriptionArgument
arg := SubscriptionArgument{
Operation: operation,
@@ -166,17 +166,21 @@ func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, cha
return nil, err
}
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
pair := channelsToSubscribe[i].Pairs[0]
switch channelsToSubscribe[i].Channel {
case chanOrderbook:
arg.Arguments = append(arg.Arguments, fmt.Sprintf("%s.%d.%s", channelsToSubscribe[i].Channel, 50, channelsToSubscribe[i].Pair.Format(pairFormat).String()))
arg.Arguments = append(arg.Arguments, fmt.Sprintf("%s.%d.%s", channelsToSubscribe[i].Channel, 50, pair.Format(pairFormat).String()))
case chanPublicTrade, chanPublicTicker, chanLiquidation, chanLeverageTokenTicker, chanLeverageTokenNav:
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+channelsToSubscribe[i].Pair.Format(pairFormat).String())
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+pair.Format(pairFormat).String())
case chanKline, chanLeverageTokenKline:
interval, err := intervalToString(kline.FiveMin)
if err != nil {
return nil, err
}
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+interval+"."+channelsToSubscribe[i].Pair.Format(pairFormat).String())
arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+interval+"."+pair.Format(pairFormat).String())
case chanPositions, chanExecution, chanOrder, chanWallet, chanGreeks, chanDCP:
if chanMap[channelsToSubscribe[i].Channel]&selectedChannels > 0 {
continue
@@ -204,11 +208,11 @@ func (by *Bybit) handleSubscriptions(assetType asset.Item, operation string, cha
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (by *Bybit) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (by *Bybit) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return by.handleSpotSubscription("unsubscribe", channelsToUnsubscribe)
}
func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe []subscription.Subscription) error {
func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe subscription.List) error {
payloads, err := by.handleSubscriptions(asset.Spot, operation, channelsToSubscribe)
if err != nil {
return err
@@ -239,8 +243,8 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe []
}
// GenerateDefaultSubscriptions generates default subscription
func (by *Bybit) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (by *Bybit) GenerateDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
var channels = []string{
chanPublicTicker,
chanOrderbook,
@@ -266,16 +270,16 @@ func (by *Bybit) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
chanDCP,
chanWallet:
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: channels[x],
Asset: asset.Spot,
})
default:
for z := range pairs {
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: channels[x],
Pair: pairs[z],
Pairs: currency.Pairs{pairs[z]},
Asset: asset.Spot,
})
}

View File

@@ -688,20 +688,11 @@ func TestWsAuth(t *testing.T) {
}
var dialer websocket.Dialer
err := c.Websocket.Conn.Dial(&dialer, http.Header{})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err, "Dial must not error")
go c.wsReadData()
err = c.Subscribe([]subscription.Subscription{
{
Channel: "user",
Pair: testPair,
},
})
if err != nil {
t.Error(err)
}
err = c.Subscribe(subscription.List{{Channel: "user", Pairs: currency.Pairs{testPair}}})
require.NoError(t, err, "Subscribe must not error")
timer := time.NewTimer(sharedtestvalues.WebsocketResponseDefaultTimeout)
select {
case badResponse := <-c.Websocket.DataHandler:

View File

@@ -362,17 +362,17 @@ type FillResponse struct {
// WebsocketSubscribe takes in subscription information
type WebsocketSubscribe struct {
Type string `json:"type"`
ProductIDs []string `json:"product_ids,omitempty"`
Channels []WsChannels `json:"channels,omitempty"`
Signature string `json:"signature,omitempty"`
Key string `json:"key,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
Type string `json:"type"`
ProductIDs []string `json:"product_ids,omitempty"`
Channels []any `json:"channels,omitempty"`
Signature string `json:"signature,omitempty"`
Key string `json:"key,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
}
// WsChannels defines outgoing channels for subscription purposes
type WsChannels struct {
// WsChannel defines a websocket subscription channel
type WsChannel struct {
Name string `json:"name"`
ProductIDs []string `json:"product_ids,omitempty"`
}

View File

@@ -10,11 +10,9 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/convert"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/account"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/orderbook"
@@ -365,131 +363,105 @@ func (c *CoinbasePro) ProcessUpdate(update *WebsocketL2Update) error {
})
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (c *CoinbasePro) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var channels = []string{"heartbeat",
"level2_batch", /*Other orderbook feeds require authentication. This is batched in 50ms lots.*/
"ticker",
"user",
"matches"}
enabledPairs, err := c.GetEnabledPairs(asset.Spot)
// generateSubscriptions returns a list of subscriptions from the configured subscriptions feature
func (c *CoinbasePro) generateSubscriptions() (subscription.List, error) {
pairs, err := c.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
for i := range channels {
if (channels[i] == "user" || channels[i] == "full") &&
!c.IsWebsocketAuthenticationSupported() {
pairFmt, err := c.GetPairFormat(asset.Spot, true)
if err != nil {
return nil, err
}
pairs = pairs.Format(pairFmt)
authed := c.IsWebsocketAuthenticationSupported()
subs := make(subscription.List, 0, len(c.Features.Subscriptions))
for _, baseSub := range c.Features.Subscriptions {
if !authed && baseSub.Authenticated {
continue
}
for j := range enabledPairs {
fPair, err := c.FormatExchangeCurrency(enabledPairs[j],
asset.Spot)
if err != nil {
return nil, err
}
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[i],
Pair: fPair,
Asset: asset.Spot,
})
}
s := baseSub.Clone()
s.Asset = asset.Spot
s.Pairs = pairs
subs = append(subs, s)
}
return subscriptions, nil
return subs, nil
}
// Subscribe sends a websocket message to receive data from the channel
func (c *CoinbasePro) Subscribe(channelsToSubscribe []subscription.Subscription) error {
var creds *account.Credentials
var err error
if c.IsWebsocketAuthenticationSupported() {
creds, err = c.GetCredentials(context.TODO())
if err != nil {
return err
func (c *CoinbasePro) Subscribe(subs subscription.List) error {
r := &WebsocketSubscribe{
Type: "subscribe",
Channels: make([]any, 0, len(subs)),
}
// See if we have a consistent Pair list for all the subs that we can use globally
// If all the subs have the same pairs then we can use the top level ProductIDs field
// Otherwise each and every sub needs to have it's own list
for i, s := range subs {
if i == 0 {
r.ProductIDs = s.Pairs.Strings()
} else if !subs[0].Pairs.Equal(s.Pairs) {
r.ProductIDs = nil
break
}
}
subscribe := WebsocketSubscribe{
Type: "subscribe",
}
productIDs := make([]string, 0, len(channelsToSubscribe))
for i := range channelsToSubscribe {
p := channelsToSubscribe[i].Pair.String()
if p != "" && !common.StringDataCompare(productIDs, p) {
// get all unique productIDs in advance as we generate by channels
productIDs = append(productIDs, p)
}
}
subscriptions:
for i := range channelsToSubscribe {
for j := range subscribe.Channels {
if subscribe.Channels[j].Name == channelsToSubscribe[i].Channel {
continue subscriptions
}
}
subChan := WsChannels{
Name: channelsToSubscribe[i].Channel,
ProductIDs: productIDs,
}
if (channelsToSubscribe[i].Channel == "user" || channelsToSubscribe[i].Channel == "full") &&
creds != nil &&
subscribe.Signature == "" {
n := strconv.FormatInt(time.Now().Unix(), 10)
message := n + http.MethodGet + "/users/self/verify"
var hmac []byte
hmac, err = crypto.GetHMAC(crypto.HashSHA256,
[]byte(message),
[]byte(creds.Secret))
if err != nil {
for _, s := range subs {
if s.Authenticated && r.Key == "" && c.IsWebsocketAuthenticationSupported() {
if err := c.authWsSubscibeReq(r); err != nil {
return err
}
subscribe.Signature = crypto.Base64Encode(hmac)
subscribe.Key = creds.Key
subscribe.Passphrase = creds.ClientID
subscribe.Timestamp = n
}
subscribe.Channels = append(subscribe.Channels, subChan)
if len(r.ProductIDs) == 0 {
r.Channels = append(r.Channels, WsChannel{
Name: s.Channel,
ProductIDs: s.Pairs.Strings(),
})
} else {
// Coinbase does not support using [WsChannel{Name:"x"}] unless each ProductIDs field is populated
// Therefore we have to use Channels as an array of strings
r.Channels = append(r.Channels, s.Channel)
}
}
err = c.Websocket.Conn.SendJSONMessage(subscribe)
err := c.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = c.Websocket.AddSuccessfulSubscriptions(subs...)
}
return err
}
func (c *CoinbasePro) authWsSubscibeReq(r *WebsocketSubscribe) error {
creds, err := c.GetCredentials(context.TODO())
if err != nil {
return err
}
c.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
r.Timestamp = strconv.FormatInt(time.Now().Unix(), 10)
message := r.Timestamp + http.MethodGet + "/users/self/verify"
hmac, err := crypto.GetHMAC(crypto.HashSHA256, []byte(message), []byte(creds.Secret))
if err != nil {
return err
}
r.Signature = crypto.Base64Encode(hmac)
r.Key = creds.Key
r.Passphrase = creds.ClientID
return nil
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (c *CoinbasePro) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
unsubscribe := WebsocketSubscribe{
Type: "unsubscribe",
func (c *CoinbasePro) Unsubscribe(subs subscription.List) error {
r := &WebsocketSubscribe{
Type: "unsubscribe",
Channels: make([]any, 0, len(subs)),
}
productIDs := make([]string, 0, len(channelsToUnsubscribe))
for i := range channelsToUnsubscribe {
p := channelsToUnsubscribe[i].Pair.String()
if p != "" && !common.StringDataCompare(productIDs, p) {
// get all unique productIDs in advance as we generate by channels
productIDs = append(productIDs, p)
}
}
unsubscriptions:
for i := range channelsToUnsubscribe {
for j := range unsubscribe.Channels {
if unsubscribe.Channels[j].Name == channelsToUnsubscribe[i].Channel {
continue unsubscriptions
}
}
unsubscribe.Channels = append(unsubscribe.Channels, WsChannels{
Name: channelsToUnsubscribe[i].Channel,
ProductIDs: productIDs,
for _, s := range subs {
r.Channels = append(r.Channels, WsChannel{
Name: s.Channel,
ProductIDs: s.Pairs.Strings(),
})
}
err := c.Websocket.Conn.SendJSONMessage(unsubscribe)
if err != nil {
return err
err := c.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = c.Websocket.RemoveSubscriptions(subs...)
}
c.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
return nil
return err
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
"github.com/thrasher-corp/gocryptotrader/log"
@@ -105,6 +106,13 @@ func (c *CoinbasePro) SetDefaults() {
GlobalResultLimit: 300,
},
},
Subscriptions: subscription.List{
{Enabled: true, Channel: "heartbeat"},
{Enabled: true, Channel: "level2_batch"}, // Other orderbook feeds require authentication; This is batched in 50ms lots
{Enabled: true, Channel: "ticker"},
{Enabled: true, Channel: "user", Authenticated: true},
{Enabled: true, Channel: "matches"},
},
}
c.Requester, err = request.New(c.Name,
@@ -155,7 +163,7 @@ func (c *CoinbasePro) Setup(exch *config.Exchange) error {
Connector: c.WsConnect,
Subscriber: c.Subscribe,
Unsubscriber: c.Unsubscribe,
GenerateSubscriptions: c.GenerateDefaultSubscriptions,
GenerateSubscriptions: c.generateSubscriptions,
Features: &c.Features.Supports.WebsocketCapabilities,
OrderbookBufferConfig: buffer.Config{
SortBuffer: true,

View File

@@ -580,18 +580,18 @@ func (c *COINUT) WsProcessOrderbookUpdate(update *WsOrderbookUpdate) error {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (c *COINUT) GenerateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{"inst_tick", "inst_order_book", "inst_trade"}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
enabledPairs, err := c.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
for i := range channels {
for j := range enabledPairs {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[i],
Pair: enabledPairs[j],
Pairs: currency.Pairs{enabledPairs[j]},
Asset: asset.Spot,
})
}
@@ -600,46 +600,50 @@ func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
// Subscribe sends a websocket message to receive data from the channel
func (c *COINUT) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (c *COINUT) Subscribe(subs subscription.List) error {
var errs error
for i := range channelsToSubscribe {
fPair, err := c.FormatExchangeCurrency(channelsToSubscribe[i].Pair, asset.Spot)
for _, s := range subs {
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
fPair, err := c.FormatExchangeCurrency(s.Pairs[0], asset.Spot)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
subscribe := wsRequest{
Request: channelsToSubscribe[i].Channel,
Request: s.Channel,
InstrumentID: c.instrumentMap.LookupID(fPair.String()),
Subscribe: true,
Nonce: getNonce(),
}
err = c.Websocket.Conn.SendJSONMessage(subscribe)
if err == nil {
err = c.Websocket.AddSuccessfulSubscriptions(s)
}
if err != nil {
errs = common.AppendError(errs, err)
continue
}
c.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
}
if errs != nil {
return errs
}
return nil
return errs
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (c *COINUT) Unsubscribe(channelToUnsubscribe []subscription.Subscription) error {
func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error {
var errs error
for i := range channelToUnsubscribe {
fPair, err := c.FormatExchangeCurrency(channelToUnsubscribe[i].Pair, asset.Spot)
for _, s := range channelToUnsubscribe {
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
fPair, err := c.FormatExchangeCurrency(s.Pairs[0], asset.Spot)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
subscribe := wsRequest{
Request: channelToUnsubscribe[i].Channel,
Request: s.Channel,
InstrumentID: c.instrumentMap.LookupID(fPair.String()),
Subscribe: false,
Nonce: getNonce(),
@@ -651,22 +655,20 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe []subscription.Subscription) e
}
var response map[string]interface{}
err = json.Unmarshal(resp, &response)
if err == nil {
val, ok := response["status"].([]any)
switch {
case !ok:
err = common.GetTypeAssertError("[]any", response["status"])
case len(val) == 0, val[0] != "OK":
err = common.AppendError(errs, fmt.Errorf("%v unsubscribe failed for channel %v", c.Name, s.Channel))
default:
err = c.Websocket.RemoveSubscriptions(s)
}
}
if err != nil {
errs = common.AppendError(errs, err)
continue
}
val, ok := response["status"].([]interface{})
if !ok {
errs = common.AppendError(errs, errors.New("unable to type assert response status"))
}
if val[0] != "OK" {
errs = common.AppendError(errs, fmt.Errorf("%v unsubscribe failed for channel %v",
c.Name,
channelToUnsubscribe[i].Channel))
continue
}
c.Websocket.RemoveSubscriptions(channelToUnsubscribe[i])
}
return errs
}

View File

@@ -834,8 +834,8 @@ func (d *Deribit) processOrderbook(respRaw []byte, channels []string) error {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (d *Deribit) GenerateDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
assets := d.GetAssetTypes(true)
subscriptionChannels := defaultSubscriptions
if d.Websocket.CanUseAuthenticatedEndpoints() {
@@ -870,9 +870,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
continue
}
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
Params: map[string]interface{}{
"resolution": "1D",
},
@@ -890,9 +890,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
a == asset.Futures) || (a != asset.Spot && a != asset.Futures) {
continue
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
Asset: a,
})
}
@@ -906,9 +906,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
continue
}
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
// if needed, group and depth of orderbook can be passed as follow "group": "250", "depth": "20",
Interval: kline.HundredMilliseconds,
Asset: a,
@@ -919,9 +919,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
},
)
if d.Websocket.CanUseAuthenticatedEndpoints() {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: orderbookChannel,
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
Asset: a,
Interval: kline.Interval(0),
Params: map[string]interface{}{
@@ -942,9 +942,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
continue
}
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
Interval: kline.HundredMilliseconds,
Asset: a,
})
@@ -959,9 +959,9 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
continue
}
subscriptions = append(subscriptions,
subscription.Subscription{
&subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: assetPairs[a][z],
Pairs: currency.Pairs{assetPairs[a][z]},
Interval: kline.HundredMilliseconds,
Asset: a,
})
@@ -974,18 +974,18 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
currencyPairsName := make(map[currency.Code]bool, 2*len(assetPairs[a]))
for z := range assetPairs[a] {
if okay = currencyPairsName[assetPairs[a][z].Base]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Asset: a,
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Base},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Base}},
})
currencyPairsName[assetPairs[a][z].Base] = true
}
if okay = currencyPairsName[assetPairs[a][z].Quote]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Asset: a,
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Quote},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Quote}},
})
currencyPairsName[assetPairs[a][z].Quote] = true
}
@@ -1001,19 +1001,19 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
var okay bool
for z := range assetPairs[a] {
if okay = currencyPairsName[assetPairs[a][z].Base]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Asset: a,
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Base},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Base}},
Interval: kline.HundredMilliseconds,
})
currencyPairsName[assetPairs[a][z].Base] = true
}
if okay = currencyPairsName[assetPairs[a][z].Quote]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Asset: a,
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Quote},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Quote}},
Interval: kline.HundredMilliseconds,
})
currencyPairsName[assetPairs[a][z].Quote] = true
@@ -1028,17 +1028,17 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
var okay bool
for z := range assetPairs[a] {
if okay = currencyPairsName[assetPairs[a][z].Base]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Base},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Base}},
Asset: a,
})
currencyPairsName[assetPairs[a][z].Base] = true
}
if okay = currencyPairsName[assetPairs[a][z].Quote]; !okay {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subscriptionChannels[x],
Pair: currency.Pair{Base: assetPairs[a][z].Quote},
Pairs: currency.Pairs{currency.Pair{Base: assetPairs[a][z].Quote}},
Asset: a,
})
currencyPairsName[assetPairs[a][z].Quote] = true
@@ -1050,7 +1050,7 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
platformStateChannel,
userLockChannel,
platformStatePublicMethodsStateChannel:
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subscriptionChannels[x],
})
case priceIndexChannel,
@@ -1060,7 +1060,7 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
markPriceOptionsChannel,
estimatedExpirationPriceChannel:
for i := range indexENUMS {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subscriptionChannels[x],
Params: map[string]interface{}{
"index_name": indexENUMS[i],
@@ -1072,9 +1072,12 @@ func (d *Deribit) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
return subscriptions, nil
}
func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs []subscription.Subscription) ([]WsSubscriptionInput, error) {
func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs subscription.List) ([]WsSubscriptionInput, error) {
subscriptionPayloads := make([]WsSubscriptionInput, len(subscs))
for x := range subscs {
if len(subscs[x].Pairs) > 1 {
return nil, subscription.ErrNotSinglePair
}
sub := WsSubscriptionInput{
JSONRPCVersion: rpcVersion,
ID: d.Websocket.Conn.GenerateMessageID(false),
@@ -1090,16 +1093,16 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
sub.Method = "private/" + operation
}
var instrumentID string
if !subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) == 1 {
pairFormat, err := d.GetPairFormat(subscs[x].Asset, true)
if err != nil {
return nil, err
}
subscs[x].Pair = subscs[x].Pair.Format(pairFormat)
subscs[x].Pairs = subscs[x].Pairs.Format(pairFormat)
if subscs[x].Asset == asset.Futures {
instrumentID = d.formatFuturesTradablePair(subscs[x].Pair.Format(pairFormat))
instrumentID = d.formatFuturesTradablePair(subscs[x].Pairs[0])
} else {
instrumentID = subscs[x].Pair.String()
instrumentID = subscs[x].Pairs.Join()
}
}
switch subscs[x].Channel {
@@ -1110,7 +1113,7 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
userLockChannel:
sub.Params["channels"] = []string{subscs[x].Channel}
case orderbookChannel:
if subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
intervalString, err := d.GetResolutionFromInterval(subscs[x].Interval)
@@ -1129,14 +1132,14 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
}
sub.Params["channels"] = []string{orderbookChannel + "." + instrumentID + "." + group + "." + depth + "." + intervalString}
case chartTradesChannel:
if subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
resolution, okay := subscs[x].Params["resolution"].(string)
if !okay {
resolution = "1D"
}
sub.Params["channels"] = []string{chartTradesChannel + "." + d.formatFuturesTradablePair(subscs[x].Pair) + "." + resolution}
sub.Params["channels"] = []string{chartTradesChannel + "." + d.formatFuturesTradablePair(subscs[x].Pairs[0]) + "." + resolution}
case priceIndexChannel,
priceRankingChannel,
priceStatisticsChannel,
@@ -1149,28 +1152,37 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
}
sub.Params["channels"] = []string{subscs[x].Channel + "." + indexName}
case instrumentStateChannel:
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
kind := d.GetAssetKind(subscs[x].Asset)
currencyCode := getValidatedCurrencyCode(subscs[x].Pair)
currencyCode := getValidatedCurrencyCode(subscs[x].Pairs[0])
sub.Params["channels"] = []string{"instrument.state." + kind + "." + currencyCode}
case rawUsersOrdersKindCurrencyChannel:
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
kind := d.GetAssetKind(subscs[x].Asset)
currencyCode := getValidatedCurrencyCode(subscs[x].Pair)
currencyCode := getValidatedCurrencyCode(subscs[x].Pairs[0])
sub.Params["channels"] = []string{"user.orders." + kind + "." + currencyCode + ".raw"}
case quoteChannel,
incrementalTickerChannel:
if subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
sub.Params["channels"] = []string{subscs[x].Channel + "." + instrumentID}
case rawUserOrdersChannel:
if subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
sub.Params["channels"] = []string{"user.orders." + instrumentID + ".raw"}
case requestForQuoteChannel,
userMMPTriggerChannel,
userPortfolioChannel:
currencyCode := getValidatedCurrencyCode(subscs[x].Pair)
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
currencyCode := getValidatedCurrencyCode(subscs[x].Pairs[0])
sub.Params["channels"] = []string{subscs[x].Channel + "." + currencyCode}
case tradesChannel,
userChangesInstrumentsChannel,
@@ -1178,7 +1190,7 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
tickerChannel,
perpetualChannel,
userTradesChannelByInstrument:
if subscs[x].Pair.IsEmpty() {
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
if subscs[x].Interval.Duration() == 0 {
@@ -1195,7 +1207,10 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
rawUsersOrdersWithKindCurrencyAndIntervalChannel,
userTradesByKindCurrencyAndIntervalChannel:
kind := d.GetAssetKind(subscs[x].Asset)
currencyCode := getValidatedCurrencyCode(subscs[x].Pair)
if len(subscs[x].Pairs) != 1 {
return nil, currency.ErrCurrencyPairEmpty
}
currencyCode := getValidatedCurrencyCode(subscs[x].Pairs[0])
if subscs[x].Interval.Duration() == 0 {
sub.Params["channels"] = []string{subscs[x].Channel + "." + kind + "." + currencyCode}
continue
@@ -1214,12 +1229,12 @@ func (d *Deribit) generatePayloadFromSubscriptionInfos(operation string, subscs
}
// Subscribe sends a websocket message to receive data from the channel
func (d *Deribit) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (d *Deribit) Subscribe(channelsToSubscribe subscription.List) error {
return d.handleSubscription("subscribe", channelsToSubscribe)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (d *Deribit) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (d *Deribit) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return d.handleSubscription("unsubscribe", channelsToUnsubscribe)
}
@@ -1238,7 +1253,7 @@ func filterSubscriptionPayloads(subscription []WsSubscriptionInput) []WsSubscrip
return newSubscs
}
func (d *Deribit) handleSubscription(operation string, channels []subscription.Subscription) error {
func (d *Deribit) handleSubscription(operation string, channels subscription.List) error {
payloads, err := d.generatePayloadFromSubscriptionInfos(operation, channels)
if err != nil {
return err

View File

@@ -168,10 +168,10 @@ func (b *Base) SetSubscriptionsFromConfig() {
b.settingsMutex.Lock()
defer b.settingsMutex.Unlock()
if len(b.Config.Features.Subscriptions) == 0 {
// Set config from the defaults, including any disabled subscriptions
b.Config.Features.Subscriptions = b.Features.Subscriptions
return
}
b.Features.Subscriptions = []*subscription.Subscription{}
b.Features.Subscriptions = subscription.List{}
for _, s := range b.Config.Features.Subscriptions {
if s.Enabled {
b.Features.Subscriptions = append(b.Features.Subscriptions, s)
@@ -1126,7 +1126,7 @@ func (b *Base) FlushWebsocketChannels() error {
// SubscribeToWebsocketChannels appends to ChannelsToSubscribe
// which lets websocket.manageSubscriptions handle subscribing
func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription) error {
func (b *Base) SubscribeToWebsocketChannels(channels subscription.List) error {
if b.Websocket == nil {
return common.ErrFunctionNotSupported
}
@@ -1135,7 +1135,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription
// UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe
// which lets websocket.manageSubscriptions handle unsubscribing
func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error {
func (b *Base) UnsubscribeToWebsocketChannels(channels subscription.List) error {
if b.Websocket == nil {
return common.ErrFunctionNotSupported
}
@@ -1143,7 +1143,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscripti
}
// GetSubscriptions returns a copied list of subscriptions
func (b *Base) GetSubscriptions() ([]subscription.Subscription, error) {
func (b *Base) GetSubscriptions() (subscription.List, error) {
if b.Websocket == nil {
return nil, common.ErrFunctionNotSupported
}
@@ -1804,7 +1804,7 @@ func (b *Base) GetOpenInterest(context.Context, ...key.PairAsset) ([]futures.Ope
}
// ParallelChanOp performs a single method call in parallel across streams and waits to return any errors
func (b *Base) ParallelChanOp(channels []subscription.Subscription, m func([]subscription.Subscription) error, batchSize int) error {
func (b *Base) ParallelChanOp(channels subscription.List, m func(subscription.List) error, batchSize int) error {
wg := sync.WaitGroup{}
errC := make(chan error, len(channels))
if batchSize == 0 {
@@ -1818,7 +1818,7 @@ func (b *Base) ParallelChanOp(channels []subscription.Subscription, m func([]sub
j = len(channels)
}
wg.Add(1)
go func(c []subscription.Subscription) {
go func(c subscription.List) {
defer wg.Done()
if err := m(c); err != nil {
errC <- err

View File

@@ -880,8 +880,8 @@ func TestSetupDefaults(t *testing.T) {
DefaultURL: "ws://something.com",
RunningURL: "ws://something.com",
Connector: func() error { return nil },
GenerateSubscriptions: func() ([]subscription.Subscription, error) { return []subscription.Subscription{}, nil },
Subscriber: func([]subscription.Subscription) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) { return subscription.List{}, nil },
Subscriber: func(subscription.List) error { return nil },
})
if err != nil {
t.Fatal(err)
@@ -1207,8 +1207,8 @@ func TestIsWebsocketEnabled(t *testing.T) {
DefaultURL: "ws://something.com",
RunningURL: "ws://something.com",
Connector: func() error { return nil },
GenerateSubscriptions: func() ([]subscription.Subscription, error) { return nil, nil },
Subscriber: func([]subscription.Subscription) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) { return nil, nil },
Subscriber: func(subscription.List) error { return nil },
})
if err != nil {
t.Error(err)
@@ -1645,15 +1645,11 @@ func TestSubscribeToWebsocketChannels(t *testing.T) {
func TestUnsubscribeToWebsocketChannels(t *testing.T) {
b := Base{}
err := b.UnsubscribeToWebsocketChannels(nil)
if err == nil {
t.Fatal(err)
}
assert.ErrorIs(t, err, common.ErrFunctionNotSupported, "UnsubscribeToWebsocketChannels should error correctly with a nil Websocket")
b.Websocket = &stream.Websocket{}
err = b.UnsubscribeToWebsocketChannels(nil)
if err == nil {
t.Fatal(err)
}
assert.NoError(t, err, "UnsubscribeToWebsocketChannels from an empty/nil list should not error")
}
func TestGetSubscriptions(t *testing.T) {
@@ -2827,27 +2823,29 @@ func TestSetSubscriptionsFromConfig(t *testing.T) {
Features: &config.FeaturesConfig{},
},
}
subs := []*subscription.Subscription{
subs := subscription.List{
{Channel: subscription.CandlesChannel, Interval: kline.OneDay, Enabled: true},
{Channel: subscription.OrderbookChannel, Enabled: false},
}
b.Features.Subscriptions = subs
b.SetSubscriptionsFromConfig()
assert.ElementsMatch(t, subs, b.Config.Features.Subscriptions, "Config Subscriptions should be updated")
assert.ElementsMatch(t, subs, b.Features.Subscriptions, "Subscriptions should be the same")
assert.ElementsMatch(t, subscription.List{subs[0]}, b.Features.Subscriptions, "Actual Subscriptions should only contain Enabled")
subs = []*subscription.Subscription{
{Channel: subscription.OrderbookChannel, Interval: kline.OneDay, Enabled: true},
subs = subscription.List{
{Channel: subscription.OrderbookChannel, Enabled: true},
{Channel: subscription.CandlesChannel, Interval: kline.OneDay, Enabled: false},
}
b.Config.Features.Subscriptions = subs
b.SetSubscriptionsFromConfig()
assert.ElementsMatch(t, subs, b.Features.Subscriptions, "Subscriptions should be updated from Config")
assert.ElementsMatch(t, subs, b.Config.Features.Subscriptions, "Config Subscriptions should be the same")
assert.ElementsMatch(t, subscription.List{subs[0]}, b.Features.Subscriptions, "Subscriptions should only contain Enabled from Config")
}
// TestParallelChanOp unit tests the helper func ParallelChanOp
func TestParallelChanOp(t *testing.T) {
t.Parallel()
c := []subscription.Subscription{
c := subscription.List{
{Channel: "red"},
{Channel: "blue"},
{Channel: "violent"},
@@ -2858,7 +2856,7 @@ func TestParallelChanOp(t *testing.T) {
b := Base{}
errC := make(chan error, 1)
go func() {
errC <- b.ParallelChanOp(c, func(c []subscription.Subscription) error {
errC <- b.ParallelChanOp(c, func(c subscription.List) error {
time.Sleep(300 * time.Millisecond)
run <- struct{}{}
switch c[0].Channel {

View File

@@ -152,7 +152,7 @@ type WithdrawalHistory struct {
type Features struct {
Supports FeaturesSupported
Enabled FeaturesEnabled
Subscriptions []*subscription.Subscription
Subscriptions subscription.List
}
// FeaturesEnabled stores the exchange enabled features

View File

@@ -625,7 +625,7 @@ func (g *Gateio) processCrossMarginLoans(data []byte) error {
}
// GenerateDefaultSubscriptions returns default subscriptions
func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (g *Gateio) GenerateDefaultSubscriptions() (subscription.List, error) {
channelsToSubscribe := defaultSubscriptions
if g.Websocket.CanUseAuthenticatedEndpoints() {
channelsToSubscribe = append(channelsToSubscribe, []string{
@@ -638,7 +638,7 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
channelsToSubscribe = append(channelsToSubscribe, spotTradesChannel)
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
var err error
for i := range channelsToSubscribe {
var pairs []currency.Pair
@@ -678,9 +678,9 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
return nil, err
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channelsToSubscribe[i],
Pair: fpair.Upper(),
Pairs: currency.Pairs{fpair.Upper()},
Asset: assetType,
Params: params,
})
@@ -690,7 +690,7 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
// handleSubscription sends a websocket message to receive data from the channel
func (g *Gateio) handleSubscription(event string, channelsToSubscribe []subscription.Subscription) error {
func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscription.List) error {
payloads, err := g.generatePayload(event, channelsToSubscribe)
if err != nil {
return err
@@ -711,16 +711,19 @@ func (g *Gateio) handleSubscription(event string, channelsToSubscribe []subscrip
continue
}
if payloads[k].Event == "subscribe" {
g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
} else {
g.Websocket.RemoveSubscriptions(channelsToSubscribe[k])
err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k])
}
if err != nil {
errs = common.AppendError(errs, err)
}
}
}
return errs
}
func (g *Gateio) generatePayload(event string, channelsToSubscribe []subscription.Subscription) ([]WsInput, error) {
func (g *Gateio) generatePayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) {
if len(channelsToSubscribe) == 0 {
return nil, errors.New("cannot generate payload, no channels supplied")
}
@@ -736,10 +739,13 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe []subscriptio
var intervalString string
payloads := make([]WsInput, 0, len(channelsToSubscribe))
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
var auth *WsAuthInput
timestamp := time.Now()
channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter
params := []string{channelsToSubscribe[i].Pair.String()}
channelsToSubscribe[i].Pairs[0].Delimiter = currency.UnderscoreDelimiter
params := []string{channelsToSubscribe[i].Pairs[0].String()}
switch channelsToSubscribe[i].Channel {
case spotOrderbookChannel:
interval, okay := channelsToSubscribe[i].Params["interval"].(kline.Interval)
@@ -837,12 +843,12 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe []subscriptio
}
// Subscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) Subscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) Subscribe(channelsToUnsubscribe subscription.List) error {
return g.handleSubscription("subscribe", channelsToUnsubscribe)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleSubscription("unsubscribe", channelsToUnsubscribe)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/account"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
@@ -141,7 +142,7 @@ func (g *Gateio) wsFunnelDeliveryFuturesConnectionData(ws stream.Connection) {
}
// GenerateDeliveryFuturesDefaultSubscriptions returns delivery futures default subscriptions params.
func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.Subscription, error) {
func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() (subscription.List, error) {
_, err := g.GetCredentials(context.Background())
if err != nil {
g.Websocket.SetCanUseAuthenticatedEndpoints(false)
@@ -159,7 +160,7 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.S
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
for i := range channelsToSubscribe {
for j := range pairs {
params := make(map[string]interface{})
@@ -174,9 +175,9 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.S
if err != nil {
return nil, err
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channelsToSubscribe[i],
Pair: fpair.Upper(),
Pairs: currency.Pairs{fpair.Upper()},
Params: params,
})
}
@@ -185,17 +186,17 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.S
}
// DeliveryFuturesSubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) DeliveryFuturesSubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) DeliveryFuturesSubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleDeliveryFuturesSubscription("subscribe", channelsToUnsubscribe)
}
// DeliveryFuturesUnsubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) DeliveryFuturesUnsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) DeliveryFuturesUnsubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleDeliveryFuturesSubscription("unsubscribe", channelsToUnsubscribe)
}
// handleDeliveryFuturesSubscription sends a websocket message to receive data from the channel
func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubscribe []subscription.Subscription) error {
func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubscribe subscription.List) error {
payloads, err := g.generateDeliveryFuturesPayload(event, channelsToSubscribe)
if err != nil {
return err
@@ -222,16 +223,19 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc
errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message))
continue
}
g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil {
errs = common.AppendError(errs, err)
}
}
}
}
return errs
}
func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscribe []subscription.Subscription) ([2][]WsInput, error) {
func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) {
payloads := [2][]WsInput{}
if len(channelsToSubscribe) == 0 {
return [2][]WsInput{}, errors.New("cannot generate payload, no channels supplied")
return payloads, errors.New("cannot generate payload, no channels supplied")
}
var creds *account.Credentials
var err error
@@ -241,12 +245,14 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib
g.Websocket.SetCanUseAuthenticatedEndpoints(false)
}
}
payloads := [2][]WsInput{}
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return payloads, subscription.ErrNotSinglePair
}
var auth *WsAuthInput
timestamp := time.Now()
var params []string
params = []string{channelsToSubscribe[i].Pair.String()}
params = []string{channelsToSubscribe[i].Pairs[0].String()}
if g.Websocket.CanUseAuthenticatedEndpoints() {
switch channelsToSubscribe[i].Channel {
case futuresOrdersChannel, futuresUserTradesChannel,
@@ -256,9 +262,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib
futuresAutoOrdersChannel:
value, ok := channelsToSubscribe[i].Params["user"].(string)
if ok {
params = append(
[]string{value},
params...)
params = append([]string{value}, params...)
}
var sigTemp string
sigTemp, err = g.generateWsSignature(creds.Secret, event, channelsToSubscribe[i].Channel, timestamp)
@@ -310,7 +314,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib
params = append(params, intervalString)
}
}
if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") {
if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") {
payloads[0] = append(payloads[0], WsInput{
ID: g.Websocket.Conn.GenerateMessageID(false),
Event: event,

View File

@@ -122,7 +122,7 @@ func (g *Gateio) WsFuturesConnect() error {
}
// GenerateFuturesDefaultSubscriptions returns default subscriptions information.
func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscription, error) {
func (g *Gateio) GenerateFuturesDefaultSubscriptions() (subscription.List, error) {
channelsToSubscribe := defaultFuturesSubscriptions
if g.Websocket.CanUseAuthenticatedEndpoints() {
channelsToSubscribe = append(channelsToSubscribe,
@@ -135,7 +135,7 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscript
if err != nil {
return nil, err
}
subscriptions := make([]subscription.Subscription, len(channelsToSubscribe)*len(pairs))
subscriptions := make(subscription.List, len(channelsToSubscribe)*len(pairs))
count := 0
for i := range channelsToSubscribe {
for j := range pairs {
@@ -154,9 +154,9 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscript
if err != nil {
return nil, err
}
subscriptions[count] = subscription.Subscription{
subscriptions[count] = &subscription.Subscription{
Channel: channelsToSubscribe[i],
Pair: fpair.Upper(),
Pairs: currency.Pairs{fpair.Upper()},
Params: params,
}
count++
@@ -166,12 +166,12 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscript
}
// FuturesSubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) FuturesSubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) FuturesSubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleFuturesSubscription("subscribe", channelsToUnsubscribe)
}
// FuturesUnsubscribe sends a websocket message to stop receiving data from the channel
func (g *Gateio) FuturesUnsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) FuturesUnsubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleFuturesSubscription("unsubscribe", channelsToUnsubscribe)
}
@@ -276,7 +276,7 @@ func (g *Gateio) wsHandleFuturesData(respRaw []byte, assetType asset.Item) error
}
// handleFuturesSubscription sends a websocket message to receive data from the channel
func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe []subscription.Subscription) error {
func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe subscription.List) error {
payloads, err := g.generateFuturesPayload(event, channelsToSubscribe)
if err != nil {
return err
@@ -303,7 +303,9 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe []s
errs = common.AppendError(errs, fmt.Errorf("error while %s to channel %s error code: %d message: %s", val[k].Event, val[k].Channel, resp.Error.Code, resp.Error.Message))
continue
}
g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
if err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k]); err != nil {
errs = common.AppendError(errs, err)
}
}
}
}
@@ -313,9 +315,10 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe []s
return nil
}
func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subscription.Subscription) ([2][]WsInput, error) {
func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe subscription.List) ([2][]WsInput, error) {
payloads := [2][]WsInput{}
if len(channelsToSubscribe) == 0 {
return [2][]WsInput{}, errors.New("cannot generate payload, no channels supplied")
return payloads, errors.New("cannot generate payload, no channels supplied")
}
var creds *account.Credentials
var err error
@@ -325,12 +328,14 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs
g.Websocket.SetCanUseAuthenticatedEndpoints(false)
}
}
payloads := [2][]WsInput{}
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return payloads, subscription.ErrNotSinglePair
}
var auth *WsAuthInput
timestamp := time.Now()
var params []string
params = []string{channelsToSubscribe[i].Pair.String()}
params = []string{channelsToSubscribe[i].Pairs[0].String()}
if g.Websocket.CanUseAuthenticatedEndpoints() {
switch channelsToSubscribe[i].Channel {
case futuresOrdersChannel, futuresUserTradesChannel,
@@ -394,7 +399,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs
params = append(params, intervalString)
}
}
if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") {
if strings.HasPrefix(channelsToSubscribe[i].Pairs[0].Quote.Upper().String(), "USDT") {
payloads[0] = append(payloads[0], WsInput{
ID: g.Websocket.Conn.GenerateMessageID(false),
Event: event,

View File

@@ -105,7 +105,7 @@ func (g *Gateio) WsOptionsConnect() error {
}
// GenerateOptionsDefaultSubscriptions generates list of channel subscriptions for options asset type.
func (g *Gateio) GenerateOptionsDefaultSubscriptions() ([]subscription.Subscription, error) {
func (g *Gateio) GenerateOptionsDefaultSubscriptions() (subscription.List, error) {
channelsToSubscribe := defaultOptionsSubscriptions
var userID int64
if g.Websocket.CanUseAuthenticatedEndpoints() {
@@ -130,7 +130,7 @@ func (g *Gateio) GenerateOptionsDefaultSubscriptions() ([]subscription.Subscript
}
}
getEnabledPairs:
var subscriptions []subscription.Subscription
var subscriptions subscription.List
pairs, err := g.GetEnabledPairs(asset.Options)
if err != nil {
return nil, err
@@ -163,9 +163,9 @@ getEnabledPairs:
if err != nil {
return nil, err
}
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channelsToSubscribe[i],
Pair: fpair.Upper(),
Pairs: currency.Pairs{fpair.Upper()},
Params: params,
})
}
@@ -173,7 +173,7 @@ getEnabledPairs:
return subscriptions, nil
}
func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subscription.Subscription) ([]WsInput, error) {
func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe subscription.List) ([]WsInput, error) {
if len(channelsToSubscribe) == 0 {
return nil, errors.New("cannot generate payload, no channels supplied")
}
@@ -181,6 +181,9 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs
var intervalString string
payloads := make([]WsInput, len(channelsToSubscribe))
for i := range channelsToSubscribe {
if len(channelsToSubscribe[i].Pairs) != 1 {
return nil, subscription.ErrNotSinglePair
}
var auth *WsAuthInput
timestamp := time.Now()
var params []string
@@ -190,7 +193,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs
optionsUnderlyingPriceChannel,
optionsUnderlyingCandlesticksChannel:
var uly currency.Pair
uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pair)
uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pairs[0])
if err != nil {
return nil, err
}
@@ -198,8 +201,8 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs
case optionsBalancesChannel:
// options.balance channel does not require underlying or contract
default:
channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter
params = append(params, channelsToSubscribe[i].Pair.String())
channelsToSubscribe[i].Pairs[0].Delimiter = currency.UnderscoreDelimiter
params = append(params, channelsToSubscribe[i].Pairs[0].String())
}
switch channelsToSubscribe[i].Channel {
case optionsOrderbookChannel:
@@ -299,17 +302,17 @@ func (g *Gateio) wsReadOptionsConnData() {
}
// OptionsSubscribe sends a websocket message to stop receiving data for asset type options
func (g *Gateio) OptionsSubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) OptionsSubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleOptionsSubscription("subscribe", channelsToUnsubscribe)
}
// OptionsUnsubscribe sends a websocket message to stop receiving data for asset type options
func (g *Gateio) OptionsUnsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (g *Gateio) OptionsUnsubscribe(channelsToUnsubscribe subscription.List) error {
return g.handleOptionsSubscription("unsubscribe", channelsToUnsubscribe)
}
// handleOptionsSubscription sends a websocket message to receive data from the channel
func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe []subscription.Subscription) error {
func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe subscription.List) error {
payloads, err := g.generateOptionsPayload(event, channelsToSubscribe)
if err != nil {
return err
@@ -330,9 +333,12 @@ func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe []s
continue
}
if payloads[k].Event == "subscribe" {
g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
err = g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[k])
} else {
g.Websocket.RemoveSubscriptions(channelsToSubscribe[k])
err = g.Websocket.RemoveSubscriptions(channelsToSubscribe[k])
}
if err != nil {
errs = common.AppendError(errs, err)
}
}
}

View File

@@ -336,8 +336,15 @@ type wsSubscriptions struct {
Symbols []string `json:"symbols"`
}
type wsSubOp string
const (
wsSubscribeOp wsSubOp = "subscribe"
wsUnsubscribeOp wsSubOp = "unsubscribe"
)
type wsSubscribeRequest struct {
Type string `json:"type"`
Type wsSubOp `json:"type"`
Subscriptions []wsSubscriptions `json:"subscriptions"`
}

View File

@@ -13,7 +13,6 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/currency"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
@@ -63,7 +62,7 @@ func (g *Gemini) WsConnect() error {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (g *Gemini) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (g *Gemini) GenerateDefaultSubscriptions() (subscription.List, error) {
// See gemini_types.go for more subscription/candle vars
var channels = []string{
marketDataLevel2,
@@ -75,105 +74,53 @@ func (g *Gemini) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
for x := range channels {
for y := range pairs {
subscriptions = append(subscriptions, subscription.Subscription{
Channel: channels[x],
Pair: pairs[y],
Asset: asset.Spot,
})
}
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[x],
Pairs: pairs,
Asset: asset.Spot,
})
}
return subscriptions, nil
}
// Subscribe sends a websocket message to receive data from the channel
func (g *Gemini) Subscribe(channelsToSubscribe []subscription.Subscription) error {
channels := make([]string, 0, len(channelsToSubscribe))
for x := range channelsToSubscribe {
if common.StringDataCompareInsensitive(channels, channelsToSubscribe[x].Channel) {
continue
}
channels = append(channels, channelsToSubscribe[x].Channel)
}
var pairs currency.Pairs
for x := range channelsToSubscribe {
if pairs.Contains(channelsToSubscribe[x].Pair, true) {
continue
}
pairs = append(pairs, channelsToSubscribe[x].Pair)
}
fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot)
if err != nil {
return err
}
subs := make([]wsSubscriptions, len(channels))
for x := range channels {
subs[x] = wsSubscriptions{
Name: channels[x],
Symbols: strings.Split(fmtPairs, ","),
}
}
wsSub := wsSubscribeRequest{
Type: "subscribe",
Subscriptions: subs,
}
err = g.Websocket.Conn.SendJSONMessage(wsSub)
if err != nil {
return err
}
g.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe...)
return nil
func (g *Gemini) Subscribe(subs subscription.List) error {
return g.manageSubs(subs, wsSubscribeOp)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (g *Gemini) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
channels := make([]string, 0, len(channelsToUnsubscribe))
for x := range channelsToUnsubscribe {
if common.StringDataCompareInsensitive(channels, channelsToUnsubscribe[x].Channel) {
continue
}
channels = append(channels, channelsToUnsubscribe[x].Channel)
}
func (g *Gemini) Unsubscribe(subs subscription.List) error {
return g.manageSubs(subs, wsUnsubscribeOp)
}
var pairs currency.Pairs
for x := range channelsToUnsubscribe {
if pairs.Contains(channelsToUnsubscribe[x].Pair, true) {
continue
}
pairs = append(pairs, channelsToUnsubscribe[x].Pair)
}
fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot)
func (g *Gemini) manageSubs(subs subscription.List, op wsSubOp) error {
format, err := g.GetPairFormat(asset.Spot, true)
if err != nil {
return err
}
subs := make([]wsSubscriptions, len(channels))
for x := range channels {
subs[x] = wsSubscriptions{
Name: channels[x],
Symbols: strings.Split(fmtPairs, ","),
}
req := wsSubscribeRequest{
Type: op,
Subscriptions: make([]wsSubscriptions, 0, len(subs)),
}
for _, s := range subs {
req.Subscriptions = append(req.Subscriptions, wsSubscriptions{
Name: s.Channel,
Symbols: s.Pairs.Format(format).Strings(),
})
}
wsSub := wsSubscribeRequest{
Type: "unsubscribe",
Subscriptions: subs,
}
err = g.Websocket.Conn.SendJSONMessage(wsSub)
if err != nil {
if err := g.Websocket.Conn.SendJSONMessage(req); err != nil {
return err
}
g.Websocket.RemoveSubscriptions(channelsToUnsubscribe...)
return nil
if op == wsUnsubscribeOp {
return g.Websocket.RemoveSubscriptions(subs...)
}
return g.Websocket.AddSuccessfulSubscriptions(subs...)
}
// WsAuth will connect to Gemini's secure endpoint

View File

@@ -292,24 +292,23 @@ type ResponseError struct {
Message string `json:"message"`
}
// WsRequest defines a request obj for the JSON-RPC and gets a websocket
// response
// WsRequest defines a request obj for the JSON-RPC and gets a websocket response
type WsRequest struct {
Method string `json:"method"`
Params Params `json:"params,omitempty"`
ID int64 `json:"id"`
Method string `json:"method"`
Params WsParams `json:"params,omitempty"`
ID int64 `json:"id"`
}
// WsNotification defines a notification obj for the JSON-RPC this does not get
// a websocket response
type WsNotification struct {
JSONRPCVersion string `json:"jsonrpc,omitempty"`
Method string `json:"method"`
Params Params `json:"params"`
JSONRPCVersion string `json:"jsonrpc,omitempty"`
Method string `json:"method"`
Params WsParams `json:"params"`
}
// Params is params
type Params struct {
// WsParams are websocket params for a request
type WsParams struct {
Symbol string `json:"symbol,omitempty"`
Period string `json:"period,omitempty"`
Limit int64 `json:"limit,omitempty"`

View File

@@ -466,33 +466,33 @@ func (h *HitBTC) WsProcessOrderbookUpdate(update *WsOrderbook) error {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (h *HitBTC) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var channels = []string{"subscribeTicker",
"subscribeOrderbook",
"subscribeTrades",
"subscribeCandles"}
var subscriptions []subscription.Subscription
if h.Websocket.CanUseAuthenticatedEndpoints() {
subscriptions = append(subscriptions, subscription.Subscription{
Channel: "subscribeReports",
})
func (h *HitBTC) GenerateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{
"Ticker",
"Orderbook",
"Trades",
"Candles",
}
enabledCurrencies, err := h.GetEnabledPairs(asset.Spot)
var subscriptions subscription.List
if h.Websocket.CanUseAuthenticatedEndpoints() {
subscriptions = append(subscriptions, &subscription.Subscription{Channel: "Reports"})
}
pairs, err := h.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
pairFmt, err := h.GetPairFormat(asset.Spot, true)
if err != nil {
return nil, err
}
pairFmt.Delimiter = ""
pairs = pairs.Format(pairFmt)
for i := range channels {
for j := range enabledCurrencies {
fPair, err := h.FormatExchangeCurrency(enabledCurrencies[j], asset.Spot)
if err != nil {
return nil, err
}
enabledCurrencies[j].Delimiter = ""
subscriptions = append(subscriptions, subscription.Subscription{
for j := range pairs {
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[i],
Pair: fPair,
Pairs: currency.Pairs{pairs[j]},
Asset: asset.Spot,
})
}
@@ -501,70 +501,74 @@ func (h *HitBTC) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
// Subscribe sends a websocket message to receive data from the channel
func (h *HitBTC) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (h *HitBTC) Subscribe(channelsToSubscribe subscription.List) error {
var errs error
for i := range channelsToSubscribe {
subscribe := WsRequest{
Method: channelsToSubscribe[i].Channel,
for _, s := range channelsToSubscribe {
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
pair := s.Pairs[0]
r := WsRequest{
Method: "subscribe" + s.Channel,
ID: h.Websocket.Conn.GenerateMessageID(false),
Params: WsParams{
Symbol: pair.String(),
},
}
switch s.Channel {
case "Trades":
r.Params.Limit = 100
case "Candles":
r.Params.Period = "M30"
r.Params.Limit = 100
}
if channelsToSubscribe[i].Pair.String() != "" {
subscribe.Params.Symbol = channelsToSubscribe[i].Pair.String()
err := h.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = h.Websocket.AddSuccessfulSubscriptions(s)
}
if strings.EqualFold(channelsToSubscribe[i].Channel, "subscribeTrades") {
subscribe.Params.Limit = 100
} else if strings.EqualFold(channelsToSubscribe[i].Channel, "subscribeCandles") {
subscribe.Params.Period = "M30"
subscribe.Params.Limit = 100
}
err := h.Websocket.Conn.SendJSONMessage(subscribe)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
}
if errs != nil {
return errs
}
return nil
return errs
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (h *HitBTC) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (h *HitBTC) Unsubscribe(subs subscription.List) error {
var errs error
for i := range channelsToUnsubscribe {
unsubscribeChannel := strings.Replace(channelsToUnsubscribe[i].Channel,
"subscribe",
"unsubscribe",
1)
for _, s := range subs {
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
pair := s.Pairs[0]
unsubscribe := WsNotification{
r := WsNotification{
JSONRPCVersion: rpcVersion,
Method: unsubscribeChannel,
Method: "unsubscribe" + s.Channel,
Params: WsParams{
Symbol: pair.String(),
},
}
unsubscribe.Params.Symbol = channelsToUnsubscribe[i].Pair.String()
if strings.EqualFold(unsubscribeChannel, "unsubscribeTrades") {
unsubscribe.Params.Limit = 100
} else if strings.EqualFold(unsubscribeChannel, "unsubscribeCandles") {
unsubscribe.Params.Period = "M30"
unsubscribe.Params.Limit = 100
switch s.Channel {
case "Trades":
r.Params.Limit = 100
case "Candles":
r.Params.Period = "M30"
r.Params.Limit = 100
}
err := h.Websocket.Conn.SendJSONMessage(unsubscribe)
err := h.Websocket.Conn.SendJSONMessage(r)
if err == nil {
err = h.Websocket.RemoveSubscriptions(s)
}
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i])
}
if errs != nil {
return errs
}
return nil
return errs
}
// Unsubscribe sends a websocket message to stop receiving data from the channel

View File

@@ -515,15 +515,15 @@ func (h *HUOBI) WsProcessOrderbook(update *WsDepth, symbol string) error {
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (h *HUOBI) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (h *HUOBI) GenerateDefaultSubscriptions() (subscription.List, error) {
var channels = []string{wsMarketKline,
wsMarketDepth,
wsMarketTrade,
wsMarketTicker}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
if h.Websocket.CanUseAuthenticatedEndpoints() {
channels = append(channels, "orders.%v", "orders.%v.update")
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: "accounts",
})
}
@@ -536,9 +536,9 @@ func (h *HUOBI) GenerateDefaultSubscriptions() ([]subscription.Subscription, err
enabledCurrencies[j].Delimiter = ""
channel := fmt.Sprintf(channels[i],
enabledCurrencies[j].Lower().String())
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channel,
Pair: enabledCurrencies[j],
Pairs: currency.Pairs{enabledCurrencies[j]},
})
}
}
@@ -546,7 +546,7 @@ func (h *HUOBI) GenerateDefaultSubscriptions() ([]subscription.Subscription, err
}
// Subscribe sends a websocket message to receive data from the channel
func (h *HUOBI) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (h *HUOBI) Subscribe(channelsToSubscribe subscription.List) error {
var creds *account.Credentials
if h.Websocket.CanUseAuthenticatedEndpoints() {
var err error
@@ -557,36 +557,30 @@ func (h *HUOBI) Subscribe(channelsToSubscribe []subscription.Subscription) error
}
var errs error
for i := range channelsToSubscribe {
var err error
if (strings.Contains(channelsToSubscribe[i].Channel, "orders.") ||
strings.Contains(channelsToSubscribe[i].Channel, "accounts")) && creds != nil {
err := h.wsAuthenticatedSubscribe(creds,
err = h.wsAuthenticatedSubscribe(creds,
"sub",
wsAccountsOrdersEndPoint+channelsToSubscribe[i].Channel,
channelsToSubscribe[i].Channel)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
continue
} else {
err = h.Websocket.Conn.SendJSONMessage(WsRequest{
Subscribe: channelsToSubscribe[i].Channel,
})
}
if err == nil {
err = h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
}
err := h.Websocket.Conn.SendJSONMessage(WsRequest{
Subscribe: channelsToSubscribe[i].Channel,
})
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.AddSuccessfulSubscriptions(channelsToSubscribe[i])
}
if errs != nil {
return errs
}
return nil
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (h *HUOBI) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (h *HUOBI) Unsubscribe(channelsToUnsubscribe subscription.List) error {
var creds *account.Credentials
if h.Websocket.CanUseAuthenticatedEndpoints() {
var err error
@@ -597,32 +591,26 @@ func (h *HUOBI) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) e
}
var errs error
for i := range channelsToUnsubscribe {
var err error
if (strings.Contains(channelsToUnsubscribe[i].Channel, "orders.") ||
strings.Contains(channelsToUnsubscribe[i].Channel, "accounts")) && creds != nil {
err := h.wsAuthenticatedSubscribe(creds,
err = h.wsAuthenticatedSubscribe(creds,
"unsub",
wsAccountsOrdersEndPoint+channelsToUnsubscribe[i].Channel,
channelsToUnsubscribe[i].Channel)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i])
continue
} else {
err = h.Websocket.Conn.SendJSONMessage(WsRequest{
Unsubscribe: channelsToUnsubscribe[i].Channel,
})
}
if err == nil {
err = h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i])
}
err := h.Websocket.Conn.SendJSONMessage(WsRequest{
Unsubscribe: channelsToUnsubscribe[i].Channel,
})
if err != nil {
errs = common.AppendError(errs, err)
continue
}
h.Websocket.RemoveSubscriptions(channelsToUnsubscribe[i])
}
if errs != nil {
return errs
}
return nil
return errs
}
func (h *HUOBI) wsGenerateSignature(creds *account.Credentials, timestamp, endpoint string) ([]byte, error) {

View File

@@ -70,9 +70,9 @@ type IBotExchange interface {
EnableRateLimiter() error
GetServerTime(ctx context.Context, ai asset.Item) (time.Time, error)
GetWebsocket() (*stream.Websocket, error)
SubscribeToWebsocketChannels(channels []subscription.Subscription) error
UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error
GetSubscriptions() ([]subscription.Subscription, error)
SubscribeToWebsocketChannels(channels subscription.List) error
UnsubscribeToWebsocketChannels(channels subscription.List) error
GetSubscriptions() (subscription.List, error)
FlushWebsocketChannels() error
AuthenticateWebsocket(ctx context.Context) error
GetOrderExecutionLimits(a asset.Item, cp currency.Pair) (order.MinMaxLevel, error)

View File

@@ -1216,10 +1216,10 @@ func setupWsTests(t *testing.T) {
// TestWebsocketSubscribe tests returning a message with an id
func TestWebsocketSubscribe(t *testing.T) {
setupWsTests(t)
err := k.Subscribe([]subscription.Subscription{
err := k.Subscribe(subscription.List{
{
Channel: defaultSubscribedChannels[0],
Pair: currency.NewPairWithDelimiter("XBT", "USD", "/"),
Pairs: currency.Pairs{currency.NewPairWithDelimiter("XBT", "USD", "/")},
},
})
if err != nil {

View File

@@ -500,11 +500,11 @@ type WithdrawStatusResponse struct {
// WebsocketSubscriptionEventRequest handles WS subscription events
type WebsocketSubscriptionEventRequest struct {
Event string `json:"event"` // subscribe
RequestID int64 `json:"reqid,omitempty"` // Optional, client originated ID reflected in response message.
Pairs []string `json:"pair,omitempty"` // Array of currency pairs (pair1,pair2,pair3).
Subscription WebsocketSubscriptionData `json:"subscription,omitempty"`
Channels []subscription.Subscription `json:"-"` // Keeps track of associated subscriptions in batched outgoings
Event string `json:"event"` // subscribe
RequestID int64 `json:"reqid,omitempty"` // Optional, client originated ID reflected in response message.
Pairs []string `json:"pair,omitempty"` // Array of currency pairs (pair1,pair2,pair3).
Subscription WebsocketSubscriptionData `json:"subscription,omitempty"`
Channels subscription.List `json:"-"` // Keeps track of associated subscriptions in batched outgoings
}
// WebsocketBaseEventRequest Just has an "event" property

View File

@@ -809,7 +809,7 @@ func (k *Kraken) wsProcessOrderBook(channelData *WebsocketChannelData, data map[
}
}(&subscription.Subscription{
Channel: krakenWsOrderbook,
Pair: outbound,
Pairs: currency.Pairs{outbound},
Asset: asset.Spot,
})
return err
@@ -1160,25 +1160,25 @@ func (k *Kraken) wsProcessCandles(channelData *WebsocketChannelData, data []inte
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (k *Kraken) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (k *Kraken) GenerateDefaultSubscriptions() (subscription.List, error) {
enabledPairs, err := k.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
var subscriptions []subscription.Subscription
var subscriptions subscription.List
for i := range defaultSubscribedChannels {
for j := range enabledPairs {
enabledPairs[j].Delimiter = "/"
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: defaultSubscribedChannels[i],
Pair: enabledPairs[j],
Pairs: currency.Pairs{enabledPairs[j]},
Asset: asset.Spot,
})
}
}
if k.Websocket.CanUseAuthenticatedEndpoints() {
for i := range authenticatedChannels {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: authenticatedChannels[i],
})
}
@@ -1187,7 +1187,7 @@ func (k *Kraken) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
// Subscribe sends a websocket message to receive data from the channel
func (k *Kraken) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (k *Kraken) Subscribe(channelsToSubscribe subscription.List) error {
var subscriptions = make(map[string]*[]WebsocketSubscriptionEventRequest)
channels:
for i := range channelsToSubscribe {
@@ -1198,7 +1198,7 @@ channels:
}
for j := range *s {
(*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pair.String())
(*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pairs.Strings()...)
(*s)[j].Channels = append((*s)[j].Channels, channelsToSubscribe[i])
continue channels
}
@@ -1214,8 +1214,8 @@ channels:
if channelsToSubscribe[i].Channel == "book" {
outbound.Subscription.Depth = krakenWsOrderbookDepth
}
if !channelsToSubscribe[i].Pair.IsEmpty() {
outbound.Pairs = []string{channelsToSubscribe[i].Pair.String()}
for _, p := range channelsToSubscribe[i].Pairs {
outbound.Pairs = append(outbound.Pairs, p.String())
}
if common.StringDataContains(authenticatedChannels, channelsToSubscribe[i].Channel) {
outbound.Subscription.Token = authToken
@@ -1228,37 +1228,32 @@ channels:
var errs error
for _, subs := range subscriptions {
for i := range *subs {
var err error
if common.StringDataContains(authenticatedChannels, (*subs)[i].Subscription.Name) {
_, err := k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
if err != nil {
errs = common.AppendError(errs, err)
continue
}
k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...)
continue
_, err = k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
} else {
_, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
}
if err == nil {
err = k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...)
}
_, err := k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
if err != nil {
errs = common.AppendError(errs, err)
continue
}
k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...)
}
}
return errs
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (k *Kraken) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (k *Kraken) Unsubscribe(channelsToUnsubscribe subscription.List) error {
var unsubs []WebsocketSubscriptionEventRequest
channels:
for x := range channelsToUnsubscribe {
for y := range unsubs {
if unsubs[y].Subscription.Name == channelsToUnsubscribe[x].Channel {
unsubs[y].Pairs = append(unsubs[y].Pairs,
channelsToUnsubscribe[x].Pair.String())
unsubs[y].Channels = append(unsubs[y].Channels,
channelsToUnsubscribe[x])
unsubs[y].Pairs = append(unsubs[y].Pairs, channelsToUnsubscribe[x].Pairs.Strings()...)
unsubs[y].Channels = append(unsubs[y].Channels, channelsToUnsubscribe[x])
continue channels
}
}
@@ -1276,7 +1271,7 @@ channels:
unsub := WebsocketSubscriptionEventRequest{
Event: krakenWsUnsubscribe,
Pairs: []string{channelsToUnsubscribe[x].Pair.String()},
Pairs: []string{channelsToUnsubscribe[x].Pairs[0].String()},
Subscription: WebsocketSubscriptionData{
Name: channelsToUnsubscribe[x].Channel,
Depth: depth,
@@ -1292,22 +1287,18 @@ channels:
var errs error
for i := range unsubs {
var err error
if common.StringDataContains(authenticatedChannels, unsubs[i].Subscription.Name) {
_, err := k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
if err != nil {
errs = common.AppendError(errs, err)
continue
}
k.Websocket.RemoveSubscriptions(unsubs[i].Channels...)
continue
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
} else {
_, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
}
if err == nil {
err = k.Websocket.RemoveSubscriptions(unsubs[i].Channels...)
}
_, err := k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
if err != nil {
errs = common.AppendError(errs, err)
continue
}
k.Websocket.RemoveSubscriptions(unsubs[i].Channels...)
}
return errs
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/key"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/core"
"github.com/thrasher-corp/gocryptotrader/currency"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
@@ -26,7 +25,6 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/margin"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/ticker"
testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange"
@@ -41,45 +39,27 @@ const (
canManipulateRealOrders = false
)
var (
ku = &Kucoin{}
spotTradablePair, marginTradablePair, futuresTradablePair currency.Pair
)
var ku *Kucoin
var spotTradablePair, marginTradablePair, futuresTradablePair currency.Pair
func TestMain(m *testing.M) {
ku.SetDefaults()
cfg := config.GetConfig()
err := cfg.LoadConfig("../../testdata/configtest.json", true)
if err != nil {
ku = new(Kucoin)
if err := testexch.Setup(ku); err != nil {
log.Fatal(err)
}
exchCfg, err := cfg.GetExchangeConfig("Kucoin")
if err != nil {
log.Fatal(err)
}
exchCfg.API.AuthenticatedSupport = true
exchCfg.API.AuthenticatedWebsocketSupport = true
exchCfg.API.Credentials.Key = apiKey
exchCfg.API.Credentials.Secret = apiSecret
exchCfg.API.Credentials.ClientID = passPhrase
if apiKey != "" && apiSecret != "" && passPhrase != "" {
ku.API.AuthenticatedSupport = true
ku.API.AuthenticatedWebsocketSupport = true
ku.API.CredentialsValidator.RequiresBase64DecodeSecret = false
ku.SetCredentials(apiKey, apiSecret, passPhrase, "", "", "")
ku.Websocket.SetCanUseAuthenticatedEndpoints(true)
}
ku.SetDefaults()
ku.Websocket = sharedtestvalues.NewTestWebsocket()
ku.Websocket.Orderbook = buffer.Orderbook{}
err = ku.Setup(exchCfg)
if err != nil {
log.Fatal(err)
}
ku.Websocket.DataHandler = sharedtestvalues.GetWebsocketInterfaceChannelOverride()
ku.Websocket.TrafficAlert = sharedtestvalues.GetWebsocketStructChannelOverride()
setupWS()
getFirstTradablePairOfAssets()
ku.setupOrderbookManager()
fetchedFuturesSnapshotOrderbook = map[string]bool{}
os.Exit(m.Run())
}
@@ -1993,10 +1973,10 @@ func TestPushData(t *testing.T) {
testexch.FixtureToDataHandler(t, "testdata/wsHandleData.json", ku.wsHandleData)
}
func verifySubs(tb testing.TB, subs []subscription.Subscription, a asset.Item, prefix string, expected ...string) {
func verifySubs(tb testing.TB, subs subscription.List, a asset.Item, prefix string, expected ...string) {
tb.Helper()
var sub *subscription.Subscription
for i, s := range subs { //nolint:gocritic // prefer convenience over performance here for tests
for i, s := range subs {
if s.Asset == a && strings.HasPrefix(s.Channel, prefix) {
if len(expected) == 1 && !strings.Contains(s.Channel, expected[0]) {
continue
@@ -2005,7 +1985,7 @@ func verifySubs(tb testing.TB, subs []subscription.Subscription, a asset.Item, p
assert.Failf(tb, "Too many subs with prefix", "Asset %s; Prefix %s", a.String(), prefix)
return
}
sub = &subs[i]
sub = subs[i]
}
}
if assert.NotNil(tb, sub, "Should find a sub for asset %s with prefix %s for %s", a.String(), prefix, strings.Join(expected, ", ")) {
@@ -2024,12 +2004,11 @@ func verifySubs(tb testing.TB, subs []subscription.Subscription, a asset.Item, p
// In Both: ETH-BTC, LTC-USDT
// Only in Margin: TRX-BTC, SOL-USDC
func TestGenerateDefaultSubscriptions(t *testing.T) {
func TestGenerateSubscriptions(t *testing.T) {
t.Parallel()
subs, err := ku.GenerateDefaultSubscriptions()
assert.NoError(t, err, "GenerateDefaultSubscriptions should not error")
subs, err := ku.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions must not error")
assert.Len(t, subs, 11, "Should generate the correct number of subs when not logged in")
@@ -2053,8 +2032,8 @@ func TestGenerateAuthSubscriptions(t *testing.T) {
ku := testInstance(t) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
ku.Websocket.SetCanUseAuthenticatedEndpoints(true)
subs, err := ku.GenerateDefaultSubscriptions()
assert.NoError(t, err, "GenerateDefaultSubscriptions with Auth should not error")
subs, err := ku.generateSubscriptions()
require.NoError(t, err, "generateSubscriptions with Auth must not error")
assert.Len(t, subs, 24, "Should generate the correct number of subs when logged in")
verifySubs(t, subs, asset.Spot, "/market/ticker:all") // This takes care of margin as well.
@@ -2084,12 +2063,12 @@ func TestGenerateCandleSubscription(t *testing.T) {
t.Parallel()
ku := testInstance(t) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
ku.Features.Subscriptions = []*subscription.Subscription{
ku.Features.Subscriptions = subscription.List{
{Channel: subscription.CandlesChannel, Interval: kline.FourHour},
}
subs, err := ku.GenerateDefaultSubscriptions()
assert.NoError(t, err, "GenerateDefaultSubscriptions with Candles should not error")
subs, err := ku.generateSubscriptions()
assert.NoError(t, err, "generateSubscriptions with Candles should not error")
assert.Len(t, subs, 6, "Should generate the correct number of subs for candles")
for _, c := range []string{"BTC-USDT", "ETH-USDT", "LTC-USDT", "ETH-BTC"} {
@@ -2104,12 +2083,12 @@ func TestGenerateMarketSubscription(t *testing.T) {
t.Parallel()
ku := testInstance(t) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
ku.Features.Subscriptions = []*subscription.Subscription{
ku.Features.Subscriptions = subscription.List{
{Channel: marketSnapshotChannel},
}
subs, err := ku.GenerateDefaultSubscriptions()
assert.NoError(t, err, "GenerateDefaultSubscriptions with MarketSnapshot should not error")
subs, err := ku.generateSubscriptions()
assert.NoError(t, err, "generateSubscriptions with MarketSnapshot should not error")
assert.Len(t, subs, 7, "Should generate the correct number of subs for snapshot")
for _, c := range []string{"BTC", "ETH", "LTC", "USDT"} {
@@ -2370,19 +2349,6 @@ func TestGetPaginatedListOfSubAccounts(t *testing.T) {
}
}
func setupWS() {
if !ku.Websocket.IsEnabled() {
return
}
if !sharedtestvalues.AreAPICredentialsSet(ku) {
ku.Websocket.SetCanUseAuthenticatedEndpoints(false)
}
err := ku.WsConnect()
if err != nil {
log.Fatal(err)
}
}
func TestGetFundingHistory(t *testing.T) {
t.Parallel()
sharedtestvalues.SkipTestIfCredentialsUnset(t, ku)
@@ -2526,8 +2492,11 @@ func TestProcessMarketSnapshot(t *testing.T) {
func TestSubscribeMarketSnapshot(t *testing.T) {
t.Parallel()
setupWS()
err := ku.Subscribe([]subscription.Subscription{{Channel: marketSymbolSnapshotChannel, Pair: currency.Pair{Base: currency.BTC}}})
ku := testInstance(t) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes
testexch.SetupWs(t, ku)
err := ku.Subscribe(subscription.List{{Channel: marketSymbolSnapshotChannel, Pairs: currency.Pairs{currency.Pair{Base: currency.BTC}}}})
assert.NoError(t, err, "Subscribe to MarketSnapshot should not error")
}

View File

@@ -950,49 +950,50 @@ func (ku *Kucoin) processMarketSnapshot(respData []byte, topic string) error {
}
// Subscribe sends a websocket message to receive data from the channel
func (ku *Kucoin) Subscribe(subscriptions []subscription.Subscription) error {
return ku.handleSubscriptions(subscriptions, "subscribe")
func (ku *Kucoin) Subscribe(subscriptions subscription.List) error {
return ku.manageSubscriptions(subscriptions, "subscribe")
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (ku *Kucoin) Unsubscribe(subscriptions []subscription.Subscription) error {
return ku.handleSubscriptions(subscriptions, "unsubscribe")
func (ku *Kucoin) Unsubscribe(subscriptions subscription.List) error {
return ku.manageSubscriptions(subscriptions, "unsubscribe")
}
func (ku *Kucoin) expandManualSubscriptions(in []subscription.Subscription) ([]subscription.Subscription, error) {
subs := make([]subscription.Subscription, 0, len(in))
for i := range in {
if isSymbolChannel(in[i].Channel) {
if in[i].Pair.IsEmpty() {
// expandManualSubscription takes a subscription list and expand all the subscriptions across the relevant assets and pairs
func (ku *Kucoin) expandManualSubscriptions(in subscription.List) (subscription.List, error) {
subs := make(subscription.List, 0, len(in))
for _, s := range in {
if isSymbolChannel(s.Channel) {
if len(s.Pairs) == 0 {
return nil, errSubscriptionPairRequired
}
a := in[i].Asset
a := s.Asset
if !a.IsValid() {
a = getChannelsAssetType(in[i].Channel)
a = getChannelsAssetType(s.Channel)
}
assetPairs := map[asset.Item]currency.Pairs{a: {in[i].Pair}}
n, err := ku.expandSubscription(&in[i], assetPairs)
assetPairs := map[asset.Item]currency.Pairs{a: s.Pairs}
n, err := ku.expandSubscription(s, assetPairs)
if err != nil {
return nil, err
}
subs = append(subs, n...)
} else {
subs = append(subs, in[i])
subs = append(subs, s)
}
}
return subs, nil
}
func (ku *Kucoin) handleSubscriptions(subs []subscription.Subscription, operation string) error {
func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string) error {
var errs error
subs, errs = ku.expandManualSubscriptions(subs)
for i := range subs {
for _, s := range subs {
msgID := strconv.FormatInt(ku.Websocket.Conn.GenerateMessageID(false), 10)
req := WsSubscriptionInput{
ID: msgID,
Type: operation,
Topic: subs[i].Channel,
PrivateChannel: subs[i].Authenticated,
Topic: s.Channel,
PrivateChannel: s.Authenticated,
Response: true,
}
if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse("msgID:"+msgID, req); err != nil {
@@ -1005,9 +1006,16 @@ func (ku *Kucoin) handleSubscriptions(subs []subscription.Subscription, operatio
case rType != "ack":
errs = common.AppendError(errs, fmt.Errorf("%w: %s from %s", errInvalidMsgType, rType, respRaw))
default:
ku.Websocket.AddSuccessfulSubscriptions(subs[i])
if ku.Verbose {
log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s", ku.Name, subs[i].Channel)
if operation == "unsubscribe" {
err = ku.Websocket.RemoveSubscriptions(s)
} else {
err = ku.Websocket.AddSuccessfulSubscriptions(s)
if ku.Verbose {
log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s", ku.Name, s.Channel)
}
}
if err != nil {
errs = common.AppendError(errs, err)
}
}
}
@@ -1034,8 +1042,8 @@ func getChannelsAssetType(channelName string) asset.Item {
}
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket.
func (ku *Kucoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
// generateSubscriptions returns a list of subscriptions from the configured subscriptions feature
func (ku *Kucoin) generateSubscriptions() (subscription.List, error) {
assetPairs := map[asset.Item]currency.Pairs{}
for _, a := range ku.GetAssetTypes(false) {
if p, err := ku.GetEnabledPairs(a); err == nil {
@@ -1045,7 +1053,7 @@ func (ku *Kucoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
}
}
authed := ku.Websocket.CanUseAuthenticatedEndpoints()
subscriptions := []subscription.Subscription{}
subscriptions := subscription.List{}
for _, s := range ku.Features.Subscriptions {
if !authed && s.Authenticated {
continue
@@ -1060,12 +1068,12 @@ func (ku *Kucoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, e
}
// expandSubscription takes a subscription and expands it across the relevant assets and pairs passed in
func (ku *Kucoin) expandSubscription(baseSub *subscription.Subscription, assetPairs map[asset.Item]currency.Pairs) ([]subscription.Subscription, error) {
var subscriptions = []subscription.Subscription{}
func (ku *Kucoin) expandSubscription(baseSub *subscription.Subscription, assetPairs map[asset.Item]currency.Pairs) (subscription.List, error) {
var subscriptions = subscription.List{}
if baseSub == nil {
return nil, common.ErrNilPointer
}
s := *baseSub
s := baseSub.Clone()
s.Channel = channelName(s.Channel)
if !s.Asset.IsValid() {
s.Asset = getChannelsAssetType(s.Channel)
@@ -1078,7 +1086,7 @@ func (ku *Kucoin) expandSubscription(baseSub *subscription.Subscription, assetPa
switch {
case s.Channel == marginLoanChannel:
for _, c := range assetPairs[asset.Margin].GetCurrencies() {
i := s
i := s.Clone()
i.Channel = fmt.Sprintf(s.Channel, c)
subscriptions = append(subscriptions, i)
}
@@ -1087,13 +1095,13 @@ func (ku *Kucoin) expandSubscription(baseSub *subscription.Subscription, assetPa
if err != nil {
return nil, err
}
subs := spotOrMarginPairSubs(assetPairs, &s, false, interval)
subs := spotOrMarginPairSubs(assetPairs, s, false, interval)
subscriptions = append(subscriptions, subs...)
case s.Channel == marginFundingbookChangeChannel:
s.Channel = fmt.Sprintf(s.Channel, assetPairs[asset.Margin].GetCurrencies().Join())
subscriptions = append(subscriptions, s)
case s.Channel == marketSnapshotChannel:
subs, err := spotOrMarginCurrencySubs(assetPairs, &s)
subs, err := spotOrMarginCurrencySubs(assetPairs, s)
if err != nil {
return nil, err
}
@@ -1104,13 +1112,13 @@ func (ku *Kucoin) expandSubscription(baseSub *subscription.Subscription, assetPa
if err != nil {
continue
}
i := s
i := s.Clone()
i.Channel = fmt.Sprintf(s.Channel, c)
subscriptions = append(subscriptions, i)
}
case isSymbolChannel(s.Channel):
// Subscriptions which can use a single comma-separated sub per asset
subs := spotOrMarginPairSubs(assetPairs, &s, true)
subs := spotOrMarginPairSubs(assetPairs, s, true)
subscriptions = append(subscriptions, subs...)
default:
subscriptions = append(subscriptions, s)
@@ -1135,21 +1143,23 @@ func channelName(name string) string {
// spotOrMarginPairSubs accepts a map of pairs and a template subscription and returns a list of subscriptions for Spot and Margin pairs
// If there's a Spot subscription, it won't be added again as a Margin subscription
// If joined param is true then one subscription per asset type with the currencies comma delimited
func spotOrMarginPairSubs(assetPairs map[asset.Item]currency.Pairs, b *subscription.Subscription, join bool, fmtArgs ...any) []subscription.Subscription {
subs := []subscription.Subscription{}
func spotOrMarginPairSubs(assetPairs map[asset.Item]currency.Pairs, b *subscription.Subscription, join bool, fmtArgs ...any) subscription.List {
subs := subscription.List{}
add := func(a asset.Item, pairs currency.Pairs) {
if len(pairs) == 0 {
return
}
s := *b
s.Asset = a
if join {
f := append([]any{pairs.Join()}, fmtArgs...)
s := b.Clone()
s.Asset = a
s.Channel = fmt.Sprintf(b.Channel, f...)
subs = append(subs, s)
} else {
for i := range pairs {
f := append([]any{pairs[i].String()}, fmtArgs...)
s := b.Clone()
s.Asset = a
s.Channel = fmt.Sprintf(b.Channel, f...)
subs = append(subs, s)
}
@@ -1171,18 +1181,18 @@ func spotOrMarginPairSubs(assetPairs map[asset.Item]currency.Pairs, b *subscript
// spotOrMarginCurrencySubs accepts a map of pairs and a template subscription and returns a list of subscriptions for every currency in Spot and Margin pairs
// If there's a Spot subscription, it won't be added again as a Margin subscription
func spotOrMarginCurrencySubs(assetPairs map[asset.Item]currency.Pairs, b *subscription.Subscription) ([]subscription.Subscription, error) {
func spotOrMarginCurrencySubs(assetPairs map[asset.Item]currency.Pairs, b *subscription.Subscription) (subscription.List, error) {
if b == nil {
return nil, common.ErrNilPointer
}
subs := []subscription.Subscription{}
subs := subscription.List{}
add := func(a asset.Item, currs currency.Currencies) {
if len(currs) == 0 {
return
}
s := *b
s.Asset = a
for _, c := range currs {
s := b.Clone()
s.Asset = a
s.Channel = fmt.Sprintf(b.Channel, c)
subs = append(subs, s)
}

View File

@@ -141,7 +141,7 @@ func (ku *Kucoin) SetDefaults() {
GlobalResultLimit: 1500,
},
},
Subscriptions: []*subscription.Subscription{
Subscriptions: subscription.List{
// Where we can we use generic names
{Enabled: true, Channel: subscription.TickerChannel}, // marketTickerChannel
{Enabled: true, Channel: subscription.AllTradesChannel}, // marketMatchChannel
@@ -206,7 +206,7 @@ func (ku *Kucoin) Setup(exch *config.Exchange) error {
Connector: ku.WsConnect,
Subscriber: ku.Subscribe,
Unsubscriber: ku.Unsubscribe,
GenerateSubscriptions: ku.GenerateDefaultSubscriptions,
GenerateSubscriptions: ku.generateSubscriptions,
Features: &ku.Features.Supports.WebsocketCapabilities,
OrderbookBufferConfig: buffer.Config{
SortBuffer: true,

View File

@@ -582,9 +582,9 @@ func (o *Okcoin) wsProcessOrderbook(respRaw []byte, obChannel string) error {
// ReSubscribeSpecificOrderbook removes the subscription and the subscribes
// again to fetch a new snapshot in the event of a de-sync event.
func (o *Okcoin) ReSubscribeSpecificOrderbook(obChannel string, p currency.Pair) error {
subscription := []subscription.Subscription{{
subscription := subscription.List{{
Channel: obChannel,
Pair: p,
Pairs: currency.Pairs{p},
}}
if err := o.Unsubscribe(subscription); err != nil {
return err
@@ -764,8 +764,8 @@ func (o *Okcoin) CalculateOrderbookUpdateChecksum(orderbookData *orderbook.Base)
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be
// handled by ManageSubscriptions()
func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (o *Okcoin) GenerateDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
pairs, err := o.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
@@ -788,7 +788,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
for s := range channels {
switch channels[s] {
case wsInstruments:
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
Asset: asset.Spot,
})
@@ -799,20 +799,20 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
wsCandle5m, wsCandle3m, wsCandle1m, wsCandle3Mutc, wsCandle1Mutc, wsCandle1Wutc, wsCandle1Dutc,
wsCandle2Dutc, wsCandle3Dutc, wsCandle5Dutc, wsCandle12Hutc, wsCandle6Hutc:
for p := range pairs {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
Pair: pairs[p],
Pairs: currency.Pairs{pairs[p]},
})
}
case wsStatus:
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
})
case wsAccount:
currenciesMap := map[currency.Code]bool{}
for p := range pairs {
if reserved, okay := currenciesMap[pairs[p].Base]; !okay && !reserved {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
Params: map[string]interface{}{
"ccy": pairs[p].Base,
@@ -823,7 +823,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
for p := range pairs {
if reserved, okay := currenciesMap[pairs[p].Quote]; !okay && !reserved {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
Params: map[string]interface{}{
"ccy": pairs[p].Quote,
@@ -834,9 +834,9 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
case wsOrder, wsOrdersAlgo, wsAlgoAdvance:
for p := range pairs {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: channels[s],
Pair: pairs[p],
Pairs: currency.Pairs{pairs[p]},
Asset: asset.Spot,
})
}
@@ -848,23 +848,23 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er
}
// Subscribe sends a websocket message to receive data from the channel
func (o *Okcoin) Subscribe(channelsToSubscribe []subscription.Subscription) error {
return o.handleSubscriptions("subscribe", channelsToSubscribe)
func (o *Okcoin) Subscribe(channelsToSubscribe subscription.List) error {
return o.manageSubscriptions("subscribe", channelsToSubscribe)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (o *Okcoin) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
return o.handleSubscriptions("unsubscribe", channelsToUnsubscribe)
func (o *Okcoin) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return o.manageSubscriptions("unsubscribe", channelsToUnsubscribe)
}
func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subscription) error {
func (o *Okcoin) manageSubscriptions(operation string, subs subscription.List) error {
subscriptionRequest := WebsocketEventRequest{Operation: operation, Arguments: []map[string]string{}}
authRequest := WebsocketEventRequest{Operation: operation, Arguments: []map[string]string{}}
temp := WebsocketEventRequest{Operation: operation, Arguments: []map[string]string{}}
authTemp := WebsocketEventRequest{Operation: operation, Arguments: []map[string]string{}}
var err error
var channels []subscription.Subscription
var authChannels []subscription.Subscription
var channels subscription.List
var authChannels subscription.List
for i := 0; i < len(subs); i++ {
authenticatedChannelSubscription := isAuthenticatedChannel(subs[i].Channel)
// Temp type to evaluate max byte len after a marshal on batched unsubs
@@ -891,8 +891,11 @@ func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subsc
if subs[i].Asset != asset.Empty {
argument["instType"] = strings.ToUpper(subs[i].Asset.String())
}
if !subs[i].Pair.IsEmpty() {
argument["instId"] = subs[i].Pair.String()
if len(subs[i].Pairs) > 1 {
return subscription.ErrNotSinglePair
}
if len(subs[i].Pairs) == 1 {
argument["instId"] = subs[i].Pairs[0].String()
}
if authenticatedChannelSubscription {
authTemp.Arguments = append(authTemp.Arguments, argument)
@@ -928,17 +931,20 @@ func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subsc
if operation == "unsubscribe" {
if authenticatedChannelSubscription {
o.Websocket.RemoveSubscriptions(authChannels...)
err = o.Websocket.RemoveSubscriptions(authChannels...)
} else {
o.Websocket.RemoveSubscriptions(channels...)
err = o.Websocket.RemoveSubscriptions(channels...)
}
} else {
if authenticatedChannelSubscription {
o.Websocket.AddSuccessfulSubscriptions(authChannels...)
err = o.Websocket.AddSuccessfulSubscriptions(authChannels...)
} else {
o.Websocket.AddSuccessfulSubscriptions(channels...)
err = o.Websocket.AddSuccessfulSubscriptions(channels...)
}
}
if err != nil {
return err
}
// Drop prior unsubs and chunked payload args on successful unsubscription
if authenticatedChannelSubscription {
authChannels = nil
@@ -958,23 +964,19 @@ func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subsc
}
}
if len(subscriptionRequest.Arguments) > 0 {
err = o.Websocket.Conn.SendJSONMessage(subscriptionRequest)
if err != nil {
if err := o.Websocket.Conn.SendJSONMessage(subscriptionRequest); err != nil {
return err
}
}
if len(authRequest.Arguments) > 0 {
err = o.Websocket.AuthConn.SendJSONMessage(authRequest)
if err != nil {
if err := o.Websocket.AuthConn.SendJSONMessage(authRequest); err != nil {
return err
}
}
if operation == "unsubscribe" {
o.Websocket.RemoveSubscriptions(channels...)
} else {
o.Websocket.AddSuccessfulSubscriptions(channels...)
return o.Websocket.RemoveSubscriptions(channels...)
}
return nil
return o.Websocket.AddSuccessfulSubscriptions(channels...)
}
// GetCandlesData represents a candlestick instances list.

View File

@@ -348,29 +348,31 @@ func (ok *Okx) wsReadData(ws stream.Connection) {
}
// Subscribe sends a websocket subscription request to several channels to receive data.
func (ok *Okx) Subscribe(channelsToSubscribe []subscription.Subscription) error {
func (ok *Okx) Subscribe(channelsToSubscribe subscription.List) error {
return ok.handleSubscription(operationSubscribe, channelsToSubscribe)
}
// Unsubscribe sends a websocket unsubscription request to several channels to receive data.
func (ok *Okx) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) error {
func (ok *Okx) Unsubscribe(channelsToUnsubscribe subscription.List) error {
return ok.handleSubscription(operationUnsubscribe, channelsToUnsubscribe)
}
// handleSubscription sends a subscription and unsubscription information thought the websocket endpoint.
// as of the okx, exchange this endpoint sends subscription and unsubscription messages but with a list of json objects.
func (ok *Okx) handleSubscription(operation string, subscriptions []subscription.Subscription) error {
func (ok *Okx) handleSubscription(operation string, subscriptions subscription.List) error {
request := WSSubscriptionInformationList{Operation: operation}
authRequests := WSSubscriptionInformationList{Operation: operation}
ok.WsRequestSemaphore <- 1
defer func() { <-ok.WsRequestSemaphore }()
var channels []subscription.Subscription
var authChannels []subscription.Subscription
var err error
var format currency.PairFormat
var channels subscription.List
var authChannels subscription.List
for i := 0; i < len(subscriptions); i++ {
s := subscriptions[i]
if len(s.Pairs) > 1 {
return subscription.ErrNotSinglePair
}
arg := SubscriptionInfo{
Channel: subscriptions[i].Channel,
Channel: s.Channel,
}
var instrumentID string
var underlying string
@@ -400,12 +402,12 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
}
if arg.Channel == okxChannelGridPositions {
algoID, _ = subscriptions[i].Params["algoId"].(string)
algoID, _ = s.Params["algoId"].(string)
}
if arg.Channel == okcChannelGridSubOrders ||
arg.Channel == okxChannelGridPositions {
uid, _ = subscriptions[i].Params["uid"].(string)
uid, _ = s.Params["uid"].(string)
}
if strings.HasPrefix(arg.Channel, "candle") ||
@@ -416,26 +418,30 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
arg.Channel == okxChannelOrderBooksTBT ||
arg.Channel == okxChannelFundingRate ||
arg.Channel == okxChannelTrades {
if subscriptions[i].Params["instId"] != "" {
instrumentID, okay = subscriptions[i].Params["instId"].(string)
if s.Params["instId"] != "" {
instrumentID, okay = s.Params["instId"].(string)
if !okay {
instrumentID = ""
}
} else if subscriptions[i].Params["instrumentID"] != "" {
instrumentID, okay = subscriptions[i].Params["instrumentID"].(string)
} else if s.Params["instrumentID"] != "" {
instrumentID, okay = s.Params["instrumentID"].(string)
if !okay {
instrumentID = ""
}
}
if instrumentID == "" {
format, err = ok.GetPairFormat(subscriptions[i].Asset, false)
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
format, err := ok.GetPairFormat(s.Asset, false)
if err != nil {
return err
}
if subscriptions[i].Pair.Base.String() == "" || subscriptions[i].Pair.Quote.String() == "" {
p := s.Pairs[0]
if p.Base.String() == "" || p.Quote.String() == "" {
return errIncompleteCurrencyPair
}
instrumentID = format.Format(subscriptions[i].Pair)
instrumentID = format.Format(p)
}
}
if arg.Channel == okxChannelInstruments ||
@@ -447,7 +453,7 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
arg.Channel == okxChannelSpotGridOrder ||
arg.Channel == okxChannelGridOrdersContract ||
arg.Channel == okxChannelEstimatedPrice {
instrumentType = ok.GetInstrumentTypeFromAssetItem(subscriptions[i].Asset)
instrumentType = ok.GetInstrumentTypeFromAssetItem(s.Asset)
}
if arg.Channel == okxChannelPositions ||
@@ -455,7 +461,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
arg.Channel == okxChannelAlgoOrders ||
arg.Channel == okxChannelEstimatedPrice ||
arg.Channel == okxChannelOptSummary {
underlying, _ = ok.GetUnderlying(subscriptions[i].Pair, subscriptions[i].Asset)
if len(s.Pairs) == 1 {
underlying, _ = ok.GetUnderlying(s.Pairs[0], s.Asset)
}
}
arg.InstrumentID = instrumentID
arg.Underlying = underlying
@@ -464,10 +472,9 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
arg.AlgoID = algoID
if authSubscription {
var authChunk []byte
authChannels = append(authChannels, subscriptions[i])
authChannels = append(authChannels, s)
authRequests.Arguments = append(authRequests.Arguments, arg)
authChunk, err = json.Marshal(authRequests)
authChunk, err := json.Marshal(authRequests)
if err != nil {
return err
}
@@ -479,18 +486,20 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
return err
}
if operation == operationUnsubscribe {
ok.Websocket.RemoveSubscriptions(channels...)
err = ok.Websocket.RemoveSubscriptions(channels...)
} else {
ok.Websocket.AddSuccessfulSubscriptions(channels...)
err = ok.Websocket.AddSuccessfulSubscriptions(channels...)
}
authChannels = []subscription.Subscription{}
if err != nil {
return err
}
authChannels = subscription.List{}
authRequests.Arguments = []SubscriptionInfo{}
}
} else {
var chunk []byte
channels = append(channels, subscriptions[i])
channels = append(channels, s)
request.Arguments = append(request.Arguments, arg)
chunk, err = json.Marshal(request)
chunk, err := json.Marshal(request)
if err != nil {
return err
}
@@ -501,41 +510,38 @@ func (ok *Okx) handleSubscription(operation string, subscriptions []subscription
return err
}
if operation == operationUnsubscribe {
ok.Websocket.RemoveSubscriptions(channels...)
err = ok.Websocket.RemoveSubscriptions(channels...)
} else {
ok.Websocket.AddSuccessfulSubscriptions(channels...)
err = ok.Websocket.AddSuccessfulSubscriptions(channels...)
}
channels = []subscription.Subscription{}
if err != nil {
return err
}
channels = subscription.List{}
request.Arguments = []SubscriptionInfo{}
continue
}
}
}
if len(request.Arguments) > 0 {
err = ok.Websocket.Conn.SendJSONMessage(request)
if err != nil {
if err := ok.Websocket.Conn.SendJSONMessage(request); err != nil {
return err
}
}
if len(authRequests.Arguments) > 0 && ok.Websocket.CanUseAuthenticatedEndpoints() {
err = ok.Websocket.AuthConn.SendJSONMessage(authRequests)
if err != nil {
if err := ok.Websocket.AuthConn.SendJSONMessage(authRequests); err != nil {
return err
}
}
if err != nil {
return err
channels = append(channels, authChannels...)
if operation == operationUnsubscribe {
return ok.Websocket.RemoveSubscriptions(channels...)
}
if operation == operationUnsubscribe {
channels = append(channels, authChannels...)
ok.Websocket.RemoveSubscriptions(channels...)
} else {
channels = append(channels, authChannels...)
ok.Websocket.AddSuccessfulSubscriptions(channels...)
}
return nil
return ok.Websocket.AddSuccessfulSubscriptions(channels...)
}
// WsHandleData will read websocket raw data and pass to appropriate handler
@@ -833,11 +839,11 @@ func (ok *Okx) wsProcessOrderBooks(data []byte) error {
}
if err != nil {
if errors.Is(err, errInvalidChecksum) {
err = ok.Subscribe([]subscription.Subscription{
err = ok.Subscribe(subscription.List{
{
Channel: response.Argument.Channel,
Asset: assets[0],
Pair: pair,
Pairs: currency.Pairs{pair},
},
})
if err != nil {
@@ -1294,8 +1300,8 @@ func (ok *Okx) wsProcessTickers(data []byte) error {
}
// GenerateDefaultSubscriptions returns a list of default subscription message.
func (ok *Okx) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
var subscriptions []subscription.Subscription
func (ok *Okx) GenerateDefaultSubscriptions() (subscription.List, error) {
var subscriptions subscription.List
assets := ok.GetAssetTypes(true)
subs := make([]string, 0, len(defaultSubscribedChannels)+len(defaultAuthChannels))
subs = append(subs, defaultSubscribedChannels...)
@@ -1306,7 +1312,7 @@ func (ok *Okx) GenerateDefaultSubscriptions() ([]subscription.Subscription, erro
switch subs[c] {
case okxChannelOrders:
for x := range assets {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subs[c],
Asset: assets[x],
})
@@ -1318,15 +1324,15 @@ func (ok *Okx) GenerateDefaultSubscriptions() ([]subscription.Subscription, erro
return nil, err
}
for p := range pairs {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subs[c],
Asset: assets[x],
Pair: pairs[p],
Pairs: currency.Pairs{pairs[p]},
})
}
}
default:
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: subs[c],
})
}
@@ -1857,8 +1863,6 @@ func (ok *Okx) wsAuthChannelSubscription(operation, channel string, assetType as
var instrumentID string
var instrumentType string
var ccy string
var err error
var format currency.PairFormat
if params.InstrumentType {
instrumentType = ok.GetInstrumentTypeFromAssetItem(assetType)
if instrumentType != okxInstTypeMargin &&
@@ -1874,13 +1878,13 @@ func (ok *Okx) wsAuthChannelSubscription(operation, channel string, assetType as
}
}
if params.InstrumentID {
format, err = ok.GetPairFormat(assetType, false)
if err != nil {
return err
}
if !pair.IsPopulated() {
return errIncompleteCurrencyPair
}
format, err := ok.GetPairFormat(assetType, false)
if err != nil {
return err
}
instrumentID = format.Format(pair)
}
if params.Currency {

View File

@@ -351,10 +351,16 @@ type WebsocketTrollboxMessage struct {
Reputation float64
}
// WsCommand defines the request params after a websocket connection has been
// established
type WsCommand struct {
Command string `json:"command"`
type wsOp string
const (
wsSubscribeOp wsOp = "subscribe"
wsUnsubscribeOp wsOp = "unsubscribe"
)
// wsCommand defines the request params after a websocket connection has been established
type wsCommand struct {
Command wsOp `json:"command"`
Channel interface{} `json:"channel"`
APIKey string `json:"key,omitempty"`
Payload string `json:"payload,omitempty"`
@@ -467,9 +473,9 @@ type WsTradeNotificationResponse struct {
Date time.Time
}
// WsAuthorisationRequest Authenticated Ws Account data request
type WsAuthorisationRequest struct {
Command string `json:"command"`
// wsAuthorisationRequest Authenticated Ws Account data request
type wsAuthorisationRequest struct {
Command wsOp `json:"command"`
Channel int64 `json:"channel"`
Sign string `json:"sign"`
Key string `json:"key"`

View File

@@ -541,28 +541,28 @@ func (p *Poloniex) WsProcessOrderbookUpdate(sequenceNumber float64, data []inter
}
// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions()
func (p *Poloniex) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) {
func (p *Poloniex) GenerateDefaultSubscriptions() (subscription.List, error) {
enabledPairs, err := p.GetEnabledPairs(asset.Spot)
if err != nil {
return nil, err
}
subscriptions := make([]subscription.Subscription, 0, len(enabledPairs))
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions := make(subscription.List, 0, len(enabledPairs))
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: strconv.FormatInt(wsTickerDataID, 10),
})
if p.IsWebsocketAuthenticationSupported() {
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: strconv.FormatInt(wsAccountNotificationID, 10),
})
}
for j := range enabledPairs {
enabledPairs[j].Delimiter = currency.UnderscoreDelimiter
subscriptions = append(subscriptions, subscription.Subscription{
subscriptions = append(subscriptions, &subscription.Subscription{
Channel: "orderbook",
Pair: enabledPairs[j],
Pairs: currency.Pairs{enabledPairs[j]},
Asset: asset.Spot,
})
}
@@ -570,54 +570,16 @@ func (p *Poloniex) GenerateDefaultSubscriptions() ([]subscription.Subscription,
}
// Subscribe sends a websocket message to receive data from the channel
func (p *Poloniex) Subscribe(sub []subscription.Subscription) error {
var creds *account.Credentials
if p.IsWebsocketAuthenticationSupported() {
var err error
creds, err = p.GetCredentials(context.TODO())
if err != nil {
return err
}
}
var errs error
channels:
for i := range sub {
subscriptionRequest := WsCommand{
Command: "subscribe",
}
switch {
case strings.EqualFold(strconv.FormatInt(wsAccountNotificationID, 10),
sub[i].Channel) && creds != nil:
err := p.wsSendAuthorisedCommand(creds.Secret, creds.Key, "subscribe")
if err != nil {
errs = common.AppendError(errs, err)
continue channels
}
p.Websocket.AddSuccessfulSubscriptions(sub[i])
continue channels
case strings.EqualFold(strconv.FormatInt(wsTickerDataID, 10),
sub[i].Channel):
subscriptionRequest.Channel = wsTickerDataID
default:
subscriptionRequest.Channel = sub[i].Pair.String()
}
err := p.Websocket.Conn.SendJSONMessage(subscriptionRequest)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
p.Websocket.AddSuccessfulSubscriptions(sub[i])
}
if errs != nil {
return errs
}
return nil
func (p *Poloniex) Subscribe(subs subscription.List) error {
return p.manageSubs(subs, wsSubscribeOp)
}
// Unsubscribe sends a websocket message to stop receiving data from the channel
func (p *Poloniex) Unsubscribe(unsub []subscription.Subscription) error {
func (p *Poloniex) Unsubscribe(subs subscription.List) error {
return p.manageSubs(subs, wsUnsubscribeOp)
}
func (p *Poloniex) manageSubs(subs subscription.List, op wsOp) error {
var creds *account.Credentials
if p.IsWebsocketAuthenticationSupported() {
var err error
@@ -626,42 +588,39 @@ func (p *Poloniex) Unsubscribe(unsub []subscription.Subscription) error {
return err
}
}
var errs error
channels:
for i := range unsub {
unsubscriptionRequest := WsCommand{
Command: "unsubscribe",
}
switch {
case strings.EqualFold(strconv.FormatInt(wsAccountNotificationID, 10),
unsub[i].Channel) && creds != nil:
err := p.wsSendAuthorisedCommand(creds.Secret, creds.Key, "unsubscribe")
if err != nil {
errs = common.AppendError(errs, err)
continue channels
for _, s := range subs {
var err error
if creds != nil && strings.EqualFold(strconv.FormatInt(wsAccountNotificationID, 10), s.Channel) {
err = p.wsSendAuthorisedCommand(creds.Secret, creds.Key, op)
} else {
req := wsCommand{Command: op}
if strings.EqualFold(strconv.FormatInt(wsTickerDataID, 10), s.Channel) {
req.Channel = wsTickerDataID
} else {
if len(s.Pairs) != 1 {
return subscription.ErrNotSinglePair
}
req.Channel = s.Pairs[0].String()
}
err = p.Websocket.Conn.SendJSONMessage(req)
}
if err == nil {
if op == wsSubscribeOp {
err = p.Websocket.AddSuccessfulSubscriptions(s)
} else {
err = p.Websocket.RemoveSubscriptions(s)
}
p.Websocket.RemoveSubscriptions(unsub[i])
continue channels
case strings.EqualFold(strconv.FormatInt(wsTickerDataID, 10),
unsub[i].Channel):
unsubscriptionRequest.Channel = wsTickerDataID
default:
unsubscriptionRequest.Channel = unsub[i].Pair.String()
}
err := p.Websocket.Conn.SendJSONMessage(unsubscriptionRequest)
if err != nil {
errs = common.AppendError(errs, err)
continue
}
p.Websocket.RemoveSubscriptions(unsub[i])
}
if errs != nil {
return errs
}
return nil
return errs
}
func (p *Poloniex) wsSendAuthorisedCommand(secret, key, command string) error {
func (p *Poloniex) wsSendAuthorisedCommand(secret, key string, op wsOp) error {
nonce := fmt.Sprintf("nonce=%v", time.Now().UnixNano())
hmac, err := crypto.GetHMAC(crypto.HashSHA512,
[]byte(nonce),
@@ -669,8 +628,8 @@ func (p *Poloniex) wsSendAuthorisedCommand(secret, key, command string) error {
if err != nil {
return err
}
request := WsAuthorisationRequest{
Command: command,
request := wsAuthorisationRequest{
Command: op,
Channel: 1000,
Sign: crypto.HexEncodeToString(hmac),
Key: key,

View File

@@ -262,7 +262,7 @@ func (c *CustomEx) SupportsREST() bool {
}
// GetSubscriptions is a mock method for CustomEx
func (c *CustomEx) GetSubscriptions() ([]subscription.Subscription, error) {
func (c *CustomEx) GetSubscriptions() (subscription.List, error) {
return nil, nil
}
@@ -312,12 +312,12 @@ func (c *CustomEx) SupportsWebsocket() bool {
}
// SubscribeToWebsocketChannels is a mock method for CustomEx
func (c *CustomEx) SubscribeToWebsocketChannels(_ []subscription.Subscription) error {
func (c *CustomEx) SubscribeToWebsocketChannels(_ subscription.List) error {
return nil
}
// UnsubscribeToWebsocketChannels is a mock method for CustomEx
func (c *CustomEx) UnsubscribeToWebsocketChannels(_ []subscription.Subscription) error {
func (c *CustomEx) UnsubscribeToWebsocketChannels(_ subscription.List) error {
return nil
}

View File

@@ -15,8 +15,6 @@ import (
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
)
// This package is only to be referenced in test files
@@ -54,16 +52,10 @@ func GetWebsocketStructChannelOverride() chan struct{} {
// NewTestWebsocket returns a test websocket object
func NewTestWebsocket() *stream.Websocket {
return &stream.Websocket{
DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity),
ToRoutine: make(chan interface{}, 1000),
TrafficAlert: make(chan struct{}),
ReadMessageErrors: make(chan error),
Subscribe: make(chan []subscription.Subscription, 10),
Unsubscribe: make(chan []subscription.Subscription, 10),
Match: stream.NewMatch(),
Orderbook: buffer.Orderbook{},
}
w := stream.NewWebsocket()
w.DataHandler = make(chan interface{}, WebsocketChannelOverrideCapacity)
w.ToRoutine = make(chan interface{}, 1000)
return w
}
// SkipTestIfCredentialsUnset is a test helper function checking if the

View File

@@ -5,12 +5,14 @@ import (
"fmt"
"net"
"net/url"
"sync"
"slices"
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/log"
)
@@ -22,12 +24,9 @@ const (
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscribedAlready = errors.New("duplicate subscription")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrChannelInStateAlready = errors.New("channel already in state")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
)
@@ -57,9 +56,6 @@ var (
errClosedConnection = errors.New("use of closed network connection")
errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit")
errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0")
errNoSubscriptionsSupplied = errors.New("no subscriptions supplied")
errChannelAlreadySubscribed = errors.New("channel already subscribed")
errInvalidChannelState = errors.New("invalid Channel state")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("websocket connect func not set")
errAlreadyConnected = errors.New("websocket already connected")
@@ -84,11 +80,13 @@ func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, jobBuffer),
ToRoutine: make(chan interface{}, jobBuffer),
ShutdownC: make(chan struct{}),
TrafficAlert: make(chan struct{}, 1),
ReadMessageErrors: make(chan error),
Subscribe: make(chan []subscription.Subscription),
Unsubscribe: make(chan []subscription.Subscription),
Match: NewMatch(),
subscriptions: subscription.NewStore(),
features: &protocol.Features{},
Orderbook: buffer.Orderbook{},
}
}
@@ -181,7 +179,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout
w.ShutdownC = make(chan struct{})
w.Wg = new(sync.WaitGroup)
w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport)
if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil {
@@ -195,7 +192,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnected)
w.setState(disconnectedState)
return nil
}
@@ -243,7 +240,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
Traffic: w.TrafficAlert,
readMessageErrors: w.ReadMessageErrors,
ShutdownC: w.ShutdownC,
Wg: w.Wg,
Wg: &w.Wg,
Match: w.Match,
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
@@ -277,20 +274,21 @@ func (w *Websocket) Connect() error {
return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected)
}
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
return fmt.Errorf("%w: subscriptions", common.ErrNilPointer)
}
w.subscriptions.Clear()
w.dataMonitor()
w.trafficMonitor()
w.setState(connecting)
w.setState(connectingState)
err := w.connector()
if err != nil {
w.setState(disconnected)
w.setState(disconnectedState)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setState(connected)
w.setState(connectedState)
if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
@@ -303,17 +301,12 @@ func (w *Websocket) Connect() error {
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) == 0 {
return nil
}
err = w.checkSubscriptions(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
err = w.Subscriber(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
if len(subs) != 0 {
if err := w.SubscribeToChannels(subs); err != nil {
return err
}
}
return nil
}
@@ -469,11 +462,9 @@ func (w *Websocket) Shutdown() error {
}
// flush any subscriptions from last connection if needed
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
w.subscriptions.Clear()
w.setState(disconnected)
w.setState(disconnectedState)
close(w.ShutdownC)
w.Wg.Wait()
@@ -527,9 +518,7 @@ func (w *Websocket) FlushChannels() error {
if len(newsubs) != 0 {
// Purge subscription list as there will be conflicts
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
w.subscriptions.Clear()
return w.SubscribeToChannels(newsubs)
}
return nil
@@ -606,17 +595,17 @@ func (w *Websocket) setState(s uint32) {
// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
return w.state.Load() != uninitialised
return w.state.Load() != uninitialisedState
}
// IsConnected returns whether the websocket is connected
func (w *Websocket) IsConnected() bool {
return w.state.Load() == connected
return w.state.Load() == connectedState
}
// IsConnecting returns whether the websocket is connecting
func (w *Websocket) IsConnecting() bool {
return w.state.Load() == connecting
return w.state.Load() == connectingState
}
func (w *Websocket) setEnabled(b bool) {
@@ -786,163 +775,134 @@ func (w *Websocket) GetName() string {
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) {
w.subscriptionMutex.RLock()
unsubMap := make(map[any]subscription.Subscription, len(w.subscriptions))
for k, c := range w.subscriptions {
unsubMap[k] = *c
func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) {
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
w.subscriptionMutex.RUnlock()
for i := range genSubs {
key := genSubs[i].EnsureKeyed()
if _, ok := unsubMap[key]; ok {
delete(unsubMap, key) // If it's in both then we remove it from the unsubscribe list
} else {
sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe
}
}
for x := range unsubMap {
unsub = append(unsub, unsubMap[x])
}
return
return w.subscriptions.Diff(newSubs)
}
// UnsubscribeChannels unsubscribes from a websocket channel
func (w *Websocket) UnsubscribeChannels(channels []subscription.Subscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoSubscriptionsSupplied)
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(channels subscription.List) error {
if w.subscriptions == nil || len(channels) == 0 {
return nil // No channels to unsubscribe from is not an error
}
w.subscriptionMutex.RLock()
for i := range channels {
key := channels[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; !ok {
w.subscriptionMutex.RUnlock()
return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, ErrSubscriptionNotFound, channels[i])
for _, s := range channels {
if w.subscriptions.Get(s) == nil {
return fmt.Errorf("%w: %s", subscription.ErrNotFound, s)
}
}
w.subscriptionMutex.RUnlock()
return w.Unsubscriber(channels)
}
// ResubscribeToChannel resubscribes to channel
func (w *Websocket) ResubscribeToChannel(subscribedChannel *subscription.Subscription) error {
err := w.UnsubscribeChannels([]subscription.Subscription{*subscribedChannel})
if err != nil {
// Sets state to Resubscribing, and exchanges which want to maintain a lock on it can respect this state and not RemoveSubscription
// Errors if subscription is already subscribing
func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {
l := subscription.List{s}
if err := s.SetState(subscription.ResubscribingState); err != nil {
return fmt.Errorf("%w: %s", err, s)
}
if err := w.UnsubscribeChannels(l); err != nil {
return err
}
return w.SubscribeToChannels([]subscription.Subscription{*subscribedChannel})
return w.SubscribeToChannels(l)
}
// SubscribeToChannels appends supplied channels to channelsToSubscribe
func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) error {
if err := w.checkSubscriptions(channels); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
if slices.Contains(subs, nil) {
return fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
}
if err := w.Subscriber(channels); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
if err := w.checkSubscriptions(subs); err != nil {
return err
}
if err := w.Subscriber(subs); err != nil {
return fmt.Errorf("%w: %w", ErrSubscriptionFailure, err)
}
return nil
}
// AddSubscription adds a subscription to the subscription lists
// Unlike AddSubscriptions this method will error if the subscription already exists
func (w *Websocket) AddSubscription(c *subscription.Subscription) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// AddSubscriptions adds subscriptions to the subscription store
// Sets state to Subscribing unless the state is already set
func (w *Websocket) AddSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
w.subscriptions = subscription.NewStore()
}
key := c.EnsureKeyed()
if _, ok := w.subscriptions[key]; ok {
return ErrSubscribedAlready
var errs error
for _, s := range subs {
if s.State() == subscription.InactiveState {
if err := s.SetState(subscription.SubscribingState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
n := *c // Fresh copy; we don't want to use the pointer we were given and allow encapsulation/locks to be bypassed
w.subscriptions[key] = &n
return nil
return errs
}
// SetSubscriptionState sets an existing subscription state
// returns an error if the subscription is not found, or the new state is already set
func (w *Websocket) SetSubscriptionState(c *subscription.Subscription, state subscription.State) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// AddSuccessfulSubscriptions marks subscriptions as subscribed and adds them to the subscription store
func (w *Websocket) AddSuccessfulSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: AddSuccessfulSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
w.subscriptions = subscription.NewStore()
}
key := c.EnsureKeyed()
p, ok := w.subscriptions[key]
if !ok {
return ErrSubscriptionNotFound
var errs error
for _, s := range subs {
if err := s.SetState(subscription.SubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
if state == p.State {
return ErrChannelInStateAlready
}
if state > subscription.UnsubscribingState {
return errInvalidChannelState
}
p.State = state
return nil
return errs
}
// AddSuccessfulSubscriptions adds subscriptions to the subscription lists that
// has been successfully subscribed
func (w *Websocket) AddSuccessfulSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
// RemoveSubscriptions removes subscriptions from the subscription list and sets the status to Unsubscribed
func (w *Websocket) RemoveSubscriptions(subs ...*subscription.Subscription) error {
if w == nil {
return fmt.Errorf("%w: RemoveSubscriptions called on nil Websocket", common.ErrNilPointer)
}
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
return fmt.Errorf("%w: RemoveSubscriptions called on uninitialised Websocket", common.ErrNilPointer)
}
for _, cN := range channels { //nolint:gocritic // See below comment
c := cN // cN is an iteration var; Not safe to make a pointer to
key := c.EnsureKeyed()
c.State = subscription.SubscribedState
w.subscriptions[key] = &c
var errs error
for _, s := range subs {
if err := s.SetState(subscription.UnsubscribedState); err != nil {
errs = common.AppendError(errs, fmt.Errorf("%w: %s", err, s))
}
if err := w.subscriptions.Remove(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}
// RemoveSubscriptions removes subscriptions from the subscription list
func (w *Websocket) RemoveSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
w.subscriptions = subscriptionMap{}
}
for i := range channels {
key := channels[i].EnsureKeyed()
delete(w.subscriptions, key)
}
}
// GetSubscription returns a pointer to a copy of the subscription at the key provided
// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
func (w *Websocket) GetSubscription(key any) *subscription.Subscription {
if key == nil || w == nil || w.subscriptions == nil {
if w == nil || w.subscriptions == nil || key == nil {
return nil
}
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if s, ok := w.subscriptions[key]; ok {
c := *s
return &c
}
return nil
return w.subscriptions.Get(key)
}
// GetSubscriptions returns a new slice of the subscriptions
func (w *Websocket) GetSubscriptions() []subscription.Subscription {
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
subs := make([]subscription.Subscription, 0, len(w.subscriptions))
for _, c := range w.subscriptions {
subs = append(subs, *c)
func (w *Websocket) GetSubscriptions() subscription.List {
if w == nil || w.subscriptions == nil {
return nil
}
return subs
return w.subscriptions.List()
}
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
@@ -978,28 +938,25 @@ func checkWebsocketURL(s string) error {
return nil
}
// checkSubscriptions checks subscriptions against the max subscription limit
// and if the subscription already exists.
func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error {
if len(subs) == 0 {
return errNoSubscriptionsSupplied
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
// The subscription state is not considered when counting existing subscriptions
func (w *Websocket) checkSubscriptions(subs subscription.List) error {
if w.subscriptions == nil {
return fmt.Errorf("%w: Websocket.subscriptions", common.ErrNilPointer)
}
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection {
existing := w.subscriptions.Len()
if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
len(w.subscriptions),
existing,
len(subs),
w.MaxSubscriptionsPerConnection)
}
for i := range subs {
key := subs[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; ok {
return fmt.Errorf("%w for %+v", errChannelAlreadySubscribed, subs[i])
for _, s := range subs {
if found := w.subscriptions.Get(s); found != nil {
return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s)
}
}

View File

@@ -90,25 +90,22 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header
// SendJSONMessage sends a JSON encoded message over the connection
func (w *WebsocketConnection) SendJSONMessage(data interface{}) error {
if !w.IsConnected() {
return fmt.Errorf("%s websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%s websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
log.Debugf(log.WebsocketMgr,
"%s websocket connection: sending message to websocket %+v\n",
w.ExchangeName,
data)
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%s websocket connection: sending message: %s\n", w.ExchangeName, msg)
}
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
}
return w.Connection.WriteJSON(data)
@@ -117,29 +114,23 @@ func (w *WebsocketConnection) SendJSONMessage(data interface{}) error {
// SendRawMessage sends a message over the connection without JSON encoding it
func (w *WebsocketConnection) SendRawMessage(messageType int, message []byte) error {
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
log.Debugf(log.WebsocketMgr,
"%v websocket connection: sending message [%s]\n",
w.ExchangeName,
message)
log.Debugf(log.WebsocketMgr, "%v websocket connection: sending message [%s]\n", w.ExchangeName, message)
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
}
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket",
w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
return w.Connection.WriteMessage(messageType, message)
}

View File

@@ -10,7 +10,6 @@ import (
"net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
@@ -20,6 +19,7 @@ import (
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
@@ -79,10 +79,10 @@ var defaultSetup = &WebsocketSetup{
DefaultURL: "testDefaultURL",
RunningURL: "wss://testRunningURL",
Connector: func() error { return nil },
Subscriber: func([]subscription.Subscription) error { return nil },
Unsubscriber: func([]subscription.Subscription) error { return nil },
GenerateSubscriptions: func() ([]subscription.Subscription, error) {
return []subscription.Subscription{
Subscriber: func(subscription.List) error { return nil },
Unsubscriber: func(subscription.List) error { return nil },
GenerateSubscriptions: func() (subscription.List, error) {
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
@@ -147,16 +147,16 @@ func TestSetup(t *testing.T) {
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly")
websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Subscriber = func(subscription.List) error { return nil }
websocketSetup.Features.Unsubscribe = true
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly")
websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Unsubscriber = func(subscription.List) error { return nil }
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly")
websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil }
websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
err = w.Setup(websocketSetup)
assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly")
@@ -193,14 +193,13 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
signal := struct{}{}
patience := 10 * time.Millisecond
ws.trafficTimeout = 200 * time.Millisecond
ws.ShutdownC = make(chan struct{})
ws.state.Store(connected)
ws.state.Store(connectedState)
thenish := time.Now()
ws.trafficMonitor()
assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connected, ws.state.Load(), "websocket must be connected")
require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected")
for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass
select {
@@ -226,7 +225,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
}
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts")
}
@@ -238,17 +237,16 @@ func TestTrafficMonitorConnecting(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.ShutdownC = make(chan struct{})
ws.state.Store(connecting)
ws.state.Store(connectingState)
ws.trafficTimeout = 50 * time.Millisecond
ws.trafficMonitor()
require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting")
require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting")
<-time.After(4 * ws.trafficTimeout)
require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connected)
require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connectedState)
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes")
}
@@ -260,8 +258,7 @@ func TestTrafficMonitorShutdown(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.ShutdownC = make(chan struct{})
ws.state.Store(connected)
ws.state.Store(connectedState)
ws.trafficTimeout = time.Minute
ws.trafficMonitor()
assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
@@ -307,12 +304,17 @@ func TestConnectionMessageErrors(t *testing.T) {
assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly")
wsWrong.setEnabled(true)
wsWrong.setState(connecting)
wsWrong.Wg = &sync.WaitGroup{}
wsWrong.setState(connectingState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly")
wsWrong.setState(disconnected)
wsWrong.setState(disconnectedState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error")
assert.ErrorContains(t, err, "subscriptions", "Connect should get a nil pointer error about subscriptions")
wsWrong.subscriptions = subscription.NewStore()
wsWrong.setState(disconnectedState)
wsWrong.connector = func() error { return errDastardlyReason }
err = wsWrong.Connect()
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
@@ -321,7 +323,7 @@ func TestConnectionMessageErrors(t *testing.T) {
err = ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")
ws.trafficTimeout = time.Minute
ws.connector = func() error { return nil }
ws.connector = connect
err = ws.Connect()
require.NoError(t, err, "Connect must not error")
@@ -381,7 +383,7 @@ func TestWebsocket(t *testing.T) {
err = ws.SetProxyAddress("https://192.168.0.1:1337")
assert.NoError(t, err, "SetProxyAddress should not error when not yet connected")
ws.setState(connected)
ws.setState(connectedState)
err = ws.SetProxyAddress("https://192.168.0.1:1336")
assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there")
@@ -404,14 +406,14 @@ func TestWebsocket(t *testing.T) {
assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly")
assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly")
ws.setState(connected)
ws.setState(connectedState)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn")
assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn")
ws.AuthConn = &WebsocketConnection{}
ws.setState(disconnected)
ws.setState(disconnectedState)
err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
@@ -446,34 +448,39 @@ func TestWebsocket(t *testing.T) {
ws.Wg.Wait()
}
func currySimpleSub(w *Websocket) func(subscription.List) error {
return func(subs subscription.List) error {
return w.AddSuccessfulSubscriptions(subs...)
}
}
func currySimpleUnsub(w *Websocket) func(subscription.List) error {
return func(unsubs subscription.List) error {
return w.RemoveSubscriptions(unsubs...)
}
}
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")
fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
subs, err := ws.GenerateSubs()
assert.NoError(t, err, "Generating test subscriptions should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed")
require.NoError(t, err, "Generating test subscriptions should not error")
assert.NoError(t, new(Websocket).UnsubscribeChannels(subs), "Should not error when w.subscriptions is nil")
assert.NoError(t, ws.UnsubscribeChannels(nil), "Unsubscribing from nil should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
byDefKey := ws.GetSubscription(subscription.DefaultKey{Channel: "TestSub"})
if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") {
assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
@@ -486,9 +493,15 @@ func TestSubscribeUnsubscribe(t *testing.T) {
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), errChannelAlreadySubscribed, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.NoError(t, ws.SubscribeToChannels(nil), "Subscribe to an nil List should not error")
assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error")
ws.Subscriber = func(subscription.List) error { return errDastardlyReason }
assert.ErrorIs(t, ws.SubscribeToChannels(subs), errDastardlyReason, "Should error correctly when error returned from Subscriber")
err = ws.SubscribeToChannels(subscription.List{nil})
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly when list contains a nil subscription")
}
// TestResubscribe tests Resubscribing to existing subscriptions
@@ -504,61 +517,56 @@ func TestResubscribe(t *testing.T) {
err = ws.Setup(defaultSetup)
assert.NoError(t, err, "WS Setup should not error")
fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub
ws.Subscriber = currySimpleSub(ws)
ws.Unsubscriber = currySimpleUnsub(ws)
channel := []subscription.Subscription{{Channel: "resubTest"}}
channel := subscription.List{{Channel: "resubTest"}}
assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed")
assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed")
}
// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
// TestSubscriptions tests adding, getting and removing subscriptions
func TestSubscriptions(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Websocket)(nil).AddSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
s := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error")
assert.Same(t, s, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSubscriptions(s), subscription.ErrDuplicate, "Adding same subscription should return error")
assert.Equal(t, subscription.SubscribingState, s.State(), "Should set state to Subscribing")
c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")
err := w.RemoveSubscriptions(s)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.Equal(t, subscription.UnsubscribedState, s.State(), "Should set state to Unsubscribed")
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
found := ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing")
assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error")
ws.AddSuccessfulSubscriptions(*c)
found = ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribedState, found.State, "Subscription should be subscribed state")
assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error")
found = ws.GetSubscription(42)
assert.Equal(t, subscription.UnsubscribingState, found.State, "Subscription should be unsubscribing state")
require.NoError(t, s.SetState(subscription.ResubscribingState), "SetState must not error")
require.NoError(t, w.AddSubscriptions(s), "Adding first subscription should not error")
assert.Equal(t, subscription.ResubscribingState, s.State(), "Should not change resubscribing state")
}
// TestRemoveSubscriptions tests removing a subscription
func TestRemoveSubscriptions(t *testing.T) {
// TestSuccessfulSubscriptions tests adding, getting and removing subscriptions
func TestSuccessfulSubscriptions(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
w := new(Websocket) // Do not use NewWebsocket; We want to exercise w.subs == nil
assert.ErrorIs(t, (*Websocket)(nil).AddSuccessfulSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel}
require.NoError(t, w.AddSuccessfulSubscriptions(c), "Adding first subscription should not error")
assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrInStateAlready, "Adding subscription in same state should return error")
require.NoError(t, c.SetState(subscription.SubscribingState), "SetState must not error")
assert.ErrorIs(t, w.AddSuccessfulSubscriptions(c), subscription.ErrDuplicate, "Adding same subscription should return error")
c := &subscription.Subscription{Key: 42, Channel: "Unite!"}
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable")
ws.RemoveSubscriptions(*c)
assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub")
err := w.RemoveSubscriptions(c)
require.NoError(t, err, "RemoveSubscriptions must not error")
assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub")
assert.ErrorIs(t, w.RemoveSubscriptions(c), subscription.ErrNotFound, "Should error correctly when not found")
assert.ErrorIs(t, (*Websocket)(nil).RemoveSubscriptions(nil), common.ErrNilPointer, "Should error correctly when nil websocket")
w.subscriptions = nil
assert.ErrorIs(t, w.RemoveSubscriptions(c), common.ErrNilPointer, "Should error correctly when nil websocket")
}
// TestConnectionMonitorNoConnection logic test
@@ -566,10 +574,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := NewWebsocket()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
ws.exchangeName = "hello"
ws.Wg = &sync.WaitGroup{}
ws.setEnabled(true)
err := ws.connectionMonitor()
require.NoError(t, err, "connectionMonitor must not error")
@@ -582,32 +587,27 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil")
w := Websocket{
subscriptions: subscriptionMap{
42: {
Channel: "hello3",
},
},
}
assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil")
c := w.GetSubscription(42)
if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") {
assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details")
}
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := NewWebsocket()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
require.NoError(t, w.AddSubscriptions(s), "AddSubscriptions must not error")
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}
// TestGetSubscriptions logic test
func TestGetSubscriptions(t *testing.T) {
t.Parallel()
w := Websocket{
subscriptions: subscriptionMap{
42: {
Channel: "hello3",
},
},
assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
w := NewWebsocket()
s := subscription.List{
{Key: 42, Channel: "hello3"},
{Key: 45, Channel: "hello4"},
}
assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details")
err := w.AddSubscriptions(s...)
require.NoError(t, err, "AddSubscriptions must not error")
assert.ElementsMatch(t, s, w.GetSubscriptions(), "GetSubscriptions should return the correct channel details")
}
// TestSetCanUseAuthenticatedEndpoints logic test
@@ -883,7 +883,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
ws := &Websocket{}
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")
ws.setState(connected)
ws.setState(connectedState)
require.True(t, ws.IsConnected(), "IsConnected must return true")
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")
@@ -939,77 +939,42 @@ func TestCheckWebsocketURL(t *testing.T) {
assert.NoError(t, err, "checkWebsocketURL should not error")
}
// TestGetChannelDifference exercises GetChannelDifference
// See subscription.TestStoreDiff for further testing
func TestGetChannelDifference(t *testing.T) {
t.Parallel()
web := Websocket{}
newChans := []subscription.Subscription{
{
Channel: "Test1",
},
{
Channel: "Test2",
},
{
Channel: "Test3",
},
}
subs, unsubs := web.GetChannelDifference(newChans)
assert.Len(t, subs, 3, "Should get the correct number of subs")
assert.Empty(t, unsubs, "Should get the correct number of unsubs")
web.AddSuccessfulSubscriptions(subs...)
flushedSubs := []subscription.Subscription{
{
Channel: "Test2",
},
}
subs, unsubs = web.GetChannelDifference(flushedSubs)
assert.Empty(t, subs, "Should get the correct number of subs")
assert.Len(t, unsubs, 2, "Should get the correct number of unsubs")
flushedSubs = []subscription.Subscription{
{
Channel: "Test2",
},
{
Channel: "Test4",
},
}
subs, unsubs = web.GetChannelDifference(flushedSubs)
if assert.Len(t, subs, 1, "Should get the correct number of subs") {
assert.Equal(t, "Test4", subs[0].Channel, "Should subscribe to the right channel")
}
if assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") {
sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel })
assert.Equal(t, "Test1", unsubs[0].Channel, "Should unsubscribe from the right channels")
assert.Equal(t, "Test3", unsubs[1].Channel, "Should unsubscribe from the right channels")
}
w := &Websocket{}
assert.NotPanics(t, func() { w.GetChannelDifference(subscription.List{}) }, "Should not panic when called without a store")
subs, unsubs := w.GetChannelDifference(subscription.List{{Channel: subscription.CandlesChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
require.Empty(t, unsubs, "Should get no unsubs")
require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error")
subs, unsubs = w.GetChannelDifference(subscription.List{{Channel: subscription.TickerChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs")
}
// GenSubs defines a theoretical exchange with pair management
type GenSubs struct {
EnabledPairs currency.Pairs
subscribos []subscription.Subscription
unsubscribos []subscription.Subscription
subscribos subscription.List
unsubscribos subscription.List
}
// generateSubs default subs created from the enabled pairs list
func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) {
superduperchannelsubs := make([]subscription.Subscription, len(g.EnabledPairs))
func (g *GenSubs) generateSubs() (subscription.List, error) {
superduperchannelsubs := make(subscription.List, len(g.EnabledPairs))
for i := range g.EnabledPairs {
superduperchannelsubs[i] = subscription.Subscription{
superduperchannelsubs[i] = &subscription.Subscription{
Channel: "TEST:" + strconv.FormatInt(int64(i), 10),
Pair: g.EnabledPairs[i],
Pairs: currency.Pairs{g.EnabledPairs[i]},
}
}
return superduperchannelsubs, nil
}
func (g *GenSubs) SUBME(subs []subscription.Subscription) error {
func (g *GenSubs) SUBME(subs subscription.List) error {
if len(subs) == 0 {
return errors.New("WOW")
}
@@ -1017,7 +982,7 @@ func (g *GenSubs) SUBME(subs []subscription.Subscription) error {
return nil
}
func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error {
func (g *GenSubs) UNSUBME(unsubs subscription.List) error {
if len(unsubs) == 0 {
return errors.New("WOW")
}
@@ -1044,87 +1009,70 @@ func TestFlushChannels(t *testing.T) {
err = dodgyWs.FlushChannels()
assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly")
w := Websocket{
connector: connect,
ShutdownC: make(chan struct{}),
Subscriber: newgen.SUBME,
Unsubscriber: newgen.UNSUBME,
Wg: new(sync.WaitGroup),
features: &protocol.Features{
// No features
},
trafficTimeout: time.Second * 30, // Added for when we utilise connect()
// in FlushChannels() so the traffic monitor doesn't time out and turn
// this to an unconnected state
}
w.setEnabled(true)
w.setState(connected)
w := NewWebsocket()
w.connector = connect
w.Subscriber = newgen.SUBME
w.Unsubscriber = newgen.UNSUBME
// Added for when we utilise connect() in FlushChannels() so the traffic monitor doesn't time out and turn this to an unconnected state
w.trafficTimeout = time.Second * 30
problemFunc := func() ([]subscription.Subscription, error) {
w.setEnabled(true)
w.setState(connectedState)
problemFunc := func() (subscription.List, error) {
return nil, errDastardlyReason
}
noSub := func() ([]subscription.Subscription, error) {
noSub := func() (subscription.List, error) {
return nil, nil
}
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{
currency.NewPair(currency.BTC, currency.AUD)}
w.GenerateSubs = func() ([]subscription.Subscription, error) {
return []subscription.Subscription{{Channel: "test"}}, nil
w.GenerateSubs = func() (subscription.List, error) {
return subscription.List{{Channel: "test"}}, nil
}
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
require.NoError(t, err, "Flush Channels must not error")
w.features.FullPayloadSubscribe = true
w.GenerateSubs = problemFunc
err = w.FlushChannels() // error on full subscribeToChannels
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs")
w.GenerateSubs = noSub
err = w.FlushChannels() // No subs to unsub
assert.NoError(t, err, "FlushChannels should not error")
err = w.FlushChannels() // No subs to sub
assert.NoError(t, err, "Flush Channels should not error")
w.GenerateSubs = newgen.generateSubs
subs, err := w.GenerateSubs()
require.NoError(t, err, "GenerateSubs must not error")
w.AddSuccessfulSubscriptions(subs...)
require.NoError(t, w.AddSubscriptions(subs...), "AddSubscriptions must not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.features.FullPayloadSubscribe = false
w.features.Subscribe = true
w.GenerateSubs = problemFunc
err = w.FlushChannels()
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")
w.GenerateSubs = newgen.generateSubs
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{
41: {
Key: 41,
Channel: "match channel",
Pair: currency.NewPair(currency.BTC, currency.AUD),
},
42: {
Key: 42,
Channel: "unsub channel",
Pair: currency.NewPair(currency.THETA, currency.USDT),
},
}
w.subscriptionMutex.Unlock()
w.subscriptions = subscription.NewStore()
err = w.subscriptions.Add(&subscription.Subscription{
Key: 41,
Channel: "match channel",
Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)},
})
require.NoError(t, err, "AddSubscription must not error")
err = w.subscriptions.Add(&subscription.Subscription{
Key: 42,
Channel: "unsub channel",
Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)},
})
require.NoError(t, err, "AddSubscription must not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
w.setState(connected)
w.setState(connectedState)
w.features.Unsubscribe = true
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
@@ -1132,27 +1080,20 @@ func TestFlushChannels(t *testing.T) {
func TestDisable(t *testing.T) {
t.Parallel()
w := Websocket{
ShutdownC: make(chan struct{}),
}
w := NewWebsocket()
w.setEnabled(true)
w.setState(connected)
w.setState(connectedState)
require.NoError(t, w.Disable(), "Disable must not error")
assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly")
}
func TestEnable(t *testing.T) {
t.Parallel()
w := Websocket{
connector: connect,
Wg: new(sync.WaitGroup),
ShutdownC: make(chan struct{}),
GenerateSubs: func() ([]subscription.Subscription, error) {
return []subscription.Subscription{{Channel: "test"}}, nil
},
Subscriber: func([]subscription.Subscription) error { return nil },
}
w := NewWebsocket()
w.connector = connect
w.Subscriber = func(subscription.List) error { return nil }
w.Unsubscriber = func(subscription.List) error { return nil }
w.GenerateSubs = func() (subscription.List, error) { return nil, nil }
require.NoError(t, w.Enable(), "Enable must not error")
assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly")
}
@@ -1259,19 +1200,30 @@ func TestCheckSubscriptions(t *testing.T) {
t.Parallel()
ws := Websocket{}
err := ws.checkSubscriptions(nil)
assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly")
assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly on nil w.subscriptions")
assert.ErrorContains(t, err, "Websocket.subscriptions", "checkSubscriptions should error giving context correctly on nil w.subscriptions")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(nil)
assert.NoError(t, err, "checkSubscriptions should not error on a nil list")
ws.MaxSubscriptionsPerConnection = 1
err = ws.checkSubscriptions([]subscription.Subscription{{}, {}})
err = ws.checkSubscriptions(subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error when subscriptions is empty")
ws.subscriptions = subscription.NewStore()
err = ws.checkSubscriptions(subscription.List{{}, {}})
assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly")
ws.MaxSubscriptionsPerConnection = 2
ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}}
err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly")
ws.subscriptions = subscription.NewStore()
err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"})
require.NoError(t, err, "Add subscription must not error")
err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}})
assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly")
err = ws.checkSubscriptions([]subscription.Subscription{{}})
err = ws.checkSubscriptions(subscription.List{{}})
assert.NoError(t, err, "checkSubscriptions should not error")
}

View File

@@ -22,13 +22,11 @@ const (
UnhandledMessage = " - Unhandled websocket message: "
)
type subscriptionMap map[any]*subscription.Subscription
const (
uninitialised uint32 = iota
disconnected
connecting
connected
uninitialisedState uint32 = iota
disconnectedState
connectingState
connectedState
)
// Websocket defines a return type for websocket connections via the interface
@@ -52,20 +50,14 @@ type Websocket struct {
m sync.Mutex
connector func() error
subscriptionMutex sync.RWMutex
subscriptions subscriptionMap
Subscribe chan []subscription.Subscription
Unsubscribe chan []subscription.Subscription
subscriptions *subscription.Store
// Subscriber function for package defined websocket subscriber
// functionality
Subscriber func([]subscription.Subscription) error
// Unsubscriber function for packaged defined websocket unsubscriber
// functionality
Unsubscriber func([]subscription.Subscription) error
// GenerateSubs function for package defined websocket generate
// subscriptions functionality
GenerateSubs func() ([]subscription.Subscription, error)
// Subscriber function for exchange specific subscribe implementation
Subscriber func(subscription.List) error
// Subscriber function for exchange specific unsubscribe implementation
Unsubscriber func(subscription.List) error
// GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets
GenerateSubs func() (subscription.List, error)
DataHandler chan interface{}
ToRoutine chan interface{}
@@ -74,7 +66,7 @@ type Websocket struct {
// shutdown synchronises shutdown event across routines
ShutdownC chan struct{}
Wg *sync.WaitGroup
Wg sync.WaitGroup
// Orderbook is a local buffer of orderbooks
Orderbook buffer.Orderbook
@@ -112,9 +104,9 @@ type WebsocketSetup struct {
RunningURL string
RunningURLAuth string
Connector func() error
Subscriber func([]subscription.Subscription) error
Unsubscriber func([]subscription.Subscription) error
GenerateSubscriptions func() ([]subscription.Subscription, error)
Subscriber func(subscription.List) error
Unsubscriber func(subscription.List) error
GenerateSubscriptions func() (subscription.List, error)
Features *protocol.Features
// Local orderbook buffer config values

View File

@@ -0,0 +1,88 @@
package subscription
import (
"fmt"
"github.com/thrasher-corp/gocryptotrader/currency"
)
// MatchableKey interface should be implemented by Key types which want a more complex matching than a simple key equality check
// The Subscription method allows keys to compare against keys of other types
type MatchableKey interface {
Match(MatchableKey) bool
GetSubscription() *Subscription
String() string
}
// ExactKey is key type for subscriptions where all the pairs in a Subscription must match exactly
type ExactKey struct {
*Subscription
}
var _ MatchableKey = ExactKey{} // Enforce ExactKey must implement MatchableKey
// GetSubscription returns the underlying subscription
func (k ExactKey) GetSubscription() *Subscription {
return k.Subscription
}
// String implements Stringer; returns the Asset, Channel and Pairs
// Does not provide concurrency protection on the subscription it points to
func (k ExactKey) String() string {
s := k.Subscription
if s == nil {
return "Uninitialised ExactKey"
}
p := s.Pairs.Format(currency.PairFormat{Uppercase: true, Delimiter: "/"})
return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, p.Join())
}
// Match implements MatchableKey
// Returns true if the key fields exactly matches the subscription, including all Pairs
func (k ExactKey) Match(eachKey MatchableKey) bool {
if eachKey == nil {
return false
}
eachSub := eachKey.GetSubscription()
return eachSub != nil &&
eachSub.Channel == k.Channel &&
eachSub.Asset == k.Asset &&
eachSub.Pairs.Equal(k.Pairs) &&
eachSub.Levels == k.Levels &&
eachSub.Interval == k.Interval
}
// IgnoringPairsKey is a key type for finding subscriptions to group together for requests
type IgnoringPairsKey struct {
*Subscription
}
var _ MatchableKey = IgnoringPairsKey{} // Enforce IgnoringPairsKey must implement MatchableKey
// GetSubscription returns the underlying subscription
func (k IgnoringPairsKey) GetSubscription() *Subscription {
return k.Subscription
}
// String implements Stringer; returns the asset and Channel name but no pairs
func (k IgnoringPairsKey) String() string {
s := k.Subscription
if s == nil {
return "Uninitialised IgnoringPairsKey"
}
return fmt.Sprintf("%s %s", s.Channel, s.Asset)
}
// Match implements MatchableKey
func (k IgnoringPairsKey) Match(eachKey MatchableKey) bool {
if eachKey == nil {
return false
}
eachSub := eachKey.GetSubscription()
return eachSub != nil &&
eachSub.Channel == k.Channel &&
eachSub.Asset == k.Asset &&
eachSub.Levels == k.Levels &&
eachSub.Interval == k.Interval
}

View File

@@ -0,0 +1,110 @@
package subscription
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
)
// DummyKey is a test key type that ensures that cross compatible keys can be used
// It will panic if Match() is called
type DummyKey struct {
*Subscription
detonator testing.TB
}
var _ MatchableKey = DummyKey{} // Enforce DummyKey must implement MatchableKey
// GetSubscription returns the underlying subscription
func (k DummyKey) GetSubscription() *Subscription {
return k.Subscription
}
// Match implements MatchableKey
func (k DummyKey) Match(_ MatchableKey) bool {
k.detonator.Fatal("DummyKey Match should never be called")
return false
}
// TestExactKeyMatch exercises ExactKey.Match
func TestExactKeyMatch(t *testing.T) {
t.Parallel()
key := &ExactKey{&Subscription{Channel: TickerChannel}}
try := &DummyKey{&Subscription{Channel: OrderbookChannel}, t}
require.False(t, key.Match(nil), "Match on a nil must return false")
require.False(t, key.Match(try), "Gate 1: Match must reject a bad Channel")
try.Channel = TickerChannel
require.True(t, key.Match(try), "Gate 1: Match must accept a good Channel")
key.Asset = asset.Spot
require.False(t, key.Match(try), "Gate 2: Match must reject a bad Asset")
try.Asset = asset.Spot
require.True(t, key.Match(try), "Gate 2: Match must accept a good Asset")
key.Pairs = currency.Pairs{btcusdtPair}
require.False(t, key.Match(try), "Gate 3: Match must reject B empty Pairs when key has Pairs")
try.Pairs = currency.Pairs{btcusdtPair}
key.Pairs = nil
require.False(t, key.Match(try), "Gate 3: Match must reject B has Pairs when key has empty Pairs")
key.Pairs = currency.Pairs{btcusdtPair}
require.True(t, key.Match(try), "Gate 3: Match must accept matching pairs")
key.Pairs = currency.Pairs{ethusdcPair}
require.False(t, key.Match(try), "Gate 3: Match must reject when key.Pairs not matching")
try.Pairs = currency.Pairs{btcusdtPair, ethusdcPair}
require.False(t, key.Match(try), "Gate 3: Match must reject when key.Pairs is only a subset")
key.Pairs = currency.Pairs{ethusdcPair, btcusdtPair}
require.True(t, key.Match(try), "Gate 3: Match accept when Pairs match in different order")
key.Levels = 4
require.False(t, key.Match(try), "Gate 4: Match must reject a bad Level")
try.Levels = 4
require.True(t, key.Match(try), "Gate 4: Match must accept a good Level")
key.Interval = kline.FiveMin
require.False(t, key.Match(try), "Gate 5: Match must reject a bad Interval")
try.Interval = kline.FiveMin
require.True(t, key.Match(try), "Gate 5: Match must accept a good Interval")
}
// TestExactKeyString exercises ExactKey.String
func TestExactKeyString(t *testing.T) {
t.Parallel()
key := &ExactKey{}
assert.Equal(t, "Uninitialised ExactKey", key.String())
key = &ExactKey{&Subscription{Asset: asset.Spot, Channel: TickerChannel, Pairs: currency.Pairs{ethusdcPair, btcusdtPair}}}
assert.Equal(t, "ticker spot ETH/USDC,BTC/USDT", key.String())
}
// TestIgnoringPairsKeyMatch exercises IgnoringPairsKey.Match
func TestIgnoringPairsKeyMatch(t *testing.T) {
t.Parallel()
key := &IgnoringPairsKey{&Subscription{Channel: TickerChannel, Pairs: currency.Pairs{btcusdtPair}}}
try := &DummyKey{&Subscription{Channel: OrderbookChannel, Pairs: currency.Pairs{ethusdcPair}}, t}
require.False(t, key.Match(nil), "Match on a nil must return false")
require.False(t, key.Match(try), "Gate 1: Match must reject a bad Channel")
try.Channel = TickerChannel
require.True(t, key.Match(try), "Gate 1: Match must accept a good Channel")
key.Asset = asset.Spot
require.False(t, key.Match(try), "Gate 2: Match must reject a bad Asset")
try.Asset = asset.Spot
require.True(t, key.Match(try), "Gate 2: Match must accept a good Asset")
key.Levels = 4
require.False(t, key.Match(try), "Gate 3: Match must reject a bad Level")
try.Levels = 4
require.True(t, key.Match(try), "Gate 3: Match must accept a good Level")
key.Interval = kline.FiveMin
require.False(t, key.Match(try), "Gate 4: Match must reject a bad Interval")
try.Interval = kline.FiveMin
require.True(t, key.Match(try), "Gate 4: Match must accept a good Interval")
}
// TestIgnoringPairsKeyString exercises IgnoringPairsKey.String
func TestIgnoringPairsKeyString(t *testing.T) {
t.Parallel()
key := &IgnoringPairsKey{&Subscription{Asset: asset.Spot, Channel: TickerChannel, Pairs: currency.Pairs{ethusdcPair, btcusdtPair}}}
assert.Equal(t, "ticker spot", key.String())
}

View File

@@ -0,0 +1,32 @@
package subscription
import (
"slices"
)
// List is a container of subscription pointers
type List []*Subscription
// Strings returns a sorted slice of subscriptions
func (l List) Strings() []string {
s := make([]string, len(l))
for i := range l {
s[i] = l[i].String()
}
slices.Sort(s)
return s
}
// GroupPairs groups subscriptions which are identical apart from the Pairs
// The returned List contains cloned Subscriptions, and the original Subscriptions are left alone
func (l List) GroupPairs() (n List) {
s := NewStore()
for _, sub := range l {
if found := s.match(&IgnoringPairsKey{sub}); found == nil {
s.unsafeAdd(sub.Clone())
} else {
found.AddPairs(sub.Pairs...)
}
}
return s.List()
}

View File

@@ -0,0 +1,47 @@
package subscription
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
)
// TestListStrings exercises List.Strings()
func TestListStrings(t *testing.T) {
l := List{
&Subscription{
Channel: TickerChannel,
Asset: asset.Spot,
Pairs: currency.Pairs{ethusdcPair, btcusdtPair},
},
&Subscription{
Channel: OrderbookChannel,
Pairs: currency.Pairs{ethusdcPair},
},
}
exp := []string{"orderbook ETH/USDC", "ticker spot ETH/USDC,BTC/USDT"}
assert.ElementsMatch(t, exp, l.Strings(), "String must return correct sorted list")
}
// TestListGroupPairs exercises List.GroupPairs()
func TestListGroupPairs(t *testing.T) {
l := List{
{Asset: asset.Spot, Channel: TickerChannel, Pairs: currency.Pairs{ethusdcPair, btcusdtPair}},
}
for _, c := range []string{TickerChannel, OrderbookChannel} {
for _, p := range []currency.Pair{ethusdcPair, btcusdtPair} {
l = append(l, &Subscription{
Channel: c,
Asset: asset.Spot,
Pairs: currency.Pairs{p},
})
}
}
n := l.GroupPairs()
assert.Len(t, l, 5, "Orig list should not be changed")
assert.Len(t, n, 2, "New list should be grouped")
exp := []string{"ticker spot ETH/USDC,BTC/USDT", "orderbook spot ETH/USDC,BTC/USDT"}
assert.ElementsMatch(t, exp, n.Strings(), "String must return correct sorted list")
}

View File

@@ -0,0 +1,208 @@
package subscription
import (
"fmt"
"maps"
"sync"
"github.com/thrasher-corp/gocryptotrader/common"
)
// Store is a container of subscription pointers
type Store struct {
m map[any]*Subscription
mu sync.RWMutex
}
// NewStore creates a ready to use store and should always be used
func NewStore() *Store {
return &Store{
m: map[any]*Subscription{},
}
}
// NewStoreFromList creates a Store from a List
func NewStoreFromList(l List) (*Store, error) {
s := NewStore()
for _, sub := range l {
if sub == nil {
return nil, fmt.Errorf("%w: List parameter contains an nil element", common.ErrNilPointer)
}
if err := s.add(sub); err != nil {
return nil, err
}
}
return s, nil
}
// Add adds a subscription to the store
// Key can be already set; if omitted EnsureKeyed will be used
// Errors if it already exists
func (s *Store) Add(sub *Subscription) error {
if s == nil {
return fmt.Errorf("%w: Add called on nil Store", common.ErrNilPointer)
}
if s.m == nil {
return fmt.Errorf("%w: Add called on an Uninitialised Store", common.ErrNilPointer)
}
if sub == nil {
return fmt.Errorf("%w: Subscription param", common.ErrNilPointer)
}
s.mu.Lock()
defer s.mu.Unlock()
return s.add(sub)
}
// add adds a subscription to the store
// Key can be already set; if omitted EnsureKeyed will be used
// This method provides no locking protection
func (s *Store) add(sub *Subscription) error {
key := sub.EnsureKeyed()
if found := s.get(key); found != nil {
return fmt.Errorf("%w: %s", ErrDuplicate, sub)
}
s.m[key] = sub
return nil
}
// unsafeAdd adds a subscription to the store without checking if it is a duplicate
// Key can be already set; if omitted EnsureKeyed will be used
// This method provides no locking protection
func (s *Store) unsafeAdd(sub *Subscription) {
key := sub.EnsureKeyed()
s.m[key] = sub
}
// Get returns a pointer to a subscription or nil if not found
// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself.
// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey
func (s *Store) Get(key any) *Subscription {
if s == nil || s.m == nil || key == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.get(key)
}
// get returns a pointer to subscription or nil if not found
// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself.
// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey
// This method provides no locking protection
func (s *Store) get(key any) *Subscription {
switch v := key.(type) {
case Subscription:
key = v.EnsureKeyed()
case *Subscription:
key = v.EnsureKeyed()
}
switch v := key.(type) {
case MatchableKey:
return s.match(v)
default:
return s.m[v]
}
}
// Remove removes a subscription from the store
// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself.
// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey
func (s *Store) Remove(key any) error {
if s == nil {
return fmt.Errorf("%w: Remove called on nil Store", common.ErrNilPointer)
}
if s.m == nil {
return fmt.Errorf("%w: Remove called on an Uninitialised Store", common.ErrNilPointer)
}
if key == nil {
return fmt.Errorf("%w: key param", common.ErrNilPointer)
}
s.mu.Lock()
defer s.mu.Unlock()
if found := s.get(key); found != nil {
delete(s.m, found.Key)
return nil
}
return ErrNotFound
}
// List returns a slice of Subscriptions pointers
func (s *Store) List() List {
if s == nil || s.m == nil {
return List{}
}
s.mu.RLock()
defer s.mu.RUnlock()
subs := make(List, 0, len(s.m))
for _, sub := range s.m {
subs = append(subs, sub)
}
return subs
}
// Clear empties the subscription store
func (s *Store) Clear() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.m == nil {
s.m = map[any]*Subscription{}
}
clear(s.m)
}
// match returns the first subscription which matches the Key's Asset, Channel and Pairs
// If the key provided has:
// 1) Empty pairs then only Subscriptions without pairs will be considered
// 2) >=1 pairs then Subscriptions which contain all the pairs will be considered
// This method provides no locking protection
func (s *Store) match(key MatchableKey) *Subscription {
for eachKey, sub := range s.m {
if m, ok := eachKey.(MatchableKey); ok {
if key.Match(m) {
return sub
}
}
}
return nil
}
// Diff returns a list of the added and missing subs from a new list
// The store Diff is invoked upon is read-lock protected
// The new store is assumed to be a new instance and enjoys no locking protection
func (s *Store) Diff(compare List) (added, removed List) {
if s == nil || s.m == nil {
return
}
s.mu.RLock()
defer s.mu.RUnlock()
removedMap := maps.Clone(s.m)
for _, sub := range compare {
if found := s.get(sub); found != nil {
delete(removedMap, found.Key)
} else {
added = append(added, sub)
}
}
for _, c := range removedMap {
removed = append(removed, c)
}
return
}
// Len returns the number of subscriptions
func (s *Store) Len() int {
if s == nil {
return 0
}
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.m)
}

View File

@@ -0,0 +1,182 @@
package subscription
import (
"maps"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/currency"
)
// TestNewStore exercises NewStore
func TestNewStore(t *testing.T) {
s := NewStore()
require.IsType(t, &Store{}, s, "Must return a store ref")
require.NotNil(t, s.m, "storage map must be initialised")
}
// TestNewStoreFromList exercises NewStoreFromList
func TestNewStoreFromList(t *testing.T) {
s, err := NewStoreFromList(List{})
assert.NoError(t, err, "Should not error on empty list")
require.IsType(t, &Store{}, s, "Must return a store ref")
l := List{
{Channel: OrderbookChannel},
{Channel: TickerChannel},
}
s, err = NewStoreFromList(l)
assert.NoError(t, err, "Should not error on empty list")
assert.Len(t, s.m, 2, "Map should have 2 values")
assert.NotNil(t, s.get(l[0]), "Should be able to get a list element")
l = append(l, &Subscription{Channel: OrderbookChannel})
_, err = NewStoreFromList(l)
assert.ErrorIs(t, err, ErrDuplicate, "Should error correctly on duplicates")
l = List{nil, &Subscription{Channel: OrderbookChannel}}
_, err = NewStoreFromList(l)
assert.ErrorIs(t, err, common.ErrNilPointer, "Should error correctly on nils")
}
// TestAdd exercises Add and add methods
func TestAdd(t *testing.T) {
assert.ErrorIs(t, (*Store)(nil).Add(&Subscription{}), common.ErrNilPointer, "Should error nil pointer correctly")
assert.ErrorIs(t, (&Store{}).Add(nil), common.ErrNilPointer, "Should error nil pointer correctly")
assert.ErrorIs(t, (&Store{}).Add(&Subscription{}), common.ErrNilPointer, "Should error nil pointer correctly")
s := NewStore()
sub := &Subscription{Channel: TickerChannel}
require.NoError(t, s.Add(sub), "Should not error on a standard add")
assert.NotNil(t, s.get(sub), "Should have stored the sub")
assert.ErrorIs(t, s.Add(sub), ErrDuplicate, "Should error on duplicates")
assert.NotNil(t, sub.Key, sub, "Add should call EnsureKeyed")
}
// TestGet exercises Get and get methods
// Ensures that key's Match is used, but does not exercise subscription.Match; See TestMatch for that coverage
func TestGet(t *testing.T) {
assert.Nil(t, (*Store)(nil).Get(&Subscription{}), "Should return nil when called on nil")
assert.Nil(t, (&Store{}).Get(&Subscription{}), "Should return nil when called with no subscription map")
s := NewStore()
exp := List{
{Channel: AllOrdersChannel},
{Channel: TickerChannel, Pairs: currency.Pairs{btcusdtPair}},
{Key: 42, Channel: OrderbookChannel},
{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}},
}
for _, sub := range exp {
require.NoError(t, s.Add(sub), "Adding subscription must not error)")
}
// Tests for a MatchableKey, ensuring that ExactKey works
assert.Nil(t, s.Get(Subscription{Channel: CandlesChannel}), "Should return nil without pairs")
assert.Nil(t, s.Get(Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{ltcusdcPair}}), "Should return nil with wrong pair")
assert.Nil(t, s.Get(Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair}}), "Should return nil with only one right pair")
assert.Same(t, exp[3], s.Get(Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}), "Should return pointer when all pairs match")
assert.Nil(t, s.Get(Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair, ltcusdcPair}}), "Should return nil when key is superset of pairs")
}
// TestRemove exercises the Remove method
func TestRemove(t *testing.T) {
assert.ErrorIs(t, (*Store)(nil).Remove(&Subscription{}), common.ErrNilPointer, "Should error correctly when called on nil")
assert.ErrorIs(t, (&Store{}).Remove(nil), common.ErrNilPointer, "Should error correctly when called passing nil")
assert.ErrorIs(t, (&Store{}).Remove(&Subscription{}), common.ErrNilPointer, "Should error correctly when called with no subscription map")
s := NewStore()
require.NoError(t, s.Add(&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}), "Adding subscription must not error")
assert.NotNil(t, s.Get(&ExactKey{&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}}), "Should have added the sub")
assert.ErrorIs(t, s.Remove(&ExactKey{&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair}}}), ErrNotFound, "Should error correctly when called with a non-matching key")
assert.NoError(t, s.Remove(&ExactKey{&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}}), "Should not error when called with a matching key")
assert.Nil(t, s.Get(&ExactKey{&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}}), "Should have removed the sub")
assert.ErrorIs(t, s.Remove(&ExactKey{&Subscription{Channel: CandlesChannel, Pairs: currency.Pairs{btcusdtPair, ethusdcPair}}}), ErrNotFound, "Should error correctly when called twice ")
}
// TestList exercises the List and Len methods
func TestList(t *testing.T) {
assert.Empty(t, (*Store)(nil).List(), "Should return an empty List when called on nil")
assert.Empty(t, (&Store{}).List(), "Should return an empty List when called on Store without map")
s := NewStore()
exp := List{
{Channel: OrderbookChannel},
{Channel: TickerChannel},
{Key: 42, Channel: CandlesChannel},
}
for _, sub := range exp {
require.NoError(t, s.Add(sub), "Adding subscription must not error)")
}
l := s.List()
require.Len(t, l, 3, "Must have 3 elements in the list")
assert.ElementsMatch(t, exp, l, "List Should have the same subscriptions")
require.Equal(t, 3, s.Len(), "Len must return 3")
require.Equal(t, 0, (*Store)(nil).Len(), "Len must return 0 on a nil store")
require.Equal(t, 0, (&Store{}).Len(), "Len must return 0 on an uninitialized store")
}
// TestStoreClear exercises the Clear method
func TestStoreClear(t *testing.T) {
assert.NotPanics(t, func() { (*Store)(nil).Clear() }, "Should not panic when called on nil")
s := &Store{}
assert.NotPanics(t, func() { s.Clear() }, "Should not panic when called with no subscription map")
assert.NotNil(t, s.m, "Should create a map when called on an empty Store")
require.NoError(t, s.Add(&Subscription{Channel: CandlesChannel}), "Adding subscription must not error")
require.Len(t, s.m, 1, "Must have a subscription")
s.Clear()
require.Empty(t, s.m, "Map must be empty after clearing")
assert.NotPanics(t, func() { s.Clear() }, "Should not panic when called on an empty map")
}
// TestStoreDiff exercises the Diff method
func TestStoreDiff(t *testing.T) {
s := NewStore()
assert.NotPanics(t, func() { (*Store)(nil).Diff(List{}) }, "Should not panic when called on nil")
assert.NotPanics(t, func() { (&Store{}).Diff(List{}) }, "Should not panic when called with no subscription map")
subs, unsubs := s.Diff(List{{Channel: TickerChannel}, {Channel: CandlesChannel}, {Channel: OrderbookChannel}})
assert.Equal(t, 3, len(subs), "Should get the correct number of subs")
assert.Empty(t, unsubs, "Should get no unsubs")
for _, sub := range subs {
require.NoError(t, s.add(sub), "add must not error")
}
assert.NotPanics(t, func() { s.Diff(nil) }, "Should not panic when called with nil list")
subs, unsubs = s.Diff(List{{Channel: CandlesChannel}})
assert.Empty(t, subs, "Should get no subs")
assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs")
subs, unsubs = s.Diff(List{{Channel: TickerChannel}, {Channel: MyTradesChannel}})
require.Equal(t, 1, len(subs), "Should get the correct number of subs")
assert.Equal(t, MyTradesChannel, subs[0].Channel, "Should get correct channels in sub")
require.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs")
EqualLists(t, unsubs, List{{Channel: OrderbookChannel}, {Channel: CandlesChannel}})
}
func EqualLists(tb testing.TB, a, b List) {
tb.Helper()
// Must not use store.Diff directly
s, err := NewStoreFromList(a)
require.NoError(tb, err, "NewStoreFromList must not error")
missingMap := maps.Clone(s.m)
var added, missing List
for _, sub := range b {
if found := s.get(sub); found != nil {
delete(missingMap, found.Key)
} else {
added = append(added, sub)
}
}
for _, c := range missingMap {
missing = append(missing, c)
}
if len(added) > 0 || len(missing) > 0 {
fail := "Differences:"
if len(added) > 0 {
fail = fail + "\n + " + strings.Join(added.Strings(), "\n + ")
}
if len(missing) > 0 {
fail = fail + "\n - " + strings.Join(missing.Strings(), "\n - ")
}
assert.Fail(tb, fail, "Subscriptions should be equal")
}
}

View File

@@ -1,92 +1,163 @@
package subscription
import (
"encoding/json"
"errors"
"fmt"
"maps"
"slices"
"sync"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
)
// DefaultKey is the fallback key for AddSuccessfulSubscriptions
type DefaultKey struct {
Channel string
Pair currency.Pair
Asset asset.Item
}
// State constants
const (
InactiveState State = iota
SubscribingState
SubscribedState
ResubscribingState
UnsubscribingState
UnsubscribedState
)
// Channel constants
const (
TickerChannel = "ticker"
OrderbookChannel = "orderbook"
CandlesChannel = "candles"
AllOrdersChannel = "allOrders"
AllTradesChannel = "allTrades"
MyTradesChannel = "myTrades"
MyOrdersChannel = "myOrders"
)
// Public errors
var (
ErrNotFound = errors.New("subscription not found")
ErrNotSinglePair = errors.New("only single pair subscriptions expected")
ErrInStateAlready = errors.New("subscription already in state")
ErrInvalidState = errors.New("invalid subscription state")
ErrDuplicate = errors.New("duplicate subscription")
)
// State tracks the status of a subscription channel
type State uint8
const (
UnknownState State = iota // UnknownState subscription state is not registered, but doesn't imply Inactive
SubscribingState // SubscribingState means channel is in the process of subscribing
SubscribedState // SubscribedState means the channel has finished a successful and acknowledged subscription
UnsubscribingState // UnsubscribingState means the channel has started to unsubscribe, but not yet confirmed
TickerChannel = "ticker" // TickerChannel Subscription Type
OrderbookChannel = "orderbook" // OrderbookChannel Subscription Type
CandlesChannel = "candles" // CandlesChannel Subscription Type
AllOrdersChannel = "allOrders" // AllOrdersChannel Subscription Type
AllTradesChannel = "allTrades" // AllTradesChannel Subscription Type
MyTradesChannel = "myTrades" // MyTradesChannel Subscription Type
MyOrdersChannel = "myOrders" // MyOrdersChannel Subscription Type
)
// Subscription container for streaming subscriptions
type Subscription struct {
Enabled bool `json:"enabled"`
Key any `json:"-"`
Channel string `json:"channel,omitempty"`
Pair currency.Pair `json:"pair,omitempty"`
Asset asset.Item `json:"asset,omitempty"`
Params map[string]interface{} `json:"params,omitempty"`
State State `json:"-"`
Interval kline.Interval `json:"interval,omitempty"`
Levels int `json:"levels,omitempty"`
Authenticated bool `json:"authenticated,omitempty"`
Enabled bool `json:"enabled"`
Key any `json:"-"`
Channel string `json:"channel,omitempty"`
Pairs currency.Pairs `json:"pairs,omitempty"`
Asset asset.Item `json:"asset,omitempty"`
Params map[string]any `json:"params,omitempty"`
Interval kline.Interval `json:"interval,omitempty"`
Levels int `json:"levels,omitempty"`
Authenticated bool `json:"authenticated,omitempty"`
state State
m sync.RWMutex
}
// MarshalJSON generates a JSON representation of a Subscription, specifically for config writing
// The only reason it exists is to avoid having to make Pair a pointer, since that would be generally painful
// If Pair becomes a pointer, this method is redundant and should be removed
func (s *Subscription) MarshalJSON() ([]byte, error) {
// None of the usual type embedding tricks seem to work for not emitting an nil Pair
// The embedded type's Pair always fills the empty value
type MaybePair struct {
Enabled bool `json:"enabled"`
Channel string `json:"channel,omitempty"`
Asset asset.Item `json:"asset,omitempty"`
Params map[string]interface{} `json:"params,omitempty"`
Interval kline.Interval `json:"interval,omitempty"`
Levels int `json:"levels,omitempty"`
Authenticated bool `json:"authenticated,omitempty"`
Pair *currency.Pair `json:"pair,omitempty"`
}
k := MaybePair{s.Enabled, s.Channel, s.Asset, s.Params, s.Interval, s.Levels, s.Authenticated, nil}
if s.Pair != currency.EMPTYPAIR {
k.Pair = &s.Pair
}
return json.Marshal(k)
}
// String implements the Stringer interface for Subscription, giving a human representation of the subscription
// String implements Stringer, and aims to informatively and uniquely identify a subscription for errors and information
// returns a string of the subscription key by delegating to MatchableKey.String() when possible
// If the key is not a MatchableKey then both the key and an ExactKey.String() will be returned; e.g. 1137: spot MyTrades
func (s *Subscription) String() string {
return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, s.Pair)
key := s.EnsureKeyed()
s.m.RLock()
defer s.m.RUnlock()
if k, ok := key.(MatchableKey); ok {
return k.String()
}
return fmt.Sprintf("%v: %s", key, ExactKey{s}.String())
}
// EnsureKeyed sets the default key on a channel if it doesn't have one
// Returns key for convenience
// State returns the subscription state
func (s *Subscription) State() State {
s.m.RLock()
defer s.m.RUnlock()
return s.state
}
// SetState sets the subscription state
// Errors if already in that state or the new state is not valid
func (s *Subscription) SetState(state State) error {
s.m.Lock()
defer s.m.Unlock()
if state == s.state {
return ErrInStateAlready
}
if state > UnsubscribedState {
return ErrInvalidState
}
s.state = state
return nil
}
// SetKey does what it says on the tin safely for concurrency
func (s *Subscription) SetKey(key any) {
s.m.Lock()
defer s.m.Unlock()
s.Key = key
}
// EnsureKeyed returns the subscription key
// If no key exists then ExactKey will be used
func (s *Subscription) EnsureKeyed() any {
if s.Key == nil {
s.Key = DefaultKey{
Channel: s.Channel,
Asset: s.Asset,
Pair: s.Pair,
}
// Juggle RLock/WLock to minimize concurrent bottleneck for hottest path
s.m.RLock()
if s.Key != nil {
defer s.m.RUnlock()
return s.Key
}
s.m.RUnlock()
s.m.Lock()
defer s.m.Unlock()
if s.Key == nil { // Ensure race hasn't updated Key whilst we swapped locks
s.Key = &ExactKey{s}
}
return s.Key
}
// Clone returns a copy of a subscription
// Key is set to nil, because most Key types contain a pointer to the subscription, and because the clone isn't added to the store yet
// Users should allow a default key to be assigned on AddSubscription or can SetKey as necessary
func (s *Subscription) Clone() *Subscription {
s.m.RLock()
c := &Subscription{
Key: nil,
Enabled: s.Enabled,
Channel: s.Channel,
Asset: s.Asset,
Params: s.Params,
Interval: s.Interval,
Levels: s.Levels,
Authenticated: s.Authenticated,
state: s.state,
Pairs: s.Pairs,
}
s.Pairs = slices.Clone(s.Pairs)
s.Params = maps.Clone(s.Params)
s.m.RUnlock()
return c
}
// SetPairs does what it says on the tin safely for concurrency
func (s *Subscription) SetPairs(pairs currency.Pairs) {
s.m.Lock()
s.Pairs = pairs
s.m.Unlock()
}
// AddPairs does what it says on the tin safely for concurrency
func (s *Subscription) AddPairs(pairs ...currency.Pair) {
if len(pairs) == 0 {
return
}
s.m.Lock()
for _, p := range pairs {
s.Pairs = s.Pairs.Add(p)
}
s.m.Unlock()
}

View File

@@ -10,39 +10,80 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
)
// TestEnsureKeyed logic test
func TestEnsureKeyed(t *testing.T) {
t.Parallel()
c := Subscription{
var (
btcusdtPair = currency.NewPair(currency.BTC, currency.USDT)
ethusdcPair = currency.NewPair(currency.ETH, currency.USDC)
ltcusdcPair = currency.NewPair(currency.LTC, currency.USDC)
)
// TestSubscriptionString exercises the String method
func TestSubscriptionString(t *testing.T) {
s := &Subscription{
Channel: "candles",
Asset: asset.Spot,
Pair: currency.NewPair(currency.BTC, currency.USDT),
}
k1, ok := c.EnsureKeyed().(DefaultKey)
if assert.True(t, ok, "EnsureKeyed should return a DefaultKey") {
assert.Exactly(t, k1, c.Key, "EnsureKeyed should set the same key")
assert.Equal(t, k1.Channel, c.Channel, "DefaultKey channel should be correct")
assert.Equal(t, k1.Asset, c.Asset, "DefaultKey asset should be correct")
assert.Equal(t, k1.Pair, c.Pair, "DefaultKey currency should be correct")
}
type platypus string
c = Subscription{
Key: platypus("Gerald"),
Channel: "orderbook",
Asset: asset.Margin,
Pair: currency.NewPair(currency.ETH, currency.USDC),
}
k2, ok := c.EnsureKeyed().(platypus)
if assert.True(t, ok, "EnsureKeyed should return a platypus") {
assert.Exactly(t, k2, c.Key, "EnsureKeyed should set the same key")
assert.EqualValues(t, "Gerald", k2, "key should have the correct value")
Pairs: currency.Pairs{btcusdtPair, ethusdcPair.Format(currency.PairFormat{Delimiter: "/"})},
}
assert.Equal(t, "candles spot BTC/USDT,ETH/USDC", s.String(), "Subscription String should return correct value")
}
// TestMarshalling logic test
func TestMarshaling(t *testing.T) {
// TestState exercises the state getter
func TestState(t *testing.T) {
t.Parallel()
j, err := json.Marshal(&Subscription{Channel: CandlesChannel})
s := &Subscription{}
assert.Equal(t, InactiveState, s.State(), "State should return initial state")
s.state = SubscribedState
assert.Equal(t, SubscribedState, s.State(), "State should return correct state")
}
// TestSetState exercises the state setter
func TestSetState(t *testing.T) {
t.Parallel()
s := &Subscription{state: UnsubscribedState}
for i := InactiveState; i <= UnsubscribedState; i++ {
assert.NoErrorf(t, s.SetState(i), "State should not error setting state %s", i)
}
assert.ErrorIs(t, s.SetState(UnsubscribedState), ErrInStateAlready, "SetState should error on same state")
assert.ErrorIs(t, s.SetState(UnsubscribedState+1), ErrInvalidState, "Setting an invalid state should error")
}
// TestString exercises the Stringer implementation
func TestString(t *testing.T) {
s := &Subscription{
Channel: "candles",
Asset: asset.Spot,
Pairs: currency.Pairs{btcusdtPair},
}
_ = s.EnsureKeyed()
assert.Equal(t, "candles spot BTC/USDT", s.String(), "String with a MatchableKey")
s.Key = 42
assert.Equal(t, "42: candles spot BTC/USDT", s.String(), "String with a MatchableKey")
}
// TestEnsureKeyed exercises the key getter and ensures it sets a self-pointer key for non
func TestEnsureKeyed(t *testing.T) {
t.Parallel()
s := &Subscription{}
k1, ok := s.EnsureKeyed().(MatchableKey)
if assert.True(t, ok, "EnsureKeyed should return a MatchableKey") {
assert.Same(t, s, k1.GetSubscription(), "Key should point to the same struct")
}
type platypus string
s = &Subscription{
Key: platypus("Gerald"),
Channel: "orderbook",
}
k2 := s.EnsureKeyed()
assert.IsType(t, platypus(""), k2, "EnsureKeyed should return a platypus")
assert.Equal(t, s.Key, k2, "Key should be the key provided")
}
// TestSubscriptionMarshalling ensures json Marshalling is clean and concise
// Since there is no UnmarshalJSON, this just exercises the json field tags of Subscription, and regressions in conciseness
func TestSubscriptionMarshaling(t *testing.T) {
t.Parallel()
j, err := json.Marshal(&Subscription{Key: 42, Channel: CandlesChannel})
assert.NoError(t, err, "Marshalling should not error")
assert.Equal(t, `{"enabled":false,"channel":"candles"}`, string(j), "Marshalling should be clean and concise")
@@ -50,11 +91,46 @@ func TestMarshaling(t *testing.T) {
assert.NoError(t, err, "Marshalling should not error")
assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise")
j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pair: currency.NewPair(currency.BTC, currency.USDT)})
j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USDT)}})
assert.NoError(t, err, "Marshalling should not error")
assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4,"pair":"BTCUSDT"}`, string(j), "Marshalling should be clean and concise")
assert.Equal(t, `{"enabled":true,"channel":"orderbook","pairs":"BTCUSDT","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise")
j, err = json.Marshal(&Subscription{Enabled: true, Channel: MyTradesChannel, Authenticated: true})
assert.NoError(t, err, "Marshalling should not error")
assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise")
}
// TestClone exercises Clone
func TestClone(t *testing.T) {
a := &Subscription{
Channel: TickerChannel,
Interval: kline.OneHour,
Pairs: currency.Pairs{btcusdtPair},
Params: map[string]any{"a": 42},
}
a.EnsureKeyed()
b := a.Clone()
assert.IsType(t, new(Subscription), b, "Clone must return a Subscription pointer")
assert.NotSame(t, a, b, "Clone should return a new Subscription")
assert.Nil(t, b.Key, "Clone should have a nil key")
b.Pairs[0] = ethusdcPair
assert.Equal(t, btcusdtPair, a.Pairs[0], "Pairs should be (relatively) deep copied")
b.Params["a"] = 12
assert.Equal(t, 42, a.Params["a"], "Params should be (relatively) deep copied")
a.m.Lock()
assert.True(t, b.m.TryLock(), "Clone must use a different Mutex")
}
// TestSetKey exercises SetKey
func TestSetKey(t *testing.T) {
s := &Subscription{}
s.SetKey(14)
assert.Equal(t, 14, s.Key, "SetKey should set a key correctly")
}
// TestSetPairs exercises SetPairs
func TestSetPairs(t *testing.T) {
s := &Subscription{}
s.SetPairs(currency.Pairs{btcusdtPair})
assert.Equal(t, "BTCUSDT", s.Pairs.Join(), "SetPairs should set a key correctly")
}