stream/match: Reduce complexity and limit locking when match occurs (#1581)

* stream match update

* update tests

* linter: fix

* glorious: nits + handle context cancellations

* glorious: whooops

* Websocket: Add SendMessageReturnResponses

* whooooooopsie

* gk: nitssssss

* Update exchanges/stream/stream_match.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/stream/stream_match_test.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* linter: appease the linter gods

* glorious: nits

* glorious: nits

* Update exchanges/stream/stream_match_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2024-08-19 10:35:46 +10:00
committed by GitHub
parent 225429bda6
commit 17c2ef2ec7
23 changed files with 207 additions and 178 deletions

View File

@@ -577,7 +577,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error {
Params: subs.QualifiedChannels(),
}
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(req.ID, req)
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req)
if err == nil {
if v, d, _, rErr := jsonparser.Get(respRaw, "result"); rErr != nil {
err = rErr

View File

@@ -1315,7 +1315,7 @@ func TestWsCancelOffer(t *testing.T) {
}
func TestWsSubscribedResponse(t *testing.T) {
m, err := b.Websocket.Match.Set("subscribe:waiter1")
ch, err := b.Websocket.Match.Set("subscribe:waiter1", 1)
assert.NoError(t, err, "Setting a matcher should not error")
err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`))
if assert.Error(t, err, "Should error if sub is not registered yet") {
@@ -1328,13 +1328,12 @@ func TestWsSubscribedResponse(t *testing.T) {
require.NoError(t, err, "AddSubscriptions must not error")
err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`))
assert.NoError(t, err, "wsHandleData should not error")
if assert.NotEmpty(t, m.C, "Matcher should have received a sub notification") {
msg := <-m.C
if assert.NotEmpty(t, ch, "Matcher should have received a sub notification") {
msg := <-ch
cID, err := jsonparser.GetInt(msg, "chanId")
assert.NoError(t, err, "Should get chanId from sub notification without error")
assert.EqualValues(t, 224555, cID, "Should get the correct chanId through the matcher notification")
}
m.Cleanup()
}
func TestWsOrderBook(t *testing.T) {

View File

@@ -1756,7 +1756,7 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error {
_ = b.Websocket.RemoveSubscriptions(c)
}()
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req)
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "subscribe:"+subID, req)
if err != nil {
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs)
}
@@ -1849,7 +1849,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error {
"chanId": chanID,
}
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("unsubscribe:"+strconv.Itoa(chanID), req)
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "unsubscribe:"+strconv.Itoa(chanID), req)
if err != nil {
return err
}
@@ -1926,7 +1926,7 @@ func (b *Bitfinex) WsSendAuth(ctx context.Context) error {
func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) {
data.CustomID = b.Websocket.AuthConn.GenerateMessageID(false)
request := makeRequestInterface(wsOrderNew, data)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.CustomID, request)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.CustomID, request)
if err != nil {
return "", err
}
@@ -1983,7 +1983,7 @@ func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) {
// WsModifyOrder authenticated modify order request
func (b *Bitfinex) WsModifyOrder(data *WsUpdateOrderRequest) error {
request := makeRequestInterface(wsOrderUpdate, data)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.OrderID, request)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.OrderID, request)
if err != nil {
return err
}
@@ -2037,7 +2037,7 @@ func (b *Bitfinex) WsCancelOrder(orderID int64) error {
OrderID: orderID,
}
request := makeRequestInterface(wsOrderCancel, cancel)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request)
if err != nil {
return err
}
@@ -2094,7 +2094,7 @@ func (b *Bitfinex) WsCancelOffer(orderID int64) error {
OrderID: orderID,
}
request := makeRequestInterface(wsFundingOfferCancel, cancel)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request)
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request)
if err != nil {
return err
}

View File

@@ -118,7 +118,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
Operation: "auth",
Args: []interface{}{creds.Key, intNonce, sign},
}
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(req.RequestID, req)
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.RequestID, req)
if err != nil {
return err
}
@@ -220,12 +220,12 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
for a := range payloads {
var response []byte
if payloads[a].auth {
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a])
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a])
if err != nil {
return err
}
} else {
response, err = by.Websocket.Conn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a])
response, err = by.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a])
if err != nil {
return err
}

View File

