mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-18 07:26:50 +00:00
stream/match: Reduce complexity and limit locking when match occurs (#1581)
* stream match update * update tests * linter: fix * glorious: nits + handle context cancellations * glorious: whooops * Websocket: Add SendMessageReturnResponses * whooooooopsie * gk: nitssssss * Update exchanges/stream/stream_match.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/stream/stream_match_test.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * linter: appease the linter gods * glorious: nits * glorious: nits * Update exchanges/stream/stream_match_test.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io> Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
@@ -577,7 +577,7 @@ func (b *Binance) manageSubs(op string, subs subscription.List) error {
|
||||
Params: subs.QualifiedChannels(),
|
||||
}
|
||||
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(req.ID, req)
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req)
|
||||
if err == nil {
|
||||
if v, d, _, rErr := jsonparser.Get(respRaw, "result"); rErr != nil {
|
||||
err = rErr
|
||||
|
||||
@@ -1315,7 +1315,7 @@ func TestWsCancelOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWsSubscribedResponse(t *testing.T) {
|
||||
m, err := b.Websocket.Match.Set("subscribe:waiter1")
|
||||
ch, err := b.Websocket.Match.Set("subscribe:waiter1", 1)
|
||||
assert.NoError(t, err, "Setting a matcher should not error")
|
||||
err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`))
|
||||
if assert.Error(t, err, "Should error if sub is not registered yet") {
|
||||
@@ -1328,13 +1328,12 @@ func TestWsSubscribedResponse(t *testing.T) {
|
||||
require.NoError(t, err, "AddSubscriptions must not error")
|
||||
err = b.wsHandleData([]byte(`{"event":"subscribed","channel":"ticker","chanId":224555,"subId":"waiter1","symbol":"tBTCUSD","pair":"BTCUSD"}`))
|
||||
assert.NoError(t, err, "wsHandleData should not error")
|
||||
if assert.NotEmpty(t, m.C, "Matcher should have received a sub notification") {
|
||||
msg := <-m.C
|
||||
if assert.NotEmpty(t, ch, "Matcher should have received a sub notification") {
|
||||
msg := <-ch
|
||||
cID, err := jsonparser.GetInt(msg, "chanId")
|
||||
assert.NoError(t, err, "Should get chanId from sub notification without error")
|
||||
assert.EqualValues(t, 224555, cID, "Should get the correct chanId through the matcher notification")
|
||||
}
|
||||
m.Cleanup()
|
||||
}
|
||||
|
||||
func TestWsOrderBook(t *testing.T) {
|
||||
|
||||
@@ -1756,7 +1756,7 @@ func (b *Bitfinex) subscribeToChan(chans subscription.List) error {
|
||||
_ = b.Websocket.RemoveSubscriptions(c)
|
||||
}()
|
||||
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req)
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "subscribe:"+subID, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs)
|
||||
}
|
||||
@@ -1849,7 +1849,7 @@ func (b *Bitfinex) unsubscribeFromChan(chans subscription.List) error {
|
||||
"chanId": chanID,
|
||||
}
|
||||
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("unsubscribe:"+strconv.Itoa(chanID), req)
|
||||
respRaw, err := b.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "unsubscribe:"+strconv.Itoa(chanID), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1926,7 +1926,7 @@ func (b *Bitfinex) WsSendAuth(ctx context.Context) error {
|
||||
func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) {
|
||||
data.CustomID = b.Websocket.AuthConn.GenerateMessageID(false)
|
||||
request := makeRequestInterface(wsOrderNew, data)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.CustomID, request)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.CustomID, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1983,7 +1983,7 @@ func (b *Bitfinex) WsNewOrder(data *WsNewOrderRequest) (string, error) {
|
||||
// WsModifyOrder authenticated modify order request
|
||||
func (b *Bitfinex) WsModifyOrder(data *WsUpdateOrderRequest) error {
|
||||
request := makeRequestInterface(wsOrderUpdate, data)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(data.OrderID, request)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), data.OrderID, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2037,7 +2037,7 @@ func (b *Bitfinex) WsCancelOrder(orderID int64) error {
|
||||
OrderID: orderID,
|
||||
}
|
||||
request := makeRequestInterface(wsOrderCancel, cancel)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2094,7 +2094,7 @@ func (b *Bitfinex) WsCancelOffer(orderID int64) error {
|
||||
OrderID: orderID,
|
||||
}
|
||||
request := makeRequestInterface(wsFundingOfferCancel, cancel)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(orderID, request)
|
||||
resp, err := b.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), orderID, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ func (by *Bybit) WsAuth(ctx context.Context) error {
|
||||
Operation: "auth",
|
||||
Args: []interface{}{creds.Key, intNonce, sign},
|
||||
}
|
||||
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(req.RequestID, req)
|
||||
resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.RequestID, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -220,12 +220,12 @@ func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe su
|
||||
for a := range payloads {
|
||||
var response []byte
|
||||
if payloads[a].auth {
|
||||
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a])
|
||||
response, err = by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
response, err = by.Websocket.Conn.SendMessageReturnResponse(payloads[a].RequestID, payloads[a])
|
||||
response, err = by.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[a].RequestID, payloads[a])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -480,7 +480,7 @@ func (c *COINUT) WsGetInstruments() (Instruments, error) {
|
||||
SecurityType: strings.ToUpper(asset.Spot.String()),
|
||||
Nonce: getNonce(),
|
||||
}
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request)
|
||||
if err != nil {
|
||||
return list, err
|
||||
}
|
||||
@@ -648,7 +648,7 @@ func (c *COINUT) Unsubscribe(channelToUnsubscribe subscription.List) error {
|
||||
Subscribe: false,
|
||||
Nonce: getNonce(),
|
||||
}
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(subscribe.Nonce, subscribe)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), subscribe.Nonce, subscribe)
|
||||
if err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
continue
|
||||
@@ -691,7 +691,7 @@ func (c *COINUT) wsAuthenticate(ctx context.Context) error {
|
||||
}
|
||||
r.Hmac = crypto.HexEncodeToString(hmac)
|
||||
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(r.Nonce, r)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), r.Nonce, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -714,7 +714,7 @@ func (c *COINUT) wsGetAccountBalance() (*UserBalance, error) {
|
||||
Request: "user_balance",
|
||||
Nonce: getNonce(),
|
||||
}
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(accBalance.Nonce, accBalance)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), accBalance.Nonce, accBalance)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -750,7 +750,7 @@ func (c *COINUT) wsSubmitOrder(o *WsSubmitOrderParameters) (*order.Detail, error
|
||||
if o.OrderID > 0 {
|
||||
orderSubmissionRequest.OrderID = o.OrderID
|
||||
}
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderSubmissionRequest.Nonce, orderSubmissionRequest)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderSubmissionRequest.Nonce, orderSubmissionRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -793,7 +793,7 @@ func (c *COINUT) wsSubmitOrders(orders []WsSubmitOrderParameters) ([]order.Detai
|
||||
|
||||
orderRequest.Nonce = getNonce()
|
||||
orderRequest.Request = "new_orders"
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(orderRequest.Nonce, orderRequest)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), orderRequest.Nonce, orderRequest)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
return nil, errs
|
||||
@@ -829,7 +829,7 @@ func (c *COINUT) wsGetOpenOrders(curr string) (*WsUserOpenOrdersResponse, error)
|
||||
openOrdersRequest.Nonce = getNonce()
|
||||
openOrdersRequest.InstrumentID = c.instrumentMap.LookupID(curr)
|
||||
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(openOrdersRequest.Nonce, openOrdersRequest)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), openOrdersRequest.Nonce, openOrdersRequest)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
@@ -862,7 +862,7 @@ func (c *COINUT) wsCancelOrder(cancellation *WsCancelOrderParameters) (*CancelOr
|
||||
cancellationRequest.OrderID = cancellation.OrderID
|
||||
cancellationRequest.Nonce = getNonce()
|
||||
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancellationRequest.Nonce, cancellationRequest)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancellationRequest.Nonce, cancellationRequest)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
@@ -903,7 +903,7 @@ func (c *COINUT) wsCancelOrders(cancellations []WsCancelOrderParameters) (*Cance
|
||||
|
||||
cancelOrderRequest.Request = "cancel_orders"
|
||||
cancelOrderRequest.Nonce = getNonce()
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(cancelOrderRequest.Nonce, cancelOrderRequest)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), cancelOrderRequest.Nonce, cancelOrderRequest)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
@@ -933,7 +933,7 @@ func (c *COINUT) wsGetTradeHistory(p currency.Pair, start, limit int64) (*WsTrad
|
||||
request.Start = start
|
||||
request.Limit = limit
|
||||
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(request.Nonce, request)
|
||||
resp, err := c.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Nonce, request)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ func (d *Deribit) wsLogin(ctx context.Context) error {
|
||||
"signature": crypto.HexEncodeToString(hmac),
|
||||
},
|
||||
}
|
||||
resp, err := d.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
d.Websocket.SetCanUseAuthenticatedEndpoints(false)
|
||||
return err
|
||||
@@ -1165,7 +1165,7 @@ func (d *Deribit) handleSubscription(operation string, channels subscription.Lis
|
||||
return err
|
||||
}
|
||||
for x := range payloads {
|
||||
data, err := d.Websocket.Conn.SendMessageReturnResponse(payloads[x].ID, payloads[x])
|
||||
data, err := d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[x].ID, payloads[x])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2406,7 +2406,7 @@ func (d *Deribit) sendWsPayload(ep request.EndpointLimit, input *WsRequest, resp
|
||||
log.Debugf(log.RequestSys, "%s attempt %d", d.Name, attempt)
|
||||
}
|
||||
var payload []byte
|
||||
payload, err = d.Websocket.Conn.SendMessageReturnResponse(input.ID, input)
|
||||
payload, err = d.Websocket.Conn.SendMessageReturnResponse(context.TODO(), input.ID, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -700,7 +700,7 @@ func (g *Gateio) handleSubscription(event string, channelsToSubscribe subscripti
|
||||
}
|
||||
var errs error
|
||||
for k := range payloads {
|
||||
result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k])
|
||||
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k])
|
||||
if err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
continue
|
||||
|
||||
@@ -207,9 +207,9 @@ func (g *Gateio) handleDeliveryFuturesSubscription(event string, channelsToSubsc
|
||||
for con, val := range payloads {
|
||||
for k := range val {
|
||||
if con == 0 {
|
||||
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k])
|
||||
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
|
||||
} else {
|
||||
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k])
|
||||
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
|
||||
}
|
||||
if err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
|
||||
@@ -287,9 +287,9 @@ func (g *Gateio) handleFuturesSubscription(event string, channelsToSubscribe sub
|
||||
for con, val := range payloads {
|
||||
for k := range val {
|
||||
if con == 0 {
|
||||
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(val[k].ID, val[k])
|
||||
respByte, err = g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
|
||||
} else {
|
||||
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(val[k].ID, val[k])
|
||||
respByte, err = g.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), val[k].ID, val[k])
|
||||
}
|
||||
if err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
|
||||
@@ -319,7 +319,7 @@ func (g *Gateio) handleOptionsSubscription(event string, channelsToSubscribe sub
|
||||
}
|
||||
var errs error
|
||||
for k := range payloads {
|
||||
result, err := g.Websocket.Conn.SendMessageReturnResponse(payloads[k].ID, payloads[k])
|
||||
result, err := g.Websocket.Conn.SendMessageReturnResponse(context.TODO(), payloads[k].ID, payloads[k])
|
||||
if err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
continue
|
||||
|
||||
@@ -632,7 +632,7 @@ func (h *HitBTC) wsPlaceOrder(pair currency.Pair, side string, price, quantity f
|
||||
},
|
||||
ID: id,
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(id, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), id, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -659,7 +659,7 @@ func (h *HitBTC) wsCancelOrder(clientOrderID string) (*WsCancelOrderResponse, er
|
||||
},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -689,7 +689,7 @@ func (h *HitBTC) wsReplaceOrder(clientOrderID string, quantity, price float64) (
|
||||
},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -714,7 +714,7 @@ func (h *HitBTC) wsGetActiveOrders() (*wsActiveOrdersResponse, error) {
|
||||
Params: WsReplaceOrderRequestData{},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -739,7 +739,7 @@ func (h *HitBTC) wsGetTradingBalance() (*WsGetTradingBalanceResponse, error) {
|
||||
Params: WsReplaceOrderRequestData{},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -763,7 +763,7 @@ func (h *HitBTC) wsGetCurrencies(currencyItem currency.Code) (*WsGetCurrenciesRe
|
||||
},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -792,7 +792,7 @@ func (h *HitBTC) wsGetSymbols(c currency.Pair) (*WsGetSymbolsResponse, error) {
|
||||
},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
@@ -824,7 +824,7 @@ func (h *HitBTC) wsGetTrades(c currency.Pair, limit int64, sort, by string) (*Ws
|
||||
},
|
||||
ID: h.Websocket.Conn.GenerateMessageID(false),
|
||||
}
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(request.ID, request)
|
||||
resp, err := h.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.ID, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%v %v", h.Name, err)
|
||||
}
|
||||
|
||||
@@ -700,7 +700,7 @@ func (h *HUOBI) wsGetAccountsList(ctx context.Context) (*WsAuthenticatedAccounts
|
||||
}
|
||||
request.Signature = crypto.Base64Encode(hmac)
|
||||
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -752,7 +752,7 @@ func (h *HUOBI) wsGetOrdersList(ctx context.Context, accountID int64, pair curre
|
||||
request.Signature = crypto.Base64Encode(hmac)
|
||||
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
|
||||
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -794,7 +794,7 @@ func (h *HUOBI) wsGetOrderDetails(ctx context.Context, orderID string) (*WsAuthe
|
||||
}
|
||||
request.Signature = crypto.Base64Encode(hmac)
|
||||
request.ClientID = h.Websocket.AuthConn.GenerateMessageID(true)
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(request.ClientID, request)
|
||||
resp, err := h.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.ClientID, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1229,9 +1229,9 @@ channels:
|
||||
for i := range *subs {
|
||||
var err error
|
||||
if common.StringDataContains(authenticatedChannels, (*subs)[i].Subscription.Name) {
|
||||
_, err = k.Websocket.AuthConn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
|
||||
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i])
|
||||
} else {
|
||||
_, err = k.Websocket.Conn.SendMessageReturnResponse((*subs)[i].RequestID, (*subs)[i])
|
||||
_, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), (*subs)[i].RequestID, (*subs)[i])
|
||||
}
|
||||
if err == nil {
|
||||
err = k.Websocket.AddSuccessfulSubscriptions((*subs)[i].Channels...)
|
||||
@@ -1288,9 +1288,9 @@ channels:
|
||||
for i := range unsubs {
|
||||
var err error
|
||||
if common.StringDataContains(authenticatedChannels, unsubs[i].Subscription.Name) {
|
||||
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
|
||||
_, err = k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i])
|
||||
} else {
|
||||
_, err = k.Websocket.Conn.SendMessageReturnResponse(unsubs[i].RequestID, unsubs[i])
|
||||
_, err = k.Websocket.Conn.SendMessageReturnResponse(context.TODO(), unsubs[i].RequestID, unsubs[i])
|
||||
}
|
||||
if err == nil {
|
||||
err = k.Websocket.RemoveSubscriptions(unsubs[i].Channels...)
|
||||
@@ -1308,7 +1308,7 @@ func (k *Kraken) wsAddOrder(request *WsAddOrderRequest) (string, error) {
|
||||
request.RequestID = id
|
||||
request.Event = krakenWsAddOrder
|
||||
request.Token = authToken
|
||||
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
|
||||
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1347,7 +1347,7 @@ func (k *Kraken) wsCancelOrder(orderID string) error {
|
||||
RequestID: id,
|
||||
}
|
||||
|
||||
resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
|
||||
resp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w %s: %w", errCancellingOrder, orderID, err)
|
||||
}
|
||||
@@ -1377,7 +1377,7 @@ func (k *Kraken) wsCancelAllOrders() (*WsCancelOrderResponse, error) {
|
||||
RequestID: id,
|
||||
}
|
||||
|
||||
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(id, request)
|
||||
jsonResp, err := k.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), id, request)
|
||||
if err != nil {
|
||||
return &WsCancelOrderResponse{}, err
|
||||
}
|
||||
|
||||
@@ -956,7 +956,7 @@ func (ku *Kucoin) manageSubscriptions(subs subscription.List, operation string)
|
||||
PrivateChannel: s.Authenticated,
|
||||
Response: true,
|
||||
}
|
||||
if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse("msgID:"+msgID, req); err != nil {
|
||||
if respRaw, err := ku.Websocket.Conn.SendMessageReturnResponse(context.TODO(), "msgID:"+msgID, req); err != nil {
|
||||
errs = common.AppendError(errs, err)
|
||||
} else {
|
||||
rType, err := jsonparser.GetUnsafeString(respRaw, "type")
|
||||
|
||||
@@ -147,7 +147,7 @@ func (o *Okcoin) WsLogin(ctx context.Context, dialer *websocket.Dialer) error {
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = o.Websocket.AuthConn.SendMessageReturnResponse("login", authRequest)
|
||||
_, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), "login", authRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package okcoin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -153,9 +154,9 @@ func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}
|
||||
var err error
|
||||
// TODO: ratelimits for websocket
|
||||
if authenticated {
|
||||
byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(req.ID, req)
|
||||
byteData, err = o.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), req.ID, req)
|
||||
} else {
|
||||
byteData, err = o.Websocket.Conn.SendMessageReturnResponse(req.ID, req)
|
||||
byteData, err = o.Websocket.Conn.SendMessageReturnResponse(context.TODO(), req.ID, req)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -5,11 +5,14 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errSignatureCollision = errors.New("signature collision")
|
||||
errInvalidBufferSize = errors.New("buffer size must be positive")
|
||||
)
|
||||
|
||||
// NewMatch returns a new Match
|
||||
func NewMatch() *Match {
|
||||
return &Match{
|
||||
m: make(map[interface{}]chan []byte),
|
||||
}
|
||||
return &Match{m: make(map[any]*incoming)}
|
||||
}
|
||||
|
||||
// Match is a distributed subtype that handles the matching of requests and
|
||||
@@ -17,64 +20,54 @@ func NewMatch() *Match {
|
||||
// connections. Stream systems fan in all incoming payloads to one routine for
|
||||
// processing.
|
||||
type Match struct {
|
||||
m map[interface{}]chan []byte
|
||||
m map[any]*incoming
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Matcher defines a payload matching return mechanism
|
||||
type Matcher struct {
|
||||
C chan []byte
|
||||
sig interface{}
|
||||
m *Match
|
||||
}
|
||||
|
||||
// Incoming matches with request, disregarding the returned payload
|
||||
func (m *Match) Incoming(signature interface{}) bool {
|
||||
return m.IncomingWithData(signature, nil)
|
||||
type incoming struct {
|
||||
expected int
|
||||
c chan<- []byte
|
||||
}
|
||||
|
||||
// IncomingWithData matches with requests and takes in the returned payload, to
|
||||
// be processed outside of a stream processing routine and returns true if a handler was found
|
||||
func (m *Match) IncomingWithData(signature interface{}, data []byte) bool {
|
||||
func (m *Match) IncomingWithData(signature any, data []byte) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
ch, ok := m.m[signature]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- data:
|
||||
default:
|
||||
// this shouldn't occur but if it does continue to process as normal
|
||||
return false
|
||||
}
|
||||
return true
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
ch.c <- data
|
||||
ch.expected--
|
||||
if ch.expected == 0 {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Set the signature response channel for incoming data
|
||||
func (m *Match) Set(signature interface{}) (Matcher, error) {
|
||||
var ch chan []byte
|
||||
m.mu.Lock()
|
||||
if _, ok := m.m[signature]; ok {
|
||||
m.mu.Unlock()
|
||||
return Matcher{}, errors.New("signature collision")
|
||||
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
|
||||
if bufSize <= 0 {
|
||||
return nil, errInvalidBufferSize
|
||||
}
|
||||
// This is buffered so we don't need to wait for receiver.
|
||||
ch = make(chan []byte, 1)
|
||||
m.m[signature] = ch
|
||||
m.mu.Unlock()
|
||||
|
||||
return Matcher{
|
||||
C: ch,
|
||||
sig: signature,
|
||||
m: m,
|
||||
}, nil
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, ok := m.m[signature]; ok {
|
||||
return nil, errSignatureCollision
|
||||
}
|
||||
ch := make(chan []byte, bufSize)
|
||||
m.m[signature] = &incoming{expected: bufSize, c: ch}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Cleanup closes underlying channel and deletes signature from map
|
||||
func (m *Matcher) Cleanup() {
|
||||
m.m.mu.Lock()
|
||||
close(m.C)
|
||||
delete(m.m.m, m.sig)
|
||||
m.m.mu.Unlock()
|
||||
// RemoveSignature removes the signature response from map and closes the channel.
|
||||
func (m *Match) RemoveSignature(signature any) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if ch, ok := m.m[signature]; ok {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,50 +1,53 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
bm := &Match{}
|
||||
if bm.Incoming("wow") {
|
||||
t.Fatal("Should not have matched")
|
||||
}
|
||||
load := []byte("42")
|
||||
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
|
||||
|
||||
nm := NewMatch()
|
||||
// try to match with unset signature
|
||||
if nm.Incoming("hello") {
|
||||
t.Fatal("should not be able to match")
|
||||
}
|
||||
match := NewMatch()
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
|
||||
|
||||
m, err := nm.Set("hello")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := match.Set("hello", 0)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
|
||||
_, err = match.Set("hello", -1)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
|
||||
ch, err := match.Set("hello", 2)
|
||||
require.NoError(t, err, "Set must not error")
|
||||
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
|
||||
assert.Equal(t, "hello", string(<-ch))
|
||||
|
||||
_, err = nm.Set("hello")
|
||||
if err == nil {
|
||||
t.Fatal("error cannot be nil as this collision cannot occur")
|
||||
}
|
||||
_, err = match.Set("hello", 2)
|
||||
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
|
||||
|
||||
if m.sig != "hello" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
|
||||
|
||||
// try and match with initial payload
|
||||
if !nm.Incoming("hello") {
|
||||
t.Fatal("should of matched")
|
||||
}
|
||||
|
||||
// put in secondary payload with conflicting signature
|
||||
if nm.Incoming("hello") {
|
||||
fmt.Println("should not have been able to match")
|
||||
}
|
||||
|
||||
if data := <-m.C; data != nil {
|
||||
t.Fatal("data chan should be nil")
|
||||
}
|
||||
|
||||
m.Cleanup()
|
||||
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
|
||||
}
|
||||
|
||||
func TestRemoveSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
match := NewMatch()
|
||||
ch, err := match.Set("masterblaster", 1)
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("Should not be able to read from an empty channel")
|
||||
default:
|
||||
}
|
||||
require.NoError(t, err)
|
||||
match.RemoveSignature("masterblaster")
|
||||
select {
|
||||
case garbage := <-ch:
|
||||
require.Empty(t, garbage)
|
||||
default:
|
||||
t.Fatal("Should be able to read from a closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -14,10 +15,11 @@ import (
|
||||
type Connection interface {
|
||||
Dial(*websocket.Dialer, http.Header) error
|
||||
ReadMessage() Response
|
||||
SendJSONMessage(interface{}) error
|
||||
SendJSONMessage(any) error
|
||||
SetupPingHandler(PingHandler)
|
||||
GenerateMessageID(highPrecision bool) int64
|
||||
SendMessageReturnResponse(signature interface{}, request interface{}) ([]byte, error)
|
||||
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
|
||||
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
|
||||
SendRawMessage(messageType int, message []byte) error
|
||||
SetURL(string)
|
||||
SetProxy(string)
|
||||
|
||||
@@ -29,6 +29,8 @@ var (
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrNoMessageListener = errors.New("websocket listener not found for message")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
)
|
||||
|
||||
// Private websocket errors
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -19,41 +20,6 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait
|
||||
// for response
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) {
|
||||
m, err := w.Match.Set(signature)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer m.Cleanup()
|
||||
|
||||
b, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = w.SendRawMessage(websocket.TextMessage, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timer := time.NewTimer(w.ResponseMaxLimit)
|
||||
|
||||
select {
|
||||
case payload := <-m.C:
|
||||
if w.Reporter != nil {
|
||||
w.Reporter.Latency(w.ExchangeName, b, time.Since(start))
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
case <-timer.C:
|
||||
timer.Stop()
|
||||
return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature)
|
||||
}
|
||||
}
|
||||
|
||||
// Dial sets proxy urls and then connects to the websocket
|
||||
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
|
||||
if w.ProxyURL != "" {
|
||||
@@ -303,3 +269,56 @@ func (w *WebsocketConnection) SetProxy(proxy string) {
|
||||
func (w *WebsocketConnection) GetURL() string {
|
||||
return w.URL
|
||||
}
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
|
||||
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resps[0], nil
|
||||
}
|
||||
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
|
||||
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) {
|
||||
outbound, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
ch, err := w.Match.Set(signature, expected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = w.SendRawMessage(websocket.TextMessage, outbound)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected))
|
||||
|
||||
resps := make([][]byte, 0, expected)
|
||||
for err == nil && len(resps) < expected {
|
||||
select {
|
||||
case resp := <-ch:
|
||||
resps = append(resps, resp)
|
||||
case <-timeout.C:
|
||||
w.Match.RemoveSignature(signature)
|
||||
err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature)
|
||||
case <-ctx.Done():
|
||||
w.Match.RemoveSignature(signature)
|
||||
err = ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
timeout.Stop()
|
||||
|
||||
if err == nil && w.Reporter != nil {
|
||||
w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start))
|
||||
}
|
||||
|
||||
return resps, err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -724,8 +725,7 @@ func TestSendMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMessageWithResponse logic test
|
||||
func TestSendMessageWithResponse(t *testing.T) {
|
||||
func TestSendMessageReturnResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
wc := &WebsocketConnection{
|
||||
Verbose: true,
|
||||
@@ -753,10 +753,20 @@ func TestSendMessageWithResponse(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
|
||||
fn()
|
||||
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
||||
// with timeout
|
||||
wc.ResponseMaxLimit = 1
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
|
||||
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
|
||||
}
|
||||
|
||||
type reporter struct {
|
||||
@@ -1182,7 +1192,7 @@ func TestLatency(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user