mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-08 07:26:48 +00:00
stream/okx: allow rate limit definitions to be used by the stream package (#1641)
* stream: rate limiter definitions * Update exchanges/request/request_test.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io> Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
@@ -16,16 +16,19 @@ import (
|
||||
type Connection interface {
|
||||
Dial(*websocket.Dialer, http.Header) error
|
||||
ReadMessage() Response
|
||||
SendJSONMessage(ctx context.Context, payload any) error
|
||||
SetupPingHandler(PingHandler)
|
||||
// GenerateMessageID generates a message ID for the individual connection.
|
||||
// If a bespoke function is set (by using SetupNewConnection) it will use
|
||||
// that, otherwise it will use the defaultGenerateMessageID function defined
|
||||
// in websocket_connection.go.
|
||||
SetupPingHandler(request.EndpointLimit, PingHandler)
|
||||
// GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set
|
||||
// (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function
|
||||
// defined in websocket_connection.go.
|
||||
GenerateMessageID(highPrecision bool) int64
|
||||
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
|
||||
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
|
||||
SendRawMessage(ctx context.Context, messageType int, message []byte) error
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error)
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int) ([][]byte, error)
|
||||
// SendRawMessage sends a message over the connection without JSON encoding it
|
||||
SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error
|
||||
// SendJSONMessage sends a JSON encoded message over the connection
|
||||
SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error
|
||||
SetURL(string)
|
||||
SetProxy(string)
|
||||
GetURL() string
|
||||
|
||||
@@ -196,6 +196,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
|
||||
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
|
||||
w.setState(disconnectedState)
|
||||
|
||||
w.rateLimitDefinitions = s.RateLimitDefinitions
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -253,6 +254,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
|
||||
RateLimit: c.RateLimit,
|
||||
Reporter: c.ConnectionLevelReporter,
|
||||
bespokeGenerateMessageID: c.BespokeGenerateMessageID,
|
||||
RateLimitDefinitions: w.rateLimitDefinitions,
|
||||
}
|
||||
|
||||
if c.Authenticated {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
@@ -22,6 +23,11 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
var (
|
||||
errWebsocketIsDisconnected = errors.New("websocket connection is disconnected")
|
||||
errRateLimitNotFound = errors.New("rate limit definition not found")
|
||||
)
|
||||
|
||||
// Dial sets proxy urls and then connects to the websocket
|
||||
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
|
||||
if w.ProxyURL != "" {
|
||||
@@ -56,8 +62,8 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header
|
||||
}
|
||||
|
||||
// SendJSONMessage sends a JSON encoded message over the connection
|
||||
func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, data interface{}) error {
|
||||
return w.writeToConn(ctx, func() error {
|
||||
func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error {
|
||||
return w.writeToConn(ctx, epl, func() error {
|
||||
if request.IsVerbose(ctx, w.Verbose) {
|
||||
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(msg))
|
||||
@@ -68,8 +74,8 @@ func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, data interfac
|
||||
}
|
||||
|
||||
// SendRawMessage sends a message over the connection without JSON encoding it
|
||||
func (w *WebsocketConnection) SendRawMessage(ctx context.Context, messageType int, message []byte) error {
|
||||
return w.writeToConn(ctx, func() error {
|
||||
func (w *WebsocketConnection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error {
|
||||
return w.writeToConn(ctx, epl, func() error {
|
||||
if request.IsVerbose(ctx, w.Verbose) {
|
||||
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(message))
|
||||
}
|
||||
@@ -77,43 +83,51 @@ func (w *WebsocketConnection) SendRawMessage(ctx context.Context, messageType in
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WebsocketConnection) writeToConn(ctx context.Context, writeConn func() error) error {
|
||||
func (w *WebsocketConnection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error {
|
||||
if !w.IsConnected() {
|
||||
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
|
||||
return fmt.Errorf("%v websocket connection: cannot send message %w", w.ExchangeName, errWebsocketIsDisconnected)
|
||||
}
|
||||
if w.RateLimit != nil {
|
||||
err := request.RateLimit(ctx, w.RateLimit)
|
||||
if err != nil {
|
||||
|
||||
var rl *request.RateLimiterWithWeight
|
||||
if w.RateLimitDefinitions != nil {
|
||||
var ok bool
|
||||
if rl, ok = w.RateLimitDefinitions[epl]; !ok && w.RateLimit == nil {
|
||||
// Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is
|
||||
// set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits
|
||||
// if they are defined.
|
||||
return fmt.Errorf("%s websocket connection: %w for %v", w.ExchangeName, errRateLimitNotFound, epl)
|
||||
}
|
||||
}
|
||||
|
||||
if rl == nil {
|
||||
// If a global rate limit definition is not found, use the connection rate limit as a fallback.
|
||||
rl = w.RateLimit
|
||||
}
|
||||
|
||||
if rl != nil {
|
||||
if err := request.RateLimit(ctx, rl); err != nil {
|
||||
return fmt.Errorf("%s websocket connection: rate limit error: %w", w.ExchangeName, err)
|
||||
}
|
||||
}
|
||||
// This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check.
|
||||
w.writeControl.Lock()
|
||||
defer w.writeControl.Unlock()
|
||||
// NOTE: Secondary check to ensure the connection is still active after
|
||||
// semacquire and potential rate limit.
|
||||
if !w.IsConnected() {
|
||||
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
|
||||
}
|
||||
return writeConn()
|
||||
}
|
||||
|
||||
// SetupPingHandler will automatically send ping or pong messages based on
|
||||
// WebsocketPingHandler configuration
|
||||
func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) {
|
||||
func (w *WebsocketConnection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) {
|
||||
if handler.UseGorillaHandler {
|
||||
h := func(msg string) error {
|
||||
err := w.Connection.WriteControl(handler.MessageType,
|
||||
[]byte(msg),
|
||||
time.Now().Add(handler.Delay))
|
||||
w.Connection.SetPingHandler(func(msg string) error {
|
||||
err := w.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay))
|
||||
if err == websocket.ErrCloseSent {
|
||||
return nil
|
||||
} else if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
w.Connection.SetPingHandler(h)
|
||||
})
|
||||
return
|
||||
}
|
||||
w.Wg.Add(1)
|
||||
@@ -126,12 +140,9 @@ func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) {
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := w.SendRawMessage(context.TODO(), handler.MessageType, handler.Message)
|
||||
err := w.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
|
||||
if err != nil {
|
||||
log.Errorf(log.WebsocketMgr,
|
||||
"%v websocket connection: ping handler failed to send message [%s]",
|
||||
w.ExchangeName,
|
||||
handler.Message)
|
||||
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]", w.ExchangeName, handler.Message)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -272,8 +283,8 @@ func (w *WebsocketConnection) GetURL() string {
|
||||
}
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
|
||||
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) {
|
||||
resps, err := w.SendMessageReturnResponses(ctx, epl, signature, request, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -282,7 +293,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, sig
|
||||
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
|
||||
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, payload any, expected int) ([][]byte, error) {
|
||||
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) {
|
||||
outbound, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
@@ -294,7 +305,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, si
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = w.SendRawMessage(ctx, websocket.TextMessage, outbound)
|
||||
err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -714,11 +714,11 @@ func TestSendMessage(t *testing.T) {
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = testData.WC.SendJSONMessage(context.Background(), Ping)
|
||||
err = testData.WC.SendJSONMessage(context.Background(), request.Unset, Ping)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = testData.WC.SendRawMessage(context.Background(), websocket.TextMessage, []byte(Ping))
|
||||
err = testData.WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -745,7 +745,7 @@ func TestSendMessageReturnResponse(t *testing.T) {
|
||||
|
||||
go readMessages(t, wc)
|
||||
|
||||
request := testRequest{
|
||||
req := testRequest{
|
||||
Event: "subscribe",
|
||||
Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()},
|
||||
Subscription: testRequestData{
|
||||
@@ -754,19 +754,19 @@ func TestSendMessageReturnResponse(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
|
||||
fn()
|
||||
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
|
||||
_, err = wc.SendMessageReturnResponse(cancelledCtx, request.Unset, "123", req)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
||||
// with timeout
|
||||
wc.ResponseMaxLimit = 1
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, "123", req)
|
||||
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
|
||||
}
|
||||
|
||||
@@ -829,7 +829,7 @@ func TestSetupPingHandler(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wc.SetupPingHandler(PingHandler{
|
||||
wc.SetupPingHandler(request.Unset, PingHandler{
|
||||
UseGorillaHandler: true,
|
||||
MessageType: websocket.PingMessage,
|
||||
Delay: 100,
|
||||
@@ -844,7 +844,7 @@ func TestSetupPingHandler(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wc.SetupPingHandler(PingHandler{
|
||||
wc.SetupPingHandler(request.Unset, PingHandler{
|
||||
MessageType: websocket.TextMessage,
|
||||
Message: []byte(Ping),
|
||||
Delay: 200,
|
||||
@@ -1187,7 +1187,7 @@ func TestLatency(t *testing.T) {
|
||||
|
||||
go readMessages(t, wc)
|
||||
|
||||
request := testRequest{
|
||||
req := testRequest{
|
||||
Event: "subscribe",
|
||||
Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()},
|
||||
Subscription: testRequestData{
|
||||
@@ -1196,7 +1196,7 @@ func TestLatency(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -1248,3 +1248,29 @@ func TestRemoveURLQueryString(t *testing.T) {
|
||||
assert.Equal(t, "https://www.google.com", removeURLQueryString("https://www.google.com"), "removeURLQueryString should not change URL")
|
||||
assert.Equal(t, "", removeURLQueryString(""), "removeURLQueryString should be equal")
|
||||
}
|
||||
|
||||
func TestWriteToConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
wc := WebsocketConnection{}
|
||||
require.ErrorIs(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }), errWebsocketIsDisconnected)
|
||||
wc.setConnectedStatus(true)
|
||||
// No rate limits set
|
||||
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
|
||||
// connection rate limit set
|
||||
wc.RateLimit = request.NewWeightedRateLimitByDuration(time.Millisecond)
|
||||
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
|
||||
// context cancelled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.Canceled)
|
||||
// definitions set but with fallover
|
||||
wc.RateLimitDefinitions = request.RateLimitDefinitions{
|
||||
request.Auth: request.NewWeightedRateLimitByDuration(time.Millisecond),
|
||||
}
|
||||
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
|
||||
// match with global rate limit
|
||||
require.NoError(t, wc.writeToConn(context.Background(), request.Auth, func() error { return nil }))
|
||||
// definitions set but connection rate limiter not set
|
||||
wc.RateLimit = nil
|
||||
require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), errRateLimitNotFound)
|
||||
}
|
||||
|
||||
@@ -96,6 +96,10 @@ type Websocket struct {
|
||||
// MaxSubScriptionsPerConnection defines the maximum number of
|
||||
// subscriptions per connection that is allowed by the exchange.
|
||||
MaxSubscriptionsPerConnection int
|
||||
|
||||
// rateLimitDefinitions contains the rate limiters shared between Websocket and REST connections for all potential
|
||||
// endpoints.
|
||||
rateLimitDefinitions request.RateLimitDefinitions
|
||||
}
|
||||
|
||||
// WebsocketSetup defines variables for setting up a websocket connection
|
||||
@@ -121,6 +125,13 @@ type WebsocketSetup struct {
|
||||
// MaxWebsocketSubscriptionsPerConnection defines the maximum number of
|
||||
// subscriptions per connection that is allowed by the exchange.
|
||||
MaxWebsocketSubscriptionsPerConnection int
|
||||
|
||||
// RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all endpoints.
|
||||
// These rate limits take precedence over any rate limits specified in individual connection configurations.
|
||||
// If no connection-specific rate limit is provided and the endpoint does not match any of these definitions,
|
||||
// an error will be returned. However, if a connection configuration includes its own rate limit,
|
||||
// it will fall back to that configuration’s rate limit without raising an error.
|
||||
RateLimitDefinitions request.RateLimitDefinitions
|
||||
}
|
||||
|
||||
// WebsocketConnection contains all the data needed to send a message to a WS
|
||||
@@ -133,7 +144,12 @@ type WebsocketConnection struct {
|
||||
// writes methods
|
||||
writeControl sync.Mutex
|
||||
|
||||
RateLimit *request.RateLimiterWithWeight
|
||||
// RateLimit is a rate limiter for the connection itself
|
||||
RateLimit *request.RateLimiterWithWeight
|
||||
// RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all
|
||||
// potential endpoints.
|
||||
RateLimitDefinitions request.RateLimitDefinitions
|
||||
|
||||
ExchangeName string
|
||||
URL string
|
||||
ProxyURL string
|
||||
|
||||
Reference in New Issue
Block a user