@@ -480,7 +480,7 @@ func (c *COINUT) WsGetInstruments() (Instruments, error) {
SecurityType: strings.ToUpper(asset.Spot.String()),
Nonce: getNonce(),
}
resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request)
if err != nil {
return list, err
}
@@ -648,7 +648,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error {
Subscribe: false,
Nonce: getNonce(),
}
resp, err := c.Websocket.Conn.SendMessageReturnResponse(subscribe.Nonce, subscribe)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), subscribe.Nonce, subscribe)
if err != nil {
errs = common.AppendError(errs, err)
continue
@@ -691,7 +691,7 @@ func (c *COINUT) wsAuthenticate(ctx context.Context) error {
}
r.Hmac = crypto.HexEncodeToString(hmac)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(r.Nonce, r)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), r.Nonce, r)
if err != nil {
return err
}
@@ -714,7 +714,7 @@ func (c *COINUT) wsGetAccountBalance() (*UserBalance, error) {
Request: "user_balance",
Nonce: getNonce(),
}
resp, err := c.Websocket.Conn.SendMessageReturnResponse(accBalance.Nonce, accBalance)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), accBalance.Nonce, accBalance)
if err != nil {
return nil, err
}
@@ -750,7 +750,7 @@ func (c *COINUT) wsSubmitOrder(o *WsSubmitOrderParameters) (*order.Detail, error
if o.OrderID > 0 {
orderSubmissionRequest.OrderID = o.OrderID
}
resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderSubmissionRequest.Nonce, orderSubmissionRequest)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderSubmissionRequest.Nonce, orderSubmissionRequest)
if err != nil {
return nil, err
}
@@ -793,7 +793,7 @@ func (c *COINUT) wsSubmitOrders(orders []WsSubmitOrderParameters) ([]order.Detai
orderRequest.Nonce = getNonce()
orderRequest.Request = "new_orders"
resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderRequest.Nonce, orderRequest)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderRequest.Nonce, orderRequest)
if err != nil {
errs = append(errs, err)
return nil, errs
@@ -829,7 +829,7 @@ func (c *COINUT) wsGetOpenOrders(curr string) (*WsUserOpenOrdersResponse, error)
openOrdersRequest.Nonce = getNonce()
openOrdersRequest.InstrumentID = c.instrumentMap.LookupID(curr)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(openOrdersRequest.Nonce, openOrdersRequest)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), openOrdersRequest.Nonce, openOrdersRequest)
if err != nil {
return response, err
}
@@ -862,7 +862,7 @@ func (c *COINUT) wsCancelOrder(cancellation *WsCancelOrderParameters) (*CancelOr
cancellationRequest.OrderID = cancellation.OrderID
cancellationRequest.Nonce = getNonce()
resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancellationRequest.Nonce, cancellationRequest)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancellationRequest.Nonce, cancellationRequest)
if err != nil {
return response, err
}
@@ -903,7 +903,7 @@ func (c *COINUT) wsCancelOrders(cancellations []WsCancelOrderParameters) (*Cance
cancelOrderRequest.Request = "cancel_orders"
cancelOrderRequest.Nonce = getNonce()
resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancelOrderRequest.Nonce, cancelOrderRequest)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancelOrderRequest.Nonce, cancelOrderRequest)
if err != nil {
return response, err
}
@@ -933,7 +933,7 @@ func (c *COINUT) wsGetTradeHistory(p currency.Pair, start, limit int64) (*WsTrad
request.Start = start
request.Limit = limit
resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request)
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request)
if err != nil {
return response, err
}

View File

@@ -147,7 +147,7 @@ func (d *Deribit) wsLogin(ctx context.Context) error {
"signature": crypto.HexEncodeToString(hmac),
},
}
resp, err := d.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
d.Websocket.SetCanUseAuthenticatedEndpoints(false)
return err
@@ -1165,7 +1165,7 @@ func (d *Deribit) handleSubscription(operation string, channels subscription.Lis
return err
}
for x := range payloads {
data, err := d.Websocket.Conn.SendMessageReturnResponse(payloads[x].ID, payloads[x])
data, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[x].ID, payloads[x])
if err != nil {
return err
}

View File

@@ -2406,7 +2406,7 @@ func (d *Deribit) sendWsPayload(ep request.EndpointLimit, input *WsRequest, resp
log.Debugf(log.RequestSys, "%s attempt %d", d.Name, attempt)
}
var payload []byte
payload, err = d.Websocket.Conn.SendMessageReturnResponse(input.ID, input)
payload, err = d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), input.ID, input)
if err != nil {
return err
}

View File

