From 1cabba73b9f84b6346fea51e1b8534461ec1327f Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Sat, 24 Aug 2024 12:18:20 +1000 Subject: [PATCH] common/gateio/stream: add thread-safe counter and overide default GenerateMessageID with connection specific implementation (#1615) * add counter and update gateio * Update exchanges/gateio/gateio.go Co-authored-by: Scott * thrasher: nits * add test case * linter: fix * revert change * thrasher nits --------- Co-authored-by: Ryan O'Hara-Reid Co-authored-by: Scott --- common/common.go | 17 ++++++++++++ common/common_test.go | 15 +++++++++++ exchanges/gateio/gateio.go | 1 + exchanges/gateio/gateio_test.go | 5 ++++ exchanges/gateio/gateio_websocket.go | 6 +++++ exchanges/gateio/gateio_wrapper.go | 9 ++++--- exchanges/stream/stream_types.go | 8 ++++++ exchanges/stream/websocket.go | 33 ++++++++++++++---------- exchanges/stream/websocket_connection.go | 12 ++++++++- exchanges/stream/websocket_test.go | 3 +++ exchanges/stream/websocket_types.go | 5 ++++ 11 files changed, 96 insertions(+), 18 deletions(-) diff --git a/common/common.go b/common/common.go index 241d2b45..207d6e4f 100644 --- a/common/common.go +++ b/common/common.go @@ -19,6 +19,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "unicode" @@ -672,3 +673,19 @@ func SortStrings[S ~[]E, E fmt.Stringer](x S) S { }) return n } + +// Counter is a thread-safe counter. +type Counter struct { + n int64 // privatised so you can't use counter as a value type +} + +// IncrementAndGet returns the next count after incrementing. +func (c *Counter) IncrementAndGet() int64 { + newID := atomic.AddInt64(&c.n, 1) + // Handle overflow by resetting the counter to 1 if it becomes negative + if newID < 0 { + atomic.StoreInt64(&c.n, 1) + return 1 + } + return newID +} diff --git a/common/common_test.go b/common/common_test.go index 0c6ca13e..81dc9b1a 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -862,3 +862,18 @@ func (a A) String() string { func TestSortStrings(t *testing.T) { assert.Equal(t, []A{1, 2, 5, 6}, SortStrings([]A{6, 2, 5, 1})) } + +func TestCounter(t *testing.T) { + t.Parallel() + c := Counter{n: -5} + require.Equal(t, int64(1), c.IncrementAndGet()) + require.Equal(t, int64(2), c.IncrementAndGet()) +} + +// 683185328 1.787 ns/op 0 B/op 0 allocs/op +func BenchmarkCounter(b *testing.B) { + c := Counter{} + for i := 0; i < b.N; i++ { + c.IncrementAndGet() + } +} diff --git a/exchanges/gateio/gateio.go b/exchanges/gateio/gateio.go index 0555e814..78606b2a 100644 --- a/exchanges/gateio/gateio.go +++ b/exchanges/gateio/gateio.go @@ -174,6 +174,7 @@ var ( // Gateio is the overarching type across this package type Gateio struct { exchange.Base + Counter common.Counter } // ***************************************** SubAccounts ******************************** diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index ca6b5b55..a69da1a0 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3606,3 +3606,8 @@ func TestGetUnifiedAccount(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, payload) } + +func TestGenerateWebsocketMessageID(t *testing.T) { + t.Parallel() + require.NotEmpty(t, g.GenerateWebsocketMessageID(false)) +} diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 537c695b..367e3137 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -868,3 +868,9 @@ func (g *Gateio) listOfAssetsCurrencyPairEnabledFor(cp currency.Pair) map[asset. } return assetPairEnabled } + +// GenerateWebsocketMessageID generates a message ID for the individual +// connection. +func (g *Gateio) GenerateWebsocketMessageID(bool) int64 { + return g.Counter.IncrementAndGet() +} diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 24b1c3ee..d918762d 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -226,10 +226,11 @@ func (g *Gateio) Setup(exch *config.Exchange) error { return err } return g.Websocket.SetupNewConnection(stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: gateioWebsocketRateLimit, - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + URL: gateioWebsocketEndpoint, + RateLimit: gateioWebsocketRateLimit, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 9421950c..36bc3f29 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -17,6 +17,10 @@ type Connection interface { ReadMessage() Response SendJSONMessage(any) error SetupPingHandler(PingHandler) + // GenerateMessageID generates a message ID for the individual connection. + // If a bespoke function is set (by using SetupNewConnection) it will use + // that, otherwise it will use the defaultGenerateMessageID function defined + // in websocket_connection.go. GenerateMessageID(highPrecision bool) int64 SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error) SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error) @@ -41,6 +45,10 @@ type ConnectionSetup struct { URL string Authenticated bool ConnectionLevelReporter Reporter + // BespokeGenerateMessageID is a function that returns a unique message ID. + // This is useful for when an exchange connection requires a unique or + // structured message ID for each message sent. + BespokeGenerateMessageID func(highPrecision bool) int64 } // PingHandler container for ping handler settings diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index b310ca5b..ad3c52d8 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -204,7 +204,13 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } - if c == (ConnectionSetup{}) { + + if c.ResponseCheckTimeout == 0 && + c.ResponseMaxLimit == 0 && + c.RateLimit == 0 && + c.URL == "" && + c.ConnectionLevelReporter == nil && + c.BespokeGenerateMessageID == nil { return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } @@ -234,18 +240,19 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { } newConn := &WebsocketConnection{ - ExchangeName: w.exchangeName, - URL: connectionURL, - ProxyURL: w.GetProxyAddress(), - Verbose: w.verbose, - ResponseMaxLimit: c.ResponseMaxLimit, - Traffic: w.TrafficAlert, - readMessageErrors: w.ReadMessageErrors, - ShutdownC: w.ShutdownC, - Wg: &w.Wg, - Match: w.Match, - RateLimit: c.RateLimit, - Reporter: c.ConnectionLevelReporter, + ExchangeName: w.exchangeName, + URL: connectionURL, + ProxyURL: w.GetProxyAddress(), + Verbose: w.verbose, + ResponseMaxLimit: c.ResponseMaxLimit, + Traffic: w.TrafficAlert, + readMessageErrors: w.ReadMessageErrors, + ShutdownC: w.ShutdownC, + Wg: &w.Wg, + Match: w.Match, + RateLimit: c.RateLimit, + Reporter: c.ConnectionLevelReporter, + bespokeGenerateMessageID: c.BespokeGenerateMessageID, } if c.Authenticated { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 37a5aecd..f0598622 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -229,8 +229,18 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a random message ID +// GenerateMessageID generates a message ID for the individual connection. +// If a bespoke function is set (by using SetupNewConnection) it will use that, +// otherwise it will use the defaultGenerateMessageID function. func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { + if w.bespokeGenerateMessageID != nil { + return w.bespokeGenerateMessageID(highPrec) + } + return w.defaultGenerateMessageID(highPrec) +} + +// defaultGenerateMessageID generates the default message ID +func (w *WebsocketConnection) defaultGenerateMessageID(highPrec bool) int64 { var minValue int64 = 1e8 var maxValue int64 = 2e8 if highPrec { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 324f3e32..5496ebcf 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -911,6 +911,9 @@ func TestGenerateMessageID(t *testing.T) { assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") ids[i] = id } + + wc.bespokeGenerateMessageID = func(bool) int64 { return 42 } + assert.EqualValues(t, 42, wc.GenerateMessageID(true), "GenerateMessageID must use bespokeGenerateMessageID") } // BenchmarkGenerateMessageID-8 2850018 408 ns/op 56 B/op 4 allocs/op diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 5d0009a4..07594991 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -145,5 +145,10 @@ type WebsocketConnection struct { Traffic chan struct{} readMessageErrors chan error + // bespokeGenerateMessageID is a function that returns a unique message ID + // defined externally. This is used for exchanges that require a unique + // message ID for each message sent. + bespokeGenerateMessageID func(highPrecision bool) int64 + Reporter Reporter }