mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
Dispatch: Fix race during stop (#1443)
* Dispatch: Assertify tests * Dispatch: Fix race during stop If we have blocking writers, then we need to synchronise them exiting before closing off their channels. * Dispatch: Rename Routes mutex for clarity
This commit is contained in:
@@ -128,27 +128,31 @@ func (d *Dispatcher) stop() error {
|
||||
// Release finished workers
|
||||
close(d.shutdown)
|
||||
|
||||
d.rMtx.Lock()
|
||||
for key, pipes := range d.routes {
|
||||
for i := range pipes {
|
||||
// Boot off receivers waiting on pipes.
|
||||
close(pipes[i])
|
||||
}
|
||||
// Flush all pipes, re-subscription will need to occur.
|
||||
d.routes[key] = nil
|
||||
}
|
||||
d.rMtx.Unlock()
|
||||
ch := make(chan struct{}, 1)
|
||||
go func(ch chan<- struct{}) {
|
||||
d.wg.Wait()
|
||||
ch <- struct{}{}
|
||||
}(ch)
|
||||
|
||||
ch := make(chan struct{})
|
||||
timer := time.NewTimer(time.Second)
|
||||
go func(ch chan<- struct{}) { d.wg.Wait(); ch <- struct{}{} }(ch)
|
||||
select {
|
||||
case <-ch:
|
||||
log.Debugln(log.DispatchMgr, "Dispatch manager shutdown.")
|
||||
return nil
|
||||
case <-timer.C:
|
||||
case <-time.After(time.Second):
|
||||
return errDispatchShutdown
|
||||
}
|
||||
|
||||
// Wait for all relayers to have exited, including any blocking channel writes, before closing channels
|
||||
d.routesMtx.Lock()
|
||||
for key, pipes := range d.routes {
|
||||
for i := range pipes {
|
||||
close(pipes[i])
|
||||
}
|
||||
d.routes[key] = nil
|
||||
}
|
||||
d.routesMtx.Unlock()
|
||||
|
||||
log.Debugln(log.DispatchMgr, "Dispatch manager shutdown")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRunning returns if the dispatch system is running
|
||||
@@ -172,19 +176,24 @@ func (d *Dispatcher) relayer() {
|
||||
// every real job created has an ID set
|
||||
continue
|
||||
}
|
||||
d.rMtx.RLock()
|
||||
d.routesMtx.Lock()
|
||||
pipes, ok := d.routes[j.ID]
|
||||
if !ok {
|
||||
log.Warnf(log.DispatchMgr, "%v: %v\n", errDispatcherUUIDNotFoundInRouteList, j.ID)
|
||||
d.rMtx.RUnlock()
|
||||
d.routesMtx.Unlock()
|
||||
continue
|
||||
}
|
||||
for i := range pipes {
|
||||
d.wg.Add(1)
|
||||
go func(p chan any) {
|
||||
p <- j.Data
|
||||
defer d.wg.Done()
|
||||
select {
|
||||
case p <- j.Data:
|
||||
case <-d.shutdown: // Avoids race on blocking consumer when we go to stop
|
||||
}
|
||||
}(pipes[i])
|
||||
}
|
||||
d.rMtx.RUnlock()
|
||||
d.routesMtx.Unlock()
|
||||
case <-d.shutdown:
|
||||
d.wg.Done()
|
||||
return
|
||||
@@ -242,8 +251,8 @@ func (d *Dispatcher) subscribe(id uuid.UUID) (chan interface{}, error) {
|
||||
return nil, ErrNotRunning
|
||||
}
|
||||
|
||||
d.rMtx.Lock()
|
||||
defer d.rMtx.Unlock()
|
||||
d.routesMtx.Lock()
|
||||
defer d.routesMtx.Unlock()
|
||||
if _, ok := d.routes[id]; !ok {
|
||||
return nil, errDispatcherUUIDNotFoundInRouteList
|
||||
}
|
||||
@@ -281,8 +290,8 @@ func (d *Dispatcher) unsubscribe(id uuid.UUID, usedChan chan interface{}) error
|
||||
return nil
|
||||
}
|
||||
|
||||
d.rMtx.Lock()
|
||||
defer d.rMtx.Unlock()
|
||||
d.routesMtx.Lock()
|
||||
defer d.routesMtx.Unlock()
|
||||
pipes, ok := d.routes[id]
|
||||
if !ok {
|
||||
return errDispatcherUUIDNotFoundInRouteList
|
||||
@@ -334,8 +343,8 @@ func (d *Dispatcher) getNewID(genFn func() (uuid.UUID, error)) (uuid.UUID, error
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
d.rMtx.Lock()
|
||||
defer d.rMtx.Unlock()
|
||||
d.routesMtx.Lock()
|
||||
defer d.routesMtx.Unlock()
|
||||
// Check to see if it already exists
|
||||
if _, ok := d.routes[newID]; ok {
|
||||
return uuid.Nil, errUUIDCollision
|
||||
|
||||
@@ -2,14 +2,13 @@ package dispatch
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,149 +18,89 @@ var (
|
||||
|
||||
func TestGlobalDispatcher(t *testing.T) {
|
||||
err := Start(0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
running := IsRunning()
|
||||
if !running {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", IsRunning(), true)
|
||||
}
|
||||
require.NoError(t, err, "Start should not error")
|
||||
assert.True(t, IsRunning(), "IsRunning should return true")
|
||||
|
||||
err = Stop()
|
||||
if err != nil {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
running = IsRunning()
|
||||
if running {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", IsRunning(), false)
|
||||
}
|
||||
assert.NoError(t, err, "Stop should not error")
|
||||
assert.False(t, IsRunning(), "IsRunning should return false")
|
||||
}
|
||||
|
||||
func TestStartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
var d *Dispatcher
|
||||
|
||||
if d.isRunning() {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", d.isRunning(), false)
|
||||
}
|
||||
assert.False(t, d.isRunning(), "IsRunning should return false")
|
||||
|
||||
err := d.stop()
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "stop should error correctly")
|
||||
|
||||
err = d.start(10, 0)
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "start should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
err = d.stop()
|
||||
if !errors.Is(err, ErrNotRunning) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNotRunning)
|
||||
}
|
||||
|
||||
if d.isRunning() {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", d.isRunning(), false)
|
||||
}
|
||||
assert.ErrorIs(t, err, ErrNotRunning, "stop should error correctly")
|
||||
assert.False(t, d.isRunning(), "IsRunning should return false")
|
||||
|
||||
err = d.start(1, 100)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if !d.isRunning() {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", d.isRunning(), true)
|
||||
}
|
||||
assert.NoError(t, err, "start should not error")
|
||||
assert.True(t, d.isRunning(), "IsRunning should return true")
|
||||
|
||||
err = d.start(0, 0)
|
||||
if !errors.Is(err, errDispatcherAlreadyRunning) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherAlreadyRunning)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherAlreadyRunning, "start should error correctly")
|
||||
|
||||
// Add route option
|
||||
id, err := d.getNewID(uuid.NewV4)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "getNewID should not error")
|
||||
|
||||
// Add pipe
|
||||
_, err = d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "subscribe should not error")
|
||||
|
||||
// Max out jobs channel
|
||||
for x := 0; x < 99; x++ {
|
||||
err = d.publish(id, "woah-nelly")
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "publish should not error")
|
||||
}
|
||||
|
||||
err = d.stop()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if d.isRunning() {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", d.isRunning(), false)
|
||||
}
|
||||
assert.NoError(t, err, "stop should not error")
|
||||
assert.False(t, d.isRunning(), "IsRunning should return false")
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var d *Dispatcher
|
||||
_, err := d.subscribe(uuid.Nil)
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "subscribe should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
_, err = d.subscribe(uuid.Nil)
|
||||
if !errors.Is(err, errIDNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errIDNotSet)
|
||||
}
|
||||
assert.ErrorIs(t, err, errIDNotSet, "subscribe should error correctly")
|
||||
|
||||
_, err = d.subscribe(nonEmptyUUID)
|
||||
if !errors.Is(err, ErrNotRunning) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, ErrNotRunning)
|
||||
}
|
||||
assert.ErrorIs(t, err, ErrNotRunning, "subscribe should error correctly")
|
||||
|
||||
err = d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
id, err := d.getNewID(uuid.NewV4)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "getNewID should not error")
|
||||
|
||||
_, err = d.subscribe(nonEmptyUUID)
|
||||
if !errors.Is(err, errDispatcherUUIDNotFoundInRouteList) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherUUIDNotFoundInRouteList)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherUUIDNotFoundInRouteList, "subscribe should error correctly")
|
||||
|
||||
d.outbound.New = func() interface{} { return "omg" }
|
||||
_, err = d.subscribe(id)
|
||||
if !errors.Is(err, errTypeAssertionFailure) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errTypeAssertionFailure)
|
||||
}
|
||||
assert.ErrorIs(t, err, errTypeAssertionFailure, "subscribe should error correctly")
|
||||
|
||||
d.outbound.New = getChan
|
||||
ch, err := d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if ch == nil {
|
||||
t.Fatal("expected channel value")
|
||||
}
|
||||
assert.NoError(t, err, "subscribe should not error")
|
||||
assert.NotNil(t, ch, "Channel should not be nil")
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
@@ -169,73 +108,43 @@ func TestUnsubscribe(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
err := d.unsubscribe(uuid.Nil, nil)
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "unsubscribe should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
err = d.unsubscribe(uuid.Nil, nil)
|
||||
if !errors.Is(err, errIDNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errIDNotSet)
|
||||
}
|
||||
assert.ErrorIs(t, err, errIDNotSet, "unsubscribe should error correctly")
|
||||
|
||||
err = d.unsubscribe(nonEmptyUUID, nil)
|
||||
if !errors.Is(err, errChannelIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errChannelIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errChannelIsNil, "unsubscribe should error correctly")
|
||||
|
||||
// will return nil if not running
|
||||
err = d.unsubscribe(nonEmptyUUID, make(chan interface{}))
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "unsubscribe should not error")
|
||||
|
||||
err = d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
err = d.unsubscribe(nonEmptyUUID, make(chan interface{}))
|
||||
if !errors.Is(err, errDispatcherUUIDNotFoundInRouteList) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherUUIDNotFoundInRouteList)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherUUIDNotFoundInRouteList, "unsubscribe should error correctly")
|
||||
|
||||
id, err := d.getNewID(uuid.NewV4)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "getNewID should not error")
|
||||
|
||||
err = d.unsubscribe(id, make(chan interface{}))
|
||||
if !errors.Is(err, errChannelNotFoundInUUIDRef) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errChannelNotFoundInUUIDRef)
|
||||
}
|
||||
|
||||
// Skip over this when matching pipes
|
||||
_, err = d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errChannelNotFoundInUUIDRef, "unsubscribe should error correctly")
|
||||
|
||||
ch, err := d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "subscribe should not error")
|
||||
|
||||
err = d.unsubscribe(id, ch)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "unsubscribe should not error")
|
||||
|
||||
ch2, err := d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "subscribe should not error")
|
||||
|
||||
err = d.unsubscribe(id, ch2)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "unsubscribe should not error")
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
@@ -243,82 +152,56 @@ func TestPublish(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
err := d.publish(uuid.Nil, nil)
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "publish should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
err = d.publish(nonEmptyUUID, "test")
|
||||
if !errors.Is(err, nil) { // If not running, don't send back an error.
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "publish should not error")
|
||||
|
||||
err = d.start(2, 10)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
err = d.publish(uuid.Nil, nil)
|
||||
if !errors.Is(err, errIDNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errIDNotSet)
|
||||
}
|
||||
assert.ErrorIs(t, err, errIDNotSet, "publish should error correctly")
|
||||
|
||||
err = d.publish(nonEmptyUUID, nil)
|
||||
if !errors.Is(err, errNoData) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNoData)
|
||||
}
|
||||
assert.ErrorIs(t, err, errNoData, "publish should error correctly")
|
||||
|
||||
// demonstrate job limit error
|
||||
d.routes[nonEmptyUUID] = []chan interface{}{
|
||||
make(chan interface{}),
|
||||
}
|
||||
for x := 0; x < 200; x++ {
|
||||
err2 := d.publish(nonEmptyUUID, "test")
|
||||
if !errors.Is(err2, nil) {
|
||||
err = err2
|
||||
if err = d.publish(nonEmptyUUID, "test"); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, errDispatcherJobsAtLimit) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherJobsAtLimit)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherJobsAtLimit, "publish should eventually error at limit")
|
||||
}
|
||||
|
||||
func TestPublishReceive(t *testing.T) {
|
||||
t.Parallel()
|
||||
d := NewDispatcher()
|
||||
if err := d.start(0, 0); !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
err := d.start(0, 0)
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
id, err := d.getNewID(uuid.NewV4)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "getNewID should not error")
|
||||
|
||||
incoming, err := d.subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "subscribe should not error")
|
||||
|
||||
go func(d *Dispatcher, id uuid.UUID) {
|
||||
for x := 0; x < 10; x++ {
|
||||
err2 := d.publish(id, "WOW")
|
||||
if !errors.Is(err2, nil) {
|
||||
panic(err2)
|
||||
}
|
||||
err := d.publish(id, "WOW")
|
||||
assert.NoError(t, err, "publish should not error")
|
||||
}
|
||||
}(d, id)
|
||||
|
||||
data, ok := (<-incoming).(string)
|
||||
if !ok {
|
||||
t.Fatal("type assertion failure expected string")
|
||||
}
|
||||
|
||||
if data != "WOW" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
assert.True(t, ok, "Should get a string type from the pipe")
|
||||
assert.Equal(t, "WOW", data, "Should get correct value from the pipe")
|
||||
}
|
||||
|
||||
func TestGetNewID(t *testing.T) {
|
||||
@@ -326,161 +209,101 @@ func TestGetNewID(t *testing.T) {
|
||||
var d *Dispatcher
|
||||
|
||||
_, err := d.getNewID(uuid.NewV4)
|
||||
if !errors.Is(err, errDispatcherNotInitialized) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errDispatcherNotInitialized)
|
||||
}
|
||||
assert.ErrorIs(t, err, errDispatcherNotInitialized, "getNewID should error correctly")
|
||||
|
||||
d = NewDispatcher()
|
||||
|
||||
err = d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
_, err = d.getNewID(nil)
|
||||
if !errors.Is(err, errUUIDGeneratorFunctionIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errUUIDGeneratorFunctionIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errUUIDGeneratorFunctionIsNil, "getNewID should error correctly")
|
||||
|
||||
_, err = d.getNewID(func() (uuid.UUID, error) { return uuid.Nil, errTest })
|
||||
if !errors.Is(err, errTest) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errTest)
|
||||
}
|
||||
assert.ErrorIs(t, err, errTest, "getNewID should error correctly")
|
||||
|
||||
_, err = d.getNewID(func() (uuid.UUID, error) { return [uuid.Size]byte{254}, nil })
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "getNewID should not error")
|
||||
|
||||
_, err = d.getNewID(func() (uuid.UUID, error) { return [uuid.Size]byte{254}, nil })
|
||||
if !errors.Is(err, errUUIDCollision) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errUUIDCollision)
|
||||
}
|
||||
assert.ErrorIs(t, err, errUUIDCollision, "getNewID should error correctly")
|
||||
}
|
||||
|
||||
func TestMux(t *testing.T) {
|
||||
t.Parallel()
|
||||
var mux *Mux
|
||||
_, err := mux.Subscribe(uuid.Nil)
|
||||
if !errors.Is(err, errMuxIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errMuxIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Subscribe should error correctly")
|
||||
|
||||
err = mux.Unsubscribe(uuid.Nil, nil)
|
||||
if !errors.Is(err, errMuxIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errMuxIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Unsubscribe should error correctly")
|
||||
|
||||
err = mux.Publish(nil)
|
||||
if !errors.Is(err, errMuxIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errMuxIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "Publish should error correctly")
|
||||
|
||||
_, err = mux.GetID()
|
||||
if !errors.Is(err, errMuxIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errMuxIsNil)
|
||||
}
|
||||
assert.ErrorIs(t, err, errMuxIsNil, "GetID should error correctly")
|
||||
|
||||
d := NewDispatcher()
|
||||
err = d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
mux = GetNewMux(d)
|
||||
|
||||
err = mux.Publish(nil)
|
||||
if !errors.Is(err, errNoData) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNoData)
|
||||
}
|
||||
assert.ErrorIs(t, err, errNoData, "Publish should error correctly")
|
||||
|
||||
err = mux.Publish("lol")
|
||||
if !errors.Is(err, errNoIDs) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNoIDs)
|
||||
}
|
||||
assert.ErrorIs(t, err, errNoIDs, "Publish should error correctly")
|
||||
|
||||
id, err := mux.GetID()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "GetID should not error")
|
||||
|
||||
_, err = mux.Subscribe(uuid.Nil)
|
||||
if !errors.Is(err, errIDNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errIDNotSet)
|
||||
}
|
||||
assert.ErrorIs(t, err, errIDNotSet, "Subscribe should error correctly")
|
||||
|
||||
pipe, err := mux.Subscribe(id)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "Subscribe should not error")
|
||||
|
||||
var errChan = make(chan error)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
// Makes sure receiver is waiting for update
|
||||
go func(ch <-chan interface{}, errChan chan error, wg *sync.WaitGroup) {
|
||||
wg.Done()
|
||||
response, ok := (<-ch).(string)
|
||||
if !ok {
|
||||
errChan <- errors.New("type assertion failure")
|
||||
return
|
||||
}
|
||||
|
||||
if response != "string" {
|
||||
errChan <- errors.New("unexpected return")
|
||||
return
|
||||
}
|
||||
errChan <- nil
|
||||
}(pipe.c, errChan, &wg)
|
||||
|
||||
wg.Wait()
|
||||
var ready = make(chan bool)
|
||||
|
||||
payload := "string"
|
||||
go func(payload string) {
|
||||
err2 := mux.Publish(payload, id)
|
||||
if err2 != nil {
|
||||
fmt.Println(err2)
|
||||
}
|
||||
}(payload)
|
||||
|
||||
err = <-errChan
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
close(ready)
|
||||
response, ok := (<-pipe.c).(string)
|
||||
assert.True(t, ok, "Should get a string type value from Publish")
|
||||
assert.Equal(t, payload, response, "Should get correct value from Publish")
|
||||
}()
|
||||
|
||||
<-ready
|
||||
|
||||
err = mux.Publish(payload, id)
|
||||
assert.NoError(t, err, "Publish should not error")
|
||||
|
||||
err = pipe.Release()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
assert.NoError(t, err, "Release should not error")
|
||||
}
|
||||
|
||||
func TestMuxSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
d := NewDispatcher()
|
||||
err := d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(t, err, "start should not error")
|
||||
mux := GetNewMux(d)
|
||||
itemID, err := mux.GetID()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err, "GetID should not error")
|
||||
|
||||
var pipes []Pipe
|
||||
for i := 0; i < 1000; i++ {
|
||||
newPipe, err := mux.Subscribe(itemID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err, "Subscribe should not error")
|
||||
pipes = append(pipes, newPipe)
|
||||
}
|
||||
|
||||
for i := range pipes {
|
||||
err := pipes[i].Release()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err, "Release should not error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,11 +311,11 @@ func TestMuxPublish(t *testing.T) {
|
||||
t.Parallel()
|
||||
d := NewDispatcher()
|
||||
err := d.start(0, 0)
|
||||
assert.NoError(t, err, "start should not error")
|
||||
require.NoError(t, err, "start should not error")
|
||||
|
||||
mux := GetNewMux(d)
|
||||
itemID, err := mux.GetID()
|
||||
assert.NoError(t, err, "GetID should not error")
|
||||
require.NoError(t, err, "GetID should not error")
|
||||
|
||||
overloadCeiling := DefaultMaxWorkers * DefaultJobsLimit * 2
|
||||
|
||||
@@ -506,7 +329,7 @@ func TestMuxPublish(t *testing.T) {
|
||||
ready := make(chan any)
|
||||
demux := make(chan any, 1)
|
||||
pipe, err := mux.Subscribe(itemID)
|
||||
assert.NoError(t, err, "Subscribe should not error")
|
||||
require.NoError(t, err, "Subscribe should not error")
|
||||
|
||||
// Subscribers must be actively selecting in order to receive anything
|
||||
go func() {
|
||||
@@ -558,14 +381,10 @@ func TestMuxPublish(t *testing.T) {
|
||||
func BenchmarkSubscribe(b *testing.B) {
|
||||
d := NewDispatcher()
|
||||
err := d.start(0, 0)
|
||||
if !errors.Is(err, nil) {
|
||||
b.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
require.NoError(b, err, "start should not error")
|
||||
mux := GetNewMux(d)
|
||||
newID, err := mux.GetID()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
require.NoError(b, err, "GetID should not error")
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, err := mux.Subscribe(newID)
|
||||
|
||||
@@ -29,8 +29,8 @@ type Dispatcher struct {
|
||||
// then publish the data across the full registered channels for that uuid.
|
||||
// See relayer() method below.
|
||||
routes map[uuid.UUID][]chan interface{}
|
||||
// rMtx protects the routes variable ensuring acceptable read/write access
|
||||
rMtx sync.RWMutex
|
||||
// routesMtx protects the routes variable ensuring acceptable read/write access
|
||||
routesMtx sync.Mutex
|
||||
|
||||
// Persistent buffered job queue for relayers
|
||||
jobs chan job
|
||||
|
||||
Reference in New Issue
Block a user