mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-01 07:26:48 +00:00
stream/match: Reduce complexity and limit locking when match occurs (#1581)
* stream match update * update tests * linter: fix * glorious: nits + handle context cancellations * glorious: whooops * Websocket: Add SendMessageReturnResponses * whooooooopsie * gk: nitssssss * Update exchanges/stream/stream_match.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/stream/stream_match_test.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * linter: appease the linter gods * glorious: nits * glorious: nits * Update exchanges/stream/stream_match_test.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io> Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
@@ -5,11 +5,14 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errSignatureCollision = errors.New("signature collision")
|
||||
errInvalidBufferSize = errors.New("buffer size must be positive")
|
||||
)
|
||||
|
||||
// NewMatch returns a new Match
|
||||
func NewMatch() *Match {
|
||||
return &Match{
|
||||
m: make(map[interface{}]chan []byte),
|
||||
}
|
||||
return &Match{m: make(map[any]*incoming)}
|
||||
}
|
||||
|
||||
// Match is a distributed subtype that handles the matching of requests and
|
||||
@@ -17,64 +20,54 @@ func NewMatch() *Match {
|
||||
// connections. Stream systems fan in all incoming payloads to one routine for
|
||||
// processing.
|
||||
type Match struct {
|
||||
m map[interface{}]chan []byte
|
||||
m map[any]*incoming
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Matcher defines a payload matching return mechanism
|
||||
type Matcher struct {
|
||||
C chan []byte
|
||||
sig interface{}
|
||||
m *Match
|
||||
}
|
||||
|
||||
// Incoming matches with request, disregarding the returned payload
|
||||
func (m *Match) Incoming(signature interface{}) bool {
|
||||
return m.IncomingWithData(signature, nil)
|
||||
type incoming struct {
|
||||
expected int
|
||||
c chan<- []byte
|
||||
}
|
||||
|
||||
// IncomingWithData matches with requests and takes in the returned payload, to
|
||||
// be processed outside of a stream processing routine and returns true if a handler was found
|
||||
func (m *Match) IncomingWithData(signature interface{}, data []byte) bool {
|
||||
func (m *Match) IncomingWithData(signature any, data []byte) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
ch, ok := m.m[signature]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- data:
|
||||
default:
|
||||
// this shouldn't occur but if it does continue to process as normal
|
||||
return false
|
||||
}
|
||||
return true
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
ch.c <- data
|
||||
ch.expected--
|
||||
if ch.expected == 0 {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Set the signature response channel for incoming data
|
||||
func (m *Match) Set(signature interface{}) (Matcher, error) {
|
||||
var ch chan []byte
|
||||
m.mu.Lock()
|
||||
if _, ok := m.m[signature]; ok {
|
||||
m.mu.Unlock()
|
||||
return Matcher{}, errors.New("signature collision")
|
||||
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
|
||||
if bufSize <= 0 {
|
||||
return nil, errInvalidBufferSize
|
||||
}
|
||||
// This is buffered so we don't need to wait for receiver.
|
||||
ch = make(chan []byte, 1)
|
||||
m.m[signature] = ch
|
||||
m.mu.Unlock()
|
||||
|
||||
return Matcher{
|
||||
C: ch,
|
||||
sig: signature,
|
||||
m: m,
|
||||
}, nil
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, ok := m.m[signature]; ok {
|
||||
return nil, errSignatureCollision
|
||||
}
|
||||
ch := make(chan []byte, bufSize)
|
||||
m.m[signature] = &incoming{expected: bufSize, c: ch}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Cleanup closes underlying channel and deletes signature from map
|
||||
func (m *Matcher) Cleanup() {
|
||||
m.m.mu.Lock()
|
||||
close(m.C)
|
||||
delete(m.m.m, m.sig)
|
||||
m.m.mu.Unlock()
|
||||
// RemoveSignature removes the signature response from map and closes the channel.
|
||||
func (m *Match) RemoveSignature(signature any) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if ch, ok := m.m[signature]; ok {
|
||||
close(ch.c)
|
||||
delete(m.m, signature)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,50 +1,53 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
bm := &Match{}
|
||||
if bm.Incoming("wow") {
|
||||
t.Fatal("Should not have matched")
|
||||
}
|
||||
load := []byte("42")
|
||||
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
|
||||
|
||||
nm := NewMatch()
|
||||
// try to match with unset signature
|
||||
if nm.Incoming("hello") {
|
||||
t.Fatal("should not be able to match")
|
||||
}
|
||||
match := NewMatch()
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
|
||||
|
||||
m, err := nm.Set("hello")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := match.Set("hello", 0)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
|
||||
_, err = match.Set("hello", -1)
|
||||
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
|
||||
ch, err := match.Set("hello", 2)
|
||||
require.NoError(t, err, "Set must not error")
|
||||
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
|
||||
assert.Equal(t, "hello", string(<-ch))
|
||||
|
||||
_, err = nm.Set("hello")
|
||||
if err == nil {
|
||||
t.Fatal("error cannot be nil as this collision cannot occur")
|
||||
}
|
||||
_, err = match.Set("hello", 2)
|
||||
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
|
||||
|
||||
if m.sig != "hello" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
|
||||
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
|
||||
|
||||
// try and match with initial payload
|
||||
if !nm.Incoming("hello") {
|
||||
t.Fatal("should of matched")
|
||||
}
|
||||
|
||||
// put in secondary payload with conflicting signature
|
||||
if nm.Incoming("hello") {
|
||||
fmt.Println("should not have been able to match")
|
||||
}
|
||||
|
||||
if data := <-m.C; data != nil {
|
||||
t.Fatal("data chan should be nil")
|
||||
}
|
||||
|
||||
m.Cleanup()
|
||||
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
|
||||
}
|
||||
|
||||
func TestRemoveSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
match := NewMatch()
|
||||
ch, err := match.Set("masterblaster", 1)
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("Should not be able to read from an empty channel")
|
||||
default:
|
||||
}
|
||||
require.NoError(t, err)
|
||||
match.RemoveSignature("masterblaster")
|
||||
select {
|
||||
case garbage := <-ch:
|
||||
require.Empty(t, garbage)
|
||||
default:
|
||||
t.Fatal("Should be able to read from a closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -14,10 +15,11 @@ import (
|
||||
type Connection interface {
|
||||
Dial(*websocket.Dialer, http.Header) error
|
||||
ReadMessage() Response
|
||||
SendJSONMessage(interface{}) error
|
||||
SendJSONMessage(any) error
|
||||
SetupPingHandler(PingHandler)
|
||||
GenerateMessageID(highPrecision bool) int64
|
||||
SendMessageReturnResponse(signature interface{}, request interface{}) ([]byte, error)
|
||||
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
|
||||
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
|
||||
SendRawMessage(messageType int, message []byte) error
|
||||
SetURL(string)
|
||||
SetProxy(string)
|
||||
|
||||
@@ -29,6 +29,8 @@ var (
|
||||
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
|
||||
ErrAlreadyDisabled = errors.New("websocket already disabled")
|
||||
ErrNotConnected = errors.New("websocket is not connected")
|
||||
ErrNoMessageListener = errors.New("websocket listener not found for message")
|
||||
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
|
||||
)
|
||||
|
||||
// Private websocket errors
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -19,41 +20,6 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
)
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait
|
||||
// for response
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) {
|
||||
m, err := w.Match.Set(signature)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer m.Cleanup()
|
||||
|
||||
b, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = w.SendRawMessage(websocket.TextMessage, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timer := time.NewTimer(w.ResponseMaxLimit)
|
||||
|
||||
select {
|
||||
case payload := <-m.C:
|
||||
if w.Reporter != nil {
|
||||
w.Reporter.Latency(w.ExchangeName, b, time.Since(start))
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
case <-timer.C:
|
||||
timer.Stop()
|
||||
return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature)
|
||||
}
|
||||
}
|
||||
|
||||
// Dial sets proxy urls and then connects to the websocket
|
||||
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
|
||||
if w.ProxyURL != "" {
|
||||
@@ -303,3 +269,56 @@ func (w *WebsocketConnection) SetProxy(proxy string) {
|
||||
func (w *WebsocketConnection) GetURL() string {
|
||||
return w.URL
|
||||
}
|
||||
|
||||
// SendMessageReturnResponse will send a WS message to the connection and wait for response
|
||||
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
|
||||
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resps[0], nil
|
||||
}
|
||||
|
||||
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
|
||||
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
|
||||
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) {
|
||||
outbound, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
ch, err := w.Match.Set(signature, expected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err = w.SendRawMessage(websocket.TextMessage, outbound)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected))
|
||||
|
||||
resps := make([][]byte, 0, expected)
|
||||
for err == nil && len(resps) < expected {
|
||||
select {
|
||||
case resp := <-ch:
|
||||
resps = append(resps, resp)
|
||||
case <-timeout.C:
|
||||
w.Match.RemoveSignature(signature)
|
||||
err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature)
|
||||
case <-ctx.Done():
|
||||
w.Match.RemoveSignature(signature)
|
||||
err = ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
timeout.Stop()
|
||||
|
||||
if err == nil && w.Reporter != nil {
|
||||
w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start))
|
||||
}
|
||||
|
||||
return resps, err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -724,8 +725,7 @@ func TestSendMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMessageWithResponse logic test
|
||||
func TestSendMessageWithResponse(t *testing.T) {
|
||||
func TestSendMessageReturnResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
wc := &WebsocketConnection{
|
||||
Verbose: true,
|
||||
@@ -753,10 +753,20 @@ func TestSendMessageWithResponse(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
|
||||
fn()
|
||||
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
|
||||
// with timeout
|
||||
wc.ResponseMaxLimit = 1
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
|
||||
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
|
||||
}
|
||||
|
||||
type reporter struct {
|
||||
@@ -1182,7 +1192,7 @@ func TestLatency(t *testing.T) {
|
||||
RequestID: wc.GenerateMessageID(false),
|
||||
}
|
||||
|
||||
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
|
||||
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user