stream/okx: allow rate limit definitions to be used by the stream package (#1641)

* stream: rate limiter definitions

* Update exchanges/request/request_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2024-09-13 13:56:46 +10:00
committed by GitHub
parent b8e836d74f
commit b461c32a5e
36 changed files with 328 additions and 231 deletions

View File

@@ -16,16 +16,19 @@ import (
type Connection interface {
Dial(*websocket.Dialer, http.Header) error
ReadMessage() Response
SendJSONMessage(ctx context.Context, payload any) error
SetupPingHandler(PingHandler)
// GenerateMessageID generates a message ID for the individual connection.
// If a bespoke function is set (by using SetupNewConnection) it will use
// that, otherwise it will use the defaultGenerateMessageID function defined
// in websocket_connection.go.
SetupPingHandler(request.EndpointLimit, PingHandler)
// GenerateMessageID generates a message ID for the individual connection. If a bespoke function is set
// (by using SetupNewConnection) it will use that, otherwise it will use the defaultGenerateMessageID function
// defined in websocket_connection.go.
GenerateMessageID(highPrecision bool) int64
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
SendRawMessage(ctx context.Context, messageType int, message []byte) error
// SendMessageReturnResponse will send a WS message to the connection and wait for response
SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error)
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int) ([][]byte, error)
// SendRawMessage sends a message over the connection without JSON encoding it
SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error
// SendJSONMessage sends a JSON encoded message over the connection
SendJSONMessage(ctx context.Context, epl request.EndpointLimit, payload any) error
SetURL(string)
SetProxy(string)
GetURL() string

View File

@@ -196,6 +196,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnectedState)
w.rateLimitDefinitions = s.RateLimitDefinitions
return nil
}
@@ -253,6 +254,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
RateLimit: c.RateLimit,
Reporter: c.ConnectionLevelReporter,
bespokeGenerateMessageID: c.BespokeGenerateMessageID,
RateLimitDefinitions: w.rateLimitDefinitions,
}
if c.Authenticated {

View File

@@ -7,6 +7,7 @@ import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
@@ -22,6 +23,11 @@ import (
"github.com/thrasher-corp/gocryptotrader/log"
)
var (
errWebsocketIsDisconnected = errors.New("websocket connection is disconnected")
errRateLimitNotFound = errors.New("rate limit definition not found")
)
// Dial sets proxy urls and then connects to the websocket
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
if w.ProxyURL != "" {
@@ -56,8 +62,8 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header
}
// SendJSONMessage sends a JSON encoded message over the connection
func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, data interface{}) error {
return w.writeToConn(ctx, func() error {
func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, epl request.EndpointLimit, data any) error {
return w.writeToConn(ctx, epl, func() error {
if request.IsVerbose(ctx, w.Verbose) {
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(msg))
@@ -68,8 +74,8 @@ func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, data interfac
}
// SendRawMessage sends a message over the connection without JSON encoding it
func (w *WebsocketConnection) SendRawMessage(ctx context.Context, messageType int, message []byte) error {
return w.writeToConn(ctx, func() error {
func (w *WebsocketConnection) SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error {
return w.writeToConn(ctx, epl, func() error {
if request.IsVerbose(ctx, w.Verbose) {
log.Debugf(log.WebsocketMgr, "%v %v: Sending message: %v", w.ExchangeName, removeURLQueryString(w.URL), string(message))
}
@@ -77,43 +83,51 @@ func (w *WebsocketConnection) SendRawMessage(ctx context.Context, messageType in
})
}
func (w *WebsocketConnection) writeToConn(ctx context.Context, writeConn func() error) error {
func (w *WebsocketConnection) writeToConn(ctx context.Context, epl request.EndpointLimit, writeConn func() error) error {
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
return fmt.Errorf("%v websocket connection: cannot send message %w", w.ExchangeName, errWebsocketIsDisconnected)
}
if w.RateLimit != nil {
err := request.RateLimit(ctx, w.RateLimit)
if err != nil {
var rl *request.RateLimiterWithWeight
if w.RateLimitDefinitions != nil {
var ok bool
if rl, ok = w.RateLimitDefinitions[epl]; !ok && w.RateLimit == nil {
// Return an error if no specific connection rate limit is found for the endpoint but a global rate limit is
// set. This ensures the system attempts to apply rate limiting, prioritizing endpoint-specific limits
// if they are defined.
return fmt.Errorf("%s websocket connection: %w for %v", w.ExchangeName, errRateLimitNotFound, epl)
}
}
if rl == nil {
// If a global rate limit definition is not found, use the connection rate limit as a fallback.
rl = w.RateLimit
}
if rl != nil {
if err := request.RateLimit(ctx, rl); err != nil {
return fmt.Errorf("%s websocket connection: rate limit error: %w", w.ExchangeName, err)
}
}
// This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check.
w.writeControl.Lock()
defer w.writeControl.Unlock()
// NOTE: Secondary check to ensure the connection is still active after
// semacquire and potential rate limit.
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
return writeConn()
}
// SetupPingHandler will automatically send ping or pong messages based on
// WebsocketPingHandler configuration
func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) {
func (w *WebsocketConnection) SetupPingHandler(epl request.EndpointLimit, handler PingHandler) {
if handler.UseGorillaHandler {
h := func(msg string) error {
err := w.Connection.WriteControl(handler.MessageType,
[]byte(msg),
time.Now().Add(handler.Delay))
w.Connection.SetPingHandler(func(msg string) error {
err := w.Connection.WriteControl(handler.MessageType, []byte(msg), time.Now().Add(handler.Delay))
if err == websocket.ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Timeout() {
return nil
}
return err
}
w.Connection.SetPingHandler(h)
})
return
}
w.Wg.Add(1)
@@ -126,12 +140,9 @@ func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) {
ticker.Stop()
return
case <-ticker.C:
err := w.SendRawMessage(context.TODO(), handler.MessageType, handler.Message)
err := w.SendRawMessage(context.TODO(), epl, handler.MessageType, handler.Message)
if err != nil {
log.Errorf(log.WebsocketMgr,
"%v websocket connection: ping handler failed to send message [%s]",
w.ExchangeName,
handler.Message)
log.Errorf(log.WebsocketMgr, "%v websocket connection: ping handler failed to send message [%s]", w.ExchangeName, handler.Message)
return
}
}
@@ -272,8 +283,8 @@ func (w *WebsocketConnection) GetURL() string {
}
// SendMessageReturnResponse will send a WS message to the connection and wait for response
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) {
resps, err := w.SendMessageReturnResponses(ctx, epl, signature, request, 1)
if err != nil {
return nil, err
}
@@ -282,7 +293,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, sig
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, payload any, expected int) ([][]byte, error) {
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) {
outbound, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
@@ -294,7 +305,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, si
}
start := time.Now()
err = w.SendRawMessage(ctx, websocket.TextMessage, outbound)
err = w.SendRawMessage(ctx, epl, websocket.TextMessage, outbound)
if err != nil {
return nil, err
}

View File

@@ -714,11 +714,11 @@ func TestSendMessage(t *testing.T) {
}
t.Fatal(err)
}
err = testData.WC.SendJSONMessage(context.Background(), Ping)
err = testData.WC.SendJSONMessage(context.Background(), request.Unset, Ping)
if err != nil {
t.Error(err)
}
err = testData.WC.SendRawMessage(context.Background(), websocket.TextMessage, []byte(Ping))
err = testData.WC.SendRawMessage(context.Background(), request.Unset, websocket.TextMessage, []byte(Ping))
if err != nil {
t.Error(err)
}
@@ -745,7 +745,7 @@ func TestSendMessageReturnResponse(t *testing.T) {
go readMessages(t, wc)
request := testRequest{
req := testRequest{
Event: "subscribe",
Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()},
Subscription: testRequestData{
@@ -754,19 +754,19 @@ func TestSendMessageReturnResponse(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req)
if err != nil {
t.Error(err)
}
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
fn()
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
_, err = wc.SendMessageReturnResponse(cancelledCtx, request.Unset, "123", req)
assert.ErrorIs(t, err, context.DeadlineExceeded)
// with timeout
wc.ResponseMaxLimit = 1
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, "123", req)
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
}
@@ -829,7 +829,7 @@ func TestSetupPingHandler(t *testing.T) {
t.Fatal(err)
}
wc.SetupPingHandler(PingHandler{
wc.SetupPingHandler(request.Unset, PingHandler{
UseGorillaHandler: true,
MessageType: websocket.PingMessage,
Delay: 100,
@@ -844,7 +844,7 @@ func TestSetupPingHandler(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wc.SetupPingHandler(PingHandler{
wc.SetupPingHandler(request.Unset, PingHandler{
MessageType: websocket.TextMessage,
Message: []byte(Ping),
Delay: 200,
@@ -1187,7 +1187,7 @@ func TestLatency(t *testing.T) {
go readMessages(t, wc)
request := testRequest{
req := testRequest{
Event: "subscribe",
Pairs: []string{currency.NewPairWithDelimiter("XBT", "USD", "/").String()},
Subscription: testRequestData{
@@ -1196,7 +1196,7 @@ func TestLatency(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.Unset, req.RequestID, req)
if err != nil {
t.Error(err)
}
@@ -1248,3 +1248,29 @@ func TestRemoveURLQueryString(t *testing.T) {
assert.Equal(t, "https://www.google.com", removeURLQueryString("https://www.google.com"), "removeURLQueryString should not change URL")
assert.Equal(t, "", removeURLQueryString(""), "removeURLQueryString should be equal")
}
func TestWriteToConn(t *testing.T) {
t.Parallel()
wc := WebsocketConnection{}
require.ErrorIs(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }), errWebsocketIsDisconnected)
wc.setConnectedStatus(true)
// No rate limits set
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
// connection rate limit set
wc.RateLimit = request.NewWeightedRateLimitByDuration(time.Millisecond)
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
// context cancelled
ctx, cancel := context.WithCancel(context.Background())
cancel()
require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), context.Canceled)
// definitions set but with fallover
wc.RateLimitDefinitions = request.RateLimitDefinitions{
request.Auth: request.NewWeightedRateLimitByDuration(time.Millisecond),
}
require.NoError(t, wc.writeToConn(context.Background(), request.Unset, func() error { return nil }))
// match with global rate limit
require.NoError(t, wc.writeToConn(context.Background(), request.Auth, func() error { return nil }))
// definitions set but connection rate limiter not set
wc.RateLimit = nil
require.ErrorIs(t, wc.writeToConn(ctx, request.Unset, func() error { return nil }), errRateLimitNotFound)
}

View File

@@ -96,6 +96,10 @@ type Websocket struct {
// MaxSubScriptionsPerConnection defines the maximum number of
// subscriptions per connection that is allowed by the exchange.
MaxSubscriptionsPerConnection int
// rateLimitDefinitions contains the rate limiters shared between Websocket and REST connections for all potential
// endpoints.
rateLimitDefinitions request.RateLimitDefinitions
}
// WebsocketSetup defines variables for setting up a websocket connection
@@ -121,6 +125,13 @@ type WebsocketSetup struct {
// MaxWebsocketSubscriptionsPerConnection defines the maximum number of
// subscriptions per connection that is allowed by the exchange.
MaxWebsocketSubscriptionsPerConnection int
// RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all endpoints.
// These rate limits take precedence over any rate limits specified in individual connection configurations.
// If no connection-specific rate limit is provided and the endpoint does not match any of these definitions,
// an error will be returned. However, if a connection configuration includes its own rate limit,
// it will fall back to that configurations rate limit without raising an error.
RateLimitDefinitions request.RateLimitDefinitions
}
// WebsocketConnection contains all the data needed to send a message to a WS
@@ -133,7 +144,12 @@ type WebsocketConnection struct {
// writes methods
writeControl sync.Mutex
RateLimit *request.RateLimiterWithWeight
// RateLimit is a rate limiter for the connection itself
RateLimit *request.RateLimiterWithWeight
// RateLimitDefinitions contains the rate limiters shared between WebSocket and REST connections for all
// potential endpoints.
RateLimitDefinitions request.RateLimitDefinitions
ExchangeName string
URL string
ProxyURL string