exchanges/websocket: update websocket rate limiting to use requester rate limiting functionality (#1578)

* exchanges/websocket: update websocket rate limiting to use requester rate limiting functionality.

* glorious: nits

* rm unsused

* updoo

* glorious: purgerino

* reduce duplicate code

* thrasher: engrish

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2024-09-02 16:43:05 +10:00
committed by GitHub
parent 7c9e6518f3
commit cb6b3421a7
43 changed files with 142 additions and 130 deletions

View File

@@ -9,13 +9,14 @@ import (
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/order"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
)
// Connection defines a streaming services connection
type Connection interface {
Dial(*websocket.Dialer, http.Header) error
ReadMessage() Response
SendJSONMessage(any) error
SendJSONMessage(ctx context.Context, payload any) error
SetupPingHandler(PingHandler)
// GenerateMessageID generates a message ID for the individual connection.
// If a bespoke function is set (by using SetupNewConnection) it will use
@@ -24,7 +25,7 @@ type Connection interface {
GenerateMessageID(highPrecision bool) int64
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
SendRawMessage(messageType int, message []byte) error
SendRawMessage(ctx context.Context, messageType int, message []byte) error
SetURL(string)
SetProxy(string)
GetURL() string
@@ -41,7 +42,7 @@ type Response struct {
type ConnectionSetup struct {
ResponseCheckTimeout time.Duration
ResponseMaxLimit time.Duration
RateLimit int64
RateLimit *request.RateLimiterWithWeight
URL string
Authenticated bool
ConnectionLevelReporter Reporter

View File

@@ -207,7 +207,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
if c.ResponseCheckTimeout == 0 &&
c.ResponseMaxLimit == 0 &&
c.RateLimit == 0 &&
c.RateLimit == nil &&
c.URL == "" &&
c.ConnectionLevelReporter == nil &&
c.BespokeGenerateMessageID == nil {

View File

@@ -17,6 +17,7 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/log"
)
@@ -54,51 +55,46 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header
}
// SendJSONMessage sends a JSON encoded message over the connection
func (w *WebsocketConnection) SendJSONMessage(data interface{}) error {
if !w.IsConnected() {
return fmt.Errorf("%s websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%s websocket connection: sending message: %s\n", w.ExchangeName, msg)
func (w *WebsocketConnection) SendJSONMessage(ctx context.Context, data interface{}) error {
return w.writeToConn(ctx, func() error {
if w.Verbose {
if msg, err := json.Marshal(data); err == nil { // WriteJSON will error for us anyway
log.Debugf(log.WebsocketMgr, "%s websocket connection: sending message: %s\n", w.ExchangeName, msg)
}
}
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
}
return w.Connection.WriteJSON(data)
return w.Connection.WriteJSON(data)
})
}
// SendRawMessage sends a message over the connection without JSON encoding it
func (w *WebsocketConnection) SendRawMessage(messageType int, message []byte) error {
func (w *WebsocketConnection) SendRawMessage(ctx context.Context, messageType int, message []byte) error {
return w.writeToConn(ctx, func() error {
if w.Verbose {
log.Debugf(log.WebsocketMgr, "%v websocket connection: sending message [%s]\n", w.ExchangeName, message)
}
return w.Connection.WriteMessage(messageType, message)
})
}
func (w *WebsocketConnection) writeToConn(ctx context.Context, writeConn func() error) error {
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
w.writeControl.Lock()
defer w.writeControl.Unlock()
if w.Verbose {
log.Debugf(log.WebsocketMgr, "%v websocket connection: sending message [%s]\n", w.ExchangeName, message)
}
if w.RateLimit > 0 {
time.Sleep(time.Duration(w.RateLimit) * time.Millisecond)
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
if w.RateLimit != nil {
err := request.RateLimit(ctx, w.RateLimit)
if err != nil {
return fmt.Errorf("%s websocket connection: rate limit error: %w", w.ExchangeName, err)
}
}
// This lock acts as a rolling gate to prevent WriteMessage panics. Acquire after rate limit check.
w.writeControl.Lock()
defer w.writeControl.Unlock()
// NOTE: Secondary check to ensure the connection is still active after
// semacquire and potential rate limit.
if !w.IsConnected() {
return fmt.Errorf("%v websocket connection: cannot send message to a disconnected websocket", w.ExchangeName)
}
return w.Connection.WriteMessage(messageType, message)
return writeConn()
}
// SetupPingHandler will automatically send ping or pong messages based on
@@ -129,7 +125,7 @@ func (w *WebsocketConnection) SetupPingHandler(handler PingHandler) {
ticker.Stop()
return
case <-ticker.C:
err := w.SendRawMessage(handler.MessageType, handler.Message)
err := w.SendRawMessage(context.TODO(), handler.MessageType, handler.Message)
if err != nil {
log.Errorf(log.WebsocketMgr,
"%v websocket connection: ping handler failed to send message [%s]",
@@ -303,7 +299,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, si
}
start := time.Now()
err = w.SendRawMessage(websocket.TextMessage, outbound)
err = w.SendRawMessage(ctx, websocket.TextMessage, outbound)
if err != nil {
return nil, err
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
)
@@ -629,7 +630,7 @@ func TestDial(t *testing.T) {
ExchangeName: "test1",
Verbose: true,
URL: websocketTestURL,
RateLimit: 10,
RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond),
ResponseMaxLimit: 7000000000,
},
},
@@ -677,7 +678,7 @@ func TestSendMessage(t *testing.T) {
ExchangeName: "test1",
Verbose: true,
URL: websocketTestURL,
RateLimit: 10,
RateLimit: request.NewWeightedRateLimitByDuration(10 * time.Millisecond),
ResponseMaxLimit: 7000000000,
},
},
@@ -713,11 +714,11 @@ func TestSendMessage(t *testing.T) {
}
t.Fatal(err)
}
err = testData.WC.SendJSONMessage(Ping)
err = testData.WC.SendJSONMessage(context.Background(), Ping)
if err != nil {
t.Error(err)
}
err = testData.WC.SendRawMessage(websocket.TextMessage, []byte(Ping))
err = testData.WC.SendRawMessage(context.Background(), websocket.TextMessage, []byte(Ping))
if err != nil {
t.Error(err)
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/exchanges/fill"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer"
"github.com/thrasher-corp/gocryptotrader/exchanges/subscription"
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
@@ -132,7 +133,7 @@ type WebsocketConnection struct {
// writes methods
writeControl sync.Mutex
RateLimit int64
RateLimit *request.RateLimiterWithWeight
ExchangeName string
URL string
ProxyURL string