common: Update ErrorCollector to use mutex and simplify error collection in concurrent operations (#2090)

* refactor: Update ErrorCollector to use mutex and simplify error collection in concurrent operations

* glorious: nits

* linter: fix

* another find

* Apply suggestion from @gbjk

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Apply suggestion from @gbjk

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* one liner defer

* Update common/common.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* gk: nits

* Update common/common_test.go

Co-authored-by: Scott <gloriousCode@users.noreply.github.com>

* thrasher-: nits

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
Co-authored-by: shazbert <shazbert@DESKTOP-3QKKR6J.localdomain>
Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
Ryan O'Hara-Reid
2025-11-11 15:08:48 +11:00
committed by GitHub
parent fefb866b02
commit 497e13dc62
6 changed files with 63 additions and 107 deletions

View File

@@ -553,31 +553,33 @@ func ExcludeError(err, excl error) error {
}
// ErrorCollector allows collecting a stream of errors from concurrent go routines
// Users should call e.Wg.Done and send errors to e.C
type ErrorCollector struct {
C chan error
Wg sync.WaitGroup
errs error
wg sync.WaitGroup
m sync.Mutex
}
// CollectErrors returns an ErrorCollector with WaitGroup and Channel buffer set to n
func CollectErrors(n int) *ErrorCollector {
e := &ErrorCollector{
C: make(chan error, n),
}
e.Wg.Add(n)
return e
}
// Collect runs waits for e.Wg to be Done, closes the error channel, and return a error collection
// Collect waits for the internal wait group to be done and returns an error collection
// State is reset after each Collect, so successive calls are okay
func (e *ErrorCollector) Collect() (errs error) {
e.Wg.Wait()
close(e.C)
for err := range e.C {
if err != nil {
errs = AppendError(errs, err)
}
e.wg.Wait()
e.m.Lock()
defer func() { e.errs = nil; e.m.Unlock() }()
return e.errs
}
// Go runs a function in a goroutine and collects any error it returns
func (e *ErrorCollector) Go(f func() error) {
if err := NilGuard(f); err != nil {
panic(err)
}
return
e.wg.Go(func() {
if err := f(); err != nil {
e.m.Lock()
e.errs = AppendError(e.errs, err)
e.m.Unlock()
}
})
}
// StartEndTimeCheck provides some basic checks which occur

View File

@@ -590,23 +590,22 @@ func TestGenerateRandomString(t *testing.T) {
}
}
// TestErrorCollector exercises the error collector
func TestErrorCollector(t *testing.T) {
e := CollectErrors(4)
var e ErrorCollector
require.Panics(t, func() { e.Go(nil) }, "Go with nil function must panic")
for i := range 4 {
go func() {
e.Go(func() error {
if i%2 == 0 {
e.C <- errors.New("Collected error")
} else {
e.C <- nil
return errors.New("collected error")
}
e.Wg.Done()
}()
return nil
})
}
v := e.Collect()
errs, ok := v.(*multiError)
require.True(t, ok, "Must return a multiError")
assert.Len(t, errs.Unwrap(), 2, "Should have 2 errors")
assert.NoError(t, e.Collect(), "should return nil when a previous collection emptied the errors")
}
// TestBatch ensures the Batch function does not regress into common behavioural faults if implementation changes

View File

@@ -224,18 +224,15 @@ func (e *Exchange) FetchTradablePairs(ctx context.Context, assetType asset.Item)
// UpdateTradablePairs updates the exchanges available pairs and stores
// them in the exchanges config
func (e *Exchange) UpdateTradablePairs(ctx context.Context) error {
assets := e.GetAssetTypes(false)
errs := common.CollectErrors(len(assets))
for x := range assets {
go func(x int) {
defer errs.Wg.Done()
pairs, err := e.FetchTradablePairs(ctx, assets[x])
var errs common.ErrorCollector
for _, a := range e.GetAssetTypes(false) {
errs.Go(func() error {
pairs, err := e.FetchTradablePairs(ctx, a)
if err != nil {
errs.C <- err
return
return err
}
errs.C <- e.UpdatePairs(pairs, assets[x], false)
}(x)
return e.UpdatePairs(pairs, a, false)
})
}
return errs.Collect()
}

View File

@@ -10,7 +10,6 @@ import (
"sort"
"strconv"
"strings"
"sync"
"text/template"
"time"
"unicode"
@@ -1769,28 +1768,11 @@ func (b *Base) GetOpenInterest(context.Context, ...key.PairAsset) ([]futures.Ope
// ParallelChanOp performs a single method call in parallel across streams and waits to return any errors
func (b *Base) ParallelChanOp(ctx context.Context, channels subscription.List, m func(context.Context, subscription.List) error, batchSize int) error {
wg := sync.WaitGroup{}
errC := make(chan error, len(channels))
for _, b := range common.Batch(channels, batchSize) {
wg.Add(1)
go func(c subscription.List) {
defer wg.Done()
if err := m(ctx, c); err != nil {
errC <- err
}
}(b)
var errs common.ErrorCollector
for _, batchedSubs := range common.Batch(channels, batchSize) {
errs.Go(func() error { return m(ctx, batchedSubs) })
}
wg.Wait()
close(errC)
var errs error
for err := range errC {
errs = common.AppendError(errs, err)
}
return errs
return errs.Collect()
}
// Bootstrap function allows for exchange authors to supplement or override common startup actions
@@ -1825,27 +1807,16 @@ func Bootstrap(ctx context.Context, b IBotExchange) error {
}
}
a := b.GetAssetTypes(true)
var wg sync.WaitGroup
errC := make(chan error, len(a))
for i := range a {
wg.Add(1)
go func(a asset.Item) {
defer wg.Done()
var errs common.ErrorCollector
for _, a := range b.GetAssetTypes(true) {
errs.Go(func() error {
if err := b.UpdateOrderExecutionLimits(ctx, a); err != nil && !errors.Is(err, common.ErrNotYetImplemented) {
errC <- fmt.Errorf("failed to set exchange order execution limits: %w", err)
return fmt.Errorf("failed to set exchange order execution limits: %w", err)
}
}(a[i])
return nil
})
}
wg.Wait()
close(errC)
var err error
for e := range errC {
err = common.AppendError(err, e)
}
return err
return errs.Collect()
}
// Bootstrap is a fallback method for exchange startup actions

View File

@@ -1066,14 +1066,10 @@ func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (stri
// wsCancelOrders cancels open orders concurrently
// It does not use the multiple txId facility of the cancelOrder API because the errors are not specific
func (e *Exchange) wsCancelOrders(ctx context.Context, orderIDs []string) error {
errs := common.CollectErrors(len(orderIDs))
var errs common.ErrorCollector
for _, id := range orderIDs {
go func() {
defer errs.Wg.Done()
errs.C <- e.wsCancelOrder(ctx, id)
}()
errs.Go(func() error { return e.wsCancelOrder(ctx, id) })
}
return errs.Collect()
}

View File

@@ -307,13 +307,9 @@ func TestDoRequest(t *testing.T) {
require.False(t, respErr.Error, "Error must be false")
// Check client side rate limit
const numReqs = 5
ec := common.CollectErrors(numReqs)
for range numReqs {
go func() {
defer ec.Wg.Done()
var ec common.ErrorCollector
for range 5 {
ec.Go(func() error {
var resp struct {
Response bool `json:"response"`
}
@@ -324,13 +320,13 @@ func TestDoRequest(t *testing.T) {
Result: &resp,
}, nil
}, AuthenticatedRequest); err != nil {
ec.C <- fmt.Errorf("SendPayload error: %w", err)
return
return fmt.Errorf("SendPayload error: %w", err)
}
if !resp.Response {
ec.C <- fmt.Errorf("unexpected response: %+v", resp)
return fmt.Errorf("unexpected response: %+v", resp)
}
}()
return nil
})
}
require.NoError(t, ec.Collect(), "Collect must return no errors")
@@ -342,14 +338,9 @@ func TestDoRequest_Retries(t *testing.T) {
r, err := New("test", new(http.Client), WithBackoff(func(int) time.Duration { return 0 }))
require.NoError(t, err, "New requester must not error")
const numReqs = 4
ec := common.CollectErrors(numReqs)
for range numReqs {
go func() {
defer ec.Wg.Done()
var ec common.ErrorCollector
for range 4 {
ec.Go(func() error {
var resp struct {
Response bool `json:"response"`
}
@@ -362,13 +353,13 @@ func TestDoRequest_Retries(t *testing.T) {
}
if err := r.SendPayload(t.Context(), Auth, itemFn, AuthenticatedRequest); err != nil {
ec.C <- fmt.Errorf("SendPayload error: %w", err)
return
return fmt.Errorf("SendPayload error: %w", err)
}
if !resp.Response {
ec.C <- fmt.Errorf("unexpected response: %+v", resp)
return fmt.Errorf("unexpected response: %+v", resp)
}
}()
return nil
})
}
require.NoError(t, ec.Collect(), "Collect must return no errors")