@@ -700,7 +700,7 @@ func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscripti
}
var errs error
for k := range payloads {
result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k])
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k])
if err != nil {
errs = common.AppendError(errs, err)
continue

View File

@@ -207,9 +207,9 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc
for con, val := range payloads {
for k := range val {
if con == 0 {
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k])
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
} else {
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k])
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
}
if err != nil {
errs = common.AppendError(errs, err)

View File

@@ -287,9 +287,9 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe sub
for con, val := range payloads {
for k := range val {
if con == 0 {
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k])
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
} else {
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k])
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
}
if err != nil {
errs = common.AppendError(errs, err)

View File

@@ -319,7 +319,7 @@ func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe sub
}
var errs error
for k := range payloads {
result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k])
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k])
if err != nil {
errs = common.AppendError(errs, err)
continue

View File

@@ -632,7 +632,7 @@ func (h *HitBTC) wsPlaceOrder(pair currency.Pair, side string, price, quantity f
},
ID: id,
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(id, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), id, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -659,7 +659,7 @@ func (h *HitBTC) wsCancelOrder(clientOrderID string) (*WsCancelOrderResponse, er
},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -689,7 +689,7 @@ func (h *HitBTC) wsReplaceOrder(clientOrderID string, quantity, price float64) (
},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -714,7 +714,7 @@ func (h *HitBTC) wsGetActiveOrders() (*wsActiveOrdersResponse, error) {
Params: WsReplaceOrderRequestData{},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -739,7 +739,7 @@ func (h *HitBTC) wsGetTradingBalance() (*WsGetTradingBalanceResponse, error) {
Params: WsReplaceOrderRequestData{},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -763,7 +763,7 @@ func (h *HitBTC) wsGetCurrencies(currencyItem currency.Code) (*WsGetCurrenciesRe
},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -792,7 +792,7 @@ func (h *HitBTC) wsGetSymbols(c currency.Pair) (*WsGetSymbolsResponse, error) {
},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}
@@ -824,7 +824,7 @@ func (h *HitBTC) wsGetTrades(c currency.Pair, limit int64, sort, by string) (*Ws
},
ID: h.Websocket.Conn.GenerateMessageID(false),
}
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
if err != nil {
return nil, fmt.Errorf("%v %v", h.Name, err)
}

View File

@@ -700,7 +700,7 @@ func (h *HUOBI) wsGetAccountsList(ctx context.Context) (*WsAuthenticatedAccounts
}
request.Signature = crypto.Base64Encode(hmac)
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
if err != nil {
return nil, err
}
@@ -752,7 +752,7 @@ func (h *HUOBI) wsGetOrdersList(ctx context.Context, accountID int64, pair curre
request.Signature = crypto.Base64Encode(hmac)
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
if err != nil {
return nil, err
}
@@ -794,7 +794,7 @@ func (h *HUOBI) wsGetOrderDetails(ctx context.Context, orderID string) (*WsAuthe
}
request.Signature = crypto.Base64Encode(hmac)
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
if err != nil {
return nil, err
}

View File

@@ -1229,9 +1229,9 @@ channels:
for i := range *subs {
var err error
if common.StringDataContains(authenticatedChannels, (*subs)[i].Subscription.Name) {
_, err = k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i])
} else {
_, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
_, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i])
}
if err == nil {
err = k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...)
@@ -1288,9 +1288,9 @@ channels:
for i := range unsubs {
var err error
if common.StringDataContains(authenticatedChannels, unsubs[i].Subscription.Name) {
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i])
} else {
_, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
_, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i])
}
if err == nil {
err = k.Websocket.RemoveSubscriptions(unsubs[i].Channels...)
@@ -1308,7 +1308,7 @@ func (k *Kraken) wsAddOrder(request *WsAddOrderRequest) (string, error) {
request.RequestID = id
request.Event = krakenWsAddOrder
request.Token = authToken
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
if err != nil {
return "", err
}
@@ -1347,7 +1347,7 @@ func (k *Kraken) wsCancelOrder(orderID string) error {
RequestID: id,
}
resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
if err != nil {
return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err)
}
@@ -1377,7 +1377,7 @@ func (k *Kraken) wsCancelAllOrders() (*WsCancelOrderResponse, error) {
RequestID: id,
}
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
if err != nil {
return &WsCancelOrderResponse{}, err
}

View File

@@ -956,7 +956,7 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string)
PrivateChannel: s.Authenticated,
Response: true,
}
if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse("msgID:"+msgID, req); err != nil {
if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "msgID:"+msgID, req); err != nil {
errs = common.AppendError(errs, err)
} else {
rType, err := jsonparser.GetUnsafeString(respRaw, "type")

View File

@@ -147,7 +147,7 @@ func (o *Okcoin) WsLogin(ctx context.Context, dialer *websocket.Dialer) error {
},
},
}
_, err = o.Websocket.AuthConn.SendMessageReturnResponse("login", authRequest)
_, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), "login", authRequest)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package okcoin
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -153,9 +154,9 @@ func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}
var err error
// TODO: ratelimits for websocket
if authenticated {
byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(req.ID, req)
byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.ID, req)
} else {
byteData, err = o.Websocket.Conn.SendMessageReturnResponse(req.ID, req)
byteData, err = o.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req)
}
if err != nil {
return err

View File

@@ -5,11 +5,14 @@ import (
"sync"
)
var (
errSignatureCollision = errors.New("signature collision")
errInvalidBufferSize = errors.New("buffer size must be positive")
)
// NewMatch returns a new Match
func NewMatch() *Match {
return &Match{
m: make(map[interface{}]chan []byte),
}
return &Match{m: make(map[any]*incoming)}
}
// Match is a distributed subtype that handles the matching of requests and
@@ -17,64 +20,54 @@ func NewMatch() *Match {
// connections. Stream systems fan in all incoming payloads to one routine for
// processing.
type Match struct {
m map[interface{}]chan []byte
m map[any]*incoming
mu sync.Mutex
}
// Matcher defines a payload matching return mechanism
type Matcher struct {
C chan []byte
sig interface{}
m *Match
}
// Incoming matches with request, disregarding the returned payload
func (m *Match) Incoming(signature interface{}) bool {
return m.IncomingWithData(signature, nil)
type incoming struct {
expected int
c chan<- []byte
}
// IncomingWithData matches with requests and takes in the returned payload, to
// be processed outside of a stream processing routine and returns true if a handler was found
func (m *Match) IncomingWithData(signature interface{}, data []byte) bool {
func (m *Match) IncomingWithData(signature any, data []byte) bool {
m.mu.Lock()
defer m.mu.Unlock()
ch, ok := m.m[signature]
if ok {
select {
case ch <- data:
default:
// this shouldn't occur but if it does continue to process as normal
return false
}
return true
if !ok {
return false
}
return false
ch.c <- data
ch.expected--
if ch.expected == 0 {
close(ch.c)
delete(m.m, signature)
}
return true
}
// Set the signature response channel for incoming data
func (m *Match) Set(signature interface{}) (Matcher, error) {
var ch chan []byte
m.mu.Lock()
if _, ok := m.m[signature]; ok {
m.mu.Unlock()
return Matcher{}, errors.New("signature collision")
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
if bufSize <= 0 {
return nil, errInvalidBufferSize
}
// This is buffered so we don't need to wait for receiver.
ch = make(chan []byte, 1)
m.m[signature] = ch
m.mu.Unlock()
return Matcher{
C: ch,
sig: signature,
m: m,
}, nil
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.m[signature]; ok {
return nil, errSignatureCollision
}
ch := make(chan []byte, bufSize)
m.m[signature] = &incoming{expected: bufSize, c: ch}
return ch, nil
}
// Cleanup closes underlying channel and deletes signature from map
func (m *Matcher) Cleanup() {
m.m.mu.Lock()
close(m.C)
delete(m.m.m, m.sig)
m.m.mu.Unlock()
// RemoveSignature removes the signature response from map and closes the channel.
func (m *Match) RemoveSignature(signature any) {
m.mu.Lock()
defer m.mu.Unlock()
if ch, ok := m.m[signature]; ok {
close(ch.c)
delete(m.m, signature)
}
}

View File

@@ -1,50 +1,53 @@
package stream
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMatch(t *testing.T) {
t.Parallel()
bm := &Match{}
if bm.Incoming("wow") {
t.Fatal("Should not have matched")
}
load := []byte("42")
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
nm := NewMatch()
// try to match with unset signature
if nm.Incoming("hello") {
t.Fatal("should not be able to match")
}
match := NewMatch()
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
m, err := nm.Set("hello")
if err != nil {
t.Fatal(err)
}
_, err := match.Set("hello", 0)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
_, err = match.Set("hello", -1)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
ch, err := match.Set("hello", 2)
require.NoError(t, err, "Set must not error")
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
assert.Equal(t, "hello", string(<-ch))
_, err = nm.Set("hello")
if err == nil {
t.Fatal("error cannot be nil as this collision cannot occur")
}
_, err = match.Set("hello", 2)
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
if m.sig != "hello" {
t.Fatal(err)
}
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
// try and match with initial payload
if !nm.Incoming("hello") {
t.Fatal("should of matched")
}
// put in secondary payload with conflicting signature
if nm.Incoming("hello") {
fmt.Println("should not have been able to match")
}
if data := <-m.C; data != nil {
t.Fatal("data chan should be nil")
}
m.Cleanup()
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
}
func TestRemoveSignature(t *testing.T) {
t.Parallel()
match := NewMatch()
ch, err := match.Set("masterblaster", 1)
select {
case <-ch:
t.Fatal("Should not be able to read from an empty channel")
default:
}
require.NoError(t, err)
match.RemoveSignature("masterblaster")
select {
case garbage := <-ch:
require.Empty(t, garbage)
default:
t.Fatal("Should be able to read from a closed channel")
}
}

View File

@@ -1,6 +1,7 @@
package stream
import (
"context"
"net/http"
"time"
@@ -14,10 +15,11 @@ import (
type Connection interface {
Dial(*websocket.Dialer, http.Header) error
ReadMessage() Response
SendJSONMessage(interface{}) error
SendJSONMessage(any) error
SetupPingHandler(PingHandler)
GenerateMessageID(highPrecision bool) int64
SendMessageReturnResponse(signature interface{}, request interface{}) ([]byte, error)
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
SendRawMessage(messageType int, message []byte) error
SetURL(string)
SetProxy(string)

View File

@@ -29,6 +29,8 @@ var (
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrNoMessageListener = errors.New("websocket listener not found for message")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
)
// Private websocket errors

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"encoding/json"
"fmt"
@@ -19,41 +20,6 @@ import (
"github.com/thrasher-corp/gocryptotrader/log"
)
// SendMessageReturnResponse will send a WS message to the connection and wait
// for response
func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) {
m, err := w.Match.Set(signature)
if err != nil {
return nil, err
}
defer m.Cleanup()
b, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
}
start := time.Now()
err = w.SendRawMessage(websocket.TextMessage, b)
if err != nil {
return nil, err
}
timer := time.NewTimer(w.ResponseMaxLimit)
select {
case payload := <-m.C:
if w.Reporter != nil {
w.Reporter.Latency(w.ExchangeName, b, time.Since(start))
}
return payload, nil
case <-timer.C:
timer.Stop()
return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature)
}
}
// Dial sets proxy urls and then connects to the websocket
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
if w.ProxyURL != "" {
@@ -303,3 +269,56 @@ func (w *WebsocketConnection) SetProxy(proxy string) {
func (w *WebsocketConnection) GetURL() string {
return w.URL
}
// SendMessageReturnResponse will send a WS message to the connection and wait for response
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
if err != nil {
return nil, err
}
return resps[0], nil
}
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) {
outbound, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
}
ch, err := w.Match.Set(signature, expected)
if err != nil {
return nil, err
}
start := time.Now()
err = w.SendRawMessage(websocket.TextMessage, outbound)
if err != nil {
return nil, err
}
timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected))
resps := make([][]byte, 0, expected)
for err == nil && len(resps) < expected {
select {
case resp := <-ch:
resps = append(resps, resp)
case <-timeout.C:
w.Match.RemoveSignature(signature)
err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature)
case <-ctx.Done():
w.Match.RemoveSignature(signature)
err = ctx.Err()
}
}
timeout.Stop()
if err == nil && w.Reporter != nil {
w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start))
}
return resps, err
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
@@ -724,8 +725,7 @@ func TestSendMessage(t *testing.T) {
}
}
// TestSendMessageWithResponse logic test
func TestSendMessageWithResponse(t *testing.T) {
func TestSendMessageReturnResponse(t *testing.T) {
t.Parallel()
wc := &WebsocketConnection{
Verbose: true,
@@ -753,10 +753,20 @@ func TestSendMessageWithResponse(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
if err != nil {
t.Error(err)
}
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
fn()
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
assert.ErrorIs(t, err, context.DeadlineExceeded)
// with timeout
wc.ResponseMaxLimit = 1
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
}
type reporter struct {
@@ -1182,7 +1192,7 @@ func TestLatency(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
if err != nil {
t.Error(err)
}