stream/match: Reduce complexity and limit locking when match occurs (#1581)

* stream match update

* update tests

* linter: fix

* glorious: nits + handle context cancellations

* glorious: whooops

* Websocket: Add SendMessageReturnResponses

* whooooooopsie

* gk: nitssssss

* Update exchanges/stream/stream_match.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/stream/stream_match_test.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* linter: appease the linter gods

* glorious: nits

* glorious: nits

* Update exchanges/stream/stream_match_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

---------

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2024-08-19 10:35:46 +10:00
committed by GitHub
parent 225429bda6
commit 17c2ef2ec7
23 changed files with 207 additions and 178 deletions

View File

@@ -5,11 +5,14 @@ import (
"sync"
)
var (
errSignatureCollision = errors.New("signature collision")
errInvalidBufferSize = errors.New("buffer size must be positive")
)
// NewMatch returns a new Match
func NewMatch() *Match {
return &Match{
m: make(map[interface{}]chan []byte),
}
return &Match{m: make(map[any]*incoming)}
}
// Match is a distributed subtype that handles the matching of requests and
@@ -17,64 +20,54 @@ func NewMatch() *Match {
// connections. Stream systems fan in all incoming payloads to one routine for
// processing.
type Match struct {
m map[interface{}]chan []byte
m map[any]*incoming
mu sync.Mutex
}
// Matcher defines a payload matching return mechanism
type Matcher struct {
C chan []byte
sig interface{}
m *Match
}
// Incoming matches with request, disregarding the returned payload
func (m *Match) Incoming(signature interface{}) bool {
return m.IncomingWithData(signature, nil)
type incoming struct {
expected int
c chan<- []byte
}
// IncomingWithData matches with requests and takes in the returned payload, to
// be processed outside of a stream processing routine and returns true if a handler was found
func (m *Match) IncomingWithData(signature interface{}, data []byte) bool {
func (m *Match) IncomingWithData(signature any, data []byte) bool {
m.mu.Lock()
defer m.mu.Unlock()
ch, ok := m.m[signature]
if ok {
select {
case ch <- data:
default:
// this shouldn't occur but if it does continue to process as normal
return false
}
return true
if !ok {
return false
}
return false
ch.c <- data
ch.expected--
if ch.expected == 0 {
close(ch.c)
delete(m.m, signature)
}
return true
}
// Set the signature response channel for incoming data
func (m *Match) Set(signature interface{}) (Matcher, error) {
var ch chan []byte
m.mu.Lock()
if _, ok := m.m[signature]; ok {
m.mu.Unlock()
return Matcher{}, errors.New("signature collision")
func (m *Match) Set(signature any, bufSize int) (<-chan []byte, error) {
if bufSize <= 0 {
return nil, errInvalidBufferSize
}
// This is buffered so we don't need to wait for receiver.
ch = make(chan []byte, 1)
m.m[signature] = ch
m.mu.Unlock()
return Matcher{
C: ch,
sig: signature,
m: m,
}, nil
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.m[signature]; ok {
return nil, errSignatureCollision
}
ch := make(chan []byte, bufSize)
m.m[signature] = &incoming{expected: bufSize, c: ch}
return ch, nil
}
// Cleanup closes underlying channel and deletes signature from map
func (m *Matcher) Cleanup() {
m.m.mu.Lock()
close(m.C)
delete(m.m.m, m.sig)
m.m.mu.Unlock()
// RemoveSignature removes the signature response from map and closes the channel.
func (m *Match) RemoveSignature(signature any) {
m.mu.Lock()
defer m.mu.Unlock()
if ch, ok := m.m[signature]; ok {
close(ch.c)
delete(m.m, signature)
}
}

View File

@@ -1,50 +1,53 @@
package stream
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMatch(t *testing.T) {
t.Parallel()
bm := &Match{}
if bm.Incoming("wow") {
t.Fatal("Should not have matched")
}
load := []byte("42")
assert.False(t, new(Match).IncomingWithData("hello", load), "Should not match an uninitialised Match")
nm := NewMatch()
// try to match with unset signature
if nm.Incoming("hello") {
t.Fatal("should not be able to match")
}
match := NewMatch()
assert.False(t, match.IncomingWithData("hello", load), "Should not match an empty signature")
m, err := nm.Set("hello")
if err != nil {
t.Fatal(err)
}
_, err := match.Set("hello", 0)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on zero buffer size")
_, err = match.Set("hello", -1)
require.ErrorIs(t, err, errInvalidBufferSize, "Must error on negative buffer size")
ch, err := match.Set("hello", 2)
require.NoError(t, err, "Set must not error")
assert.True(t, match.IncomingWithData("hello", []byte("hello")))
assert.Equal(t, "hello", string(<-ch))
_, err = nm.Set("hello")
if err == nil {
t.Fatal("error cannot be nil as this collision cannot occur")
}
_, err = match.Set("hello", 2)
assert.ErrorIs(t, err, errSignatureCollision, "Should error on signature collision")
if m.sig != "hello" {
t.Fatal(err)
}
assert.True(t, match.IncomingWithData("hello", load), "Should match with matching message and signature")
assert.False(t, match.IncomingWithData("hello", load), "Should not match with matching message and signature")
// try and match with initial payload
if !nm.Incoming("hello") {
t.Fatal("should of matched")
}
// put in secondary payload with conflicting signature
if nm.Incoming("hello") {
fmt.Println("should not have been able to match")
}
if data := <-m.C; data != nil {
t.Fatal("data chan should be nil")
}
m.Cleanup()
assert.Len(t, ch, 1, "Channel should have 1 items, 1 was already read above")
}
func TestRemoveSignature(t *testing.T) {
t.Parallel()
match := NewMatch()
ch, err := match.Set("masterblaster", 1)
select {
case <-ch:
t.Fatal("Should not be able to read from an empty channel")
default:
}
require.NoError(t, err)
match.RemoveSignature("masterblaster")
select {
case garbage := <-ch:
require.Empty(t, garbage)
default:
t.Fatal("Should be able to read from a closed channel")
}
}

View File

@@ -1,6 +1,7 @@
package stream
import (
"context"
"net/http"
"time"
@@ -14,10 +15,11 @@ import (
type Connection interface {
Dial(*websocket.Dialer, http.Header) error
ReadMessage() Response
SendJSONMessage(interface{}) error
SendJSONMessage(any) error
SetupPingHandler(PingHandler)
GenerateMessageID(highPrecision bool) int64
SendMessageReturnResponse(signature interface{}, request interface{}) ([]byte, error)
SendMessageReturnResponse(ctx context.Context, signature any, request any) ([]byte, error)
SendMessageReturnResponses(ctx context.Context, signature any, request any, expected int) ([][]byte, error)
SendRawMessage(messageType int, message []byte) error
SetURL(string)
SetProxy(string)

View File

@@ -29,6 +29,8 @@ var (
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
ErrAlreadyDisabled = errors.New("websocket already disabled")
ErrNotConnected = errors.New("websocket is not connected")
ErrNoMessageListener = errors.New("websocket listener not found for message")
ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature")
)
// Private websocket errors

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"encoding/json"
"fmt"
@@ -19,41 +20,6 @@ import (
"github.com/thrasher-corp/gocryptotrader/log"
)
// SendMessageReturnResponse will send a WS message to the connection and wait
// for response
func (w *WebsocketConnection) SendMessageReturnResponse(signature, request interface{}) ([]byte, error) {
m, err := w.Match.Set(signature)
if err != nil {
return nil, err
}
defer m.Cleanup()
b, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
}
start := time.Now()
err = w.SendRawMessage(websocket.TextMessage, b)
if err != nil {
return nil, err
}
timer := time.NewTimer(w.ResponseMaxLimit)
select {
case payload := <-m.C:
if w.Reporter != nil {
w.Reporter.Latency(w.ExchangeName, b, time.Since(start))
}
return payload, nil
case <-timer.C:
timer.Stop()
return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature)
}
}
// Dial sets proxy urls and then connects to the websocket
func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header) error {
if w.ProxyURL != "" {
@@ -303,3 +269,56 @@ func (w *WebsocketConnection) SetProxy(proxy string) {
func (w *WebsocketConnection) GetURL() string {
return w.URL
}
// SendMessageReturnResponse will send a WS message to the connection and wait for response
func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, signature, request any) ([]byte, error) {
resps, err := w.SendMessageReturnResponses(ctx, signature, request, 1)
if err != nil {
return nil, err
}
return resps[0], nil
}
// SendMessageReturnResponses will send a WS message to the connection and wait for N responses
// An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked
func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, signature, request any, expected int) ([][]byte, error) {
outbound, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err)
}
ch, err := w.Match.Set(signature, expected)
if err != nil {
return nil, err
}
start := time.Now()
err = w.SendRawMessage(websocket.TextMessage, outbound)
if err != nil {
return nil, err
}
timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected))
resps := make([][]byte, 0, expected)
for err == nil && len(resps) < expected {
select {
case resp := <-ch:
resps = append(resps, resp)
case <-timeout.C:
w.Match.RemoveSignature(signature)
err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature)
case <-ctx.Done():
w.Match.RemoveSignature(signature)
err = ctx.Err()
}
}
timeout.Stop()
if err == nil && w.Reporter != nil {
w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start))
}
return resps, err
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
@@ -724,8 +725,7 @@ func TestSendMessage(t *testing.T) {
}
}
// TestSendMessageWithResponse logic test
func TestSendMessageWithResponse(t *testing.T) {
func TestSendMessageReturnResponse(t *testing.T) {
t.Parallel()
wc := &WebsocketConnection{
Verbose: true,
@@ -753,10 +753,20 @@ func TestSendMessageWithResponse(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
if err != nil {
t.Error(err)
}
cancelledCtx, fn := context.WithDeadline(context.Background(), time.Now())
fn()
_, err = wc.SendMessageReturnResponse(cancelledCtx, "123", request)
assert.ErrorIs(t, err, context.DeadlineExceeded)
// with timeout
wc.ResponseMaxLimit = 1
_, err = wc.SendMessageReturnResponse(context.Background(), "123", request)
assert.ErrorIs(t, err, ErrSignatureTimeout, "SendMessageReturnResponse should error when request ID not found")
}
type reporter struct {
@@ -1182,7 +1192,7 @@ func TestLatency(t *testing.T) {
RequestID: wc.GenerateMessageID(false),
}
_, err = wc.SendMessageReturnResponse(request.RequestID, request)
_, err = wc.SendMessageReturnResponse(context.Background(), request.RequestID, request)
if err != nil {
t.Error(err)
}