mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
common: Update ErrorCollector to use mutex and simplify error collection in concurrent operations (#2090)
* refactor: Update ErrorCollector to use mutex and simplify error collection in concurrent operations * glorious: nits * linter: fix * another find * Apply suggestion from @gbjk Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Apply suggestion from @gbjk Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * one liner defer * Update common/common.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * gk: nits * Update common/common_test.go Co-authored-by: Scott <gloriousCode@users.noreply.github.com> * thrasher-: nits --------- Co-authored-by: shazbert <ryan.oharareid@thrasher.io> Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> Co-authored-by: shazbert <shazbert@DESKTOP-3QKKR6J.localdomain> Co-authored-by: Scott <gloriousCode@users.noreply.github.com>
This commit is contained in:
@@ -553,31 +553,33 @@ func ExcludeError(err, excl error) error {
|
||||
}
|
||||
|
||||
// ErrorCollector allows collecting a stream of errors from concurrent go routines
|
||||
// Users should call e.Wg.Done and send errors to e.C
|
||||
type ErrorCollector struct {
|
||||
C chan error
|
||||
Wg sync.WaitGroup
|
||||
errs error
|
||||
wg sync.WaitGroup
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// CollectErrors returns an ErrorCollector with WaitGroup and Channel buffer set to n
|
||||
func CollectErrors(n int) *ErrorCollector {
|
||||
e := &ErrorCollector{
|
||||
C: make(chan error, n),
|
||||
}
|
||||
e.Wg.Add(n)
|
||||
return e
|
||||
}
|
||||
|
||||
// Collect runs waits for e.Wg to be Done, closes the error channel, and return a error collection
|
||||
// Collect waits for the internal wait group to be done and returns an error collection
|
||||
// State is reset after each Collect, so successive calls are okay
|
||||
func (e *ErrorCollector) Collect() (errs error) {
|
||||
e.Wg.Wait()
|
||||
close(e.C)
|
||||
for err := range e.C {
|
||||
if err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
e.wg.Wait()
|
||||
e.m.Lock()
|
||||
defer func() { e.errs = nil; e.m.Unlock() }()
|
||||
return e.errs
|
||||
}
|
||||
|
||||
// Go runs a function in a goroutine and collects any error it returns
|
||||
func (e *ErrorCollector) Go(f func() error) {
|
||||
if err := NilGuard(f); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
e.wg.Go(func() {
|
||||
if err := f(); err != nil {
|
||||
e.m.Lock()
|
||||
e.errs = AppendError(e.errs, err)
|
||||
e.m.Unlock()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// StartEndTimeCheck provides some basic checks which occur
|
||||
|
||||
@@ -590,23 +590,22 @@ func TestGenerateRandomString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorCollector exercises the error collector
|
||||
func TestErrorCollector(t *testing.T) {
|
||||
e := CollectErrors(4)
|
||||
var e ErrorCollector
|
||||
require.Panics(t, func() { e.Go(nil) }, "Go with nil function must panic")
|
||||
for i := range 4 {
|
||||
go func() {
|
||||
e.Go(func() error {
|
||||
if i%2 == 0 {
|
||||
e.C <- errors.New("Collected error")
|
||||
} else {
|
||||
e.C <- nil
|
||||
return errors.New("collected error")
|
||||
}
|
||||
e.Wg.Done()
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
v := e.Collect()
|
||||
errs, ok := v.(*multiError)
|
||||
require.True(t, ok, "Must return a multiError")
|
||||
assert.Len(t, errs.Unwrap(), 2, "Should have 2 errors")
|
||||
assert.NoError(t, e.Collect(), "should return nil when a previous collection emptied the errors")
|
||||
}
|
||||
|
||||
// TestBatch ensures the Batch function does not regress into common behavioural faults if implementation changes
|
||||
|
||||
@@ -224,18 +224,15 @@ func (e *Exchange) FetchTradablePairs(ctx context.Context, assetType asset.Item)
|
||||
// UpdateTradablePairs updates the exchanges available pairs and stores
|
||||
// them in the exchanges config
|
||||
func (e *Exchange) UpdateTradablePairs(ctx context.Context) error {
|
||||
assets := e.GetAssetTypes(false)
|
||||
errs := common.CollectErrors(len(assets))
|
||||
for x := range assets {
|
||||
go func(x int) {
|
||||
defer errs.Wg.Done()
|
||||
pairs, err := e.FetchTradablePairs(ctx, assets[x])
|
||||
var errs common.ErrorCollector
|
||||
for _, a := range e.GetAssetTypes(false) {
|
||||
errs.Go(func() error {
|
||||
pairs, err := e.FetchTradablePairs(ctx, a)
|
||||
if err != nil {
|
||||
errs.C <- err
|
||||
return
|
||||
return err
|
||||
}
|
||||
errs.C <- e.UpdatePairs(pairs, assets[x], false)
|
||||
}(x)
|
||||
return e.UpdatePairs(pairs, a, false)
|
||||
})
|
||||
}
|
||||
return errs.Collect()
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
"unicode"
|
||||
@@ -1769,28 +1768,11 @@ func (b *Base) GetOpenInterest(context.Context, ...key.PairAsset) ([]futures.Ope
|
||||
|
||||
// ParallelChanOp performs a single method call in parallel across streams and waits to return any errors
|
||||
func (b *Base) ParallelChanOp(ctx context.Context, channels subscription.List, m func(context.Context, subscription.List) error, batchSize int) error {
|
||||
wg := sync.WaitGroup{}
|
||||
errC := make(chan error, len(channels))
|
||||
|
||||
for _, b := range common.Batch(channels, batchSize) {
|
||||
wg.Add(1)
|
||||
go func(c subscription.List) {
|
||||
defer wg.Done()
|
||||
if err := m(ctx, c); err != nil {
|
||||
errC <- err
|
||||
}
|
||||
}(b)
|
||||
var errs common.ErrorCollector
|
||||
for _, batchedSubs := range common.Batch(channels, batchSize) {
|
||||
errs.Go(func() error { return m(ctx, batchedSubs) })
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errC)
|
||||
|
||||
var errs error
|
||||
for err := range errC {
|
||||
errs = common.AppendError(errs, err)
|
||||
}
|
||||
|
||||
return errs
|
||||
return errs.Collect()
|
||||
}
|
||||
|
||||
// Bootstrap function allows for exchange authors to supplement or override common startup actions
|
||||
@@ -1825,27 +1807,16 @@ func Bootstrap(ctx context.Context, b IBotExchange) error {
|
||||
}
|
||||
}
|
||||
|
||||
a := b.GetAssetTypes(true)
|
||||
var wg sync.WaitGroup
|
||||
errC := make(chan error, len(a))
|
||||
for i := range a {
|
||||
wg.Add(1)
|
||||
go func(a asset.Item) {
|
||||
defer wg.Done()
|
||||
var errs common.ErrorCollector
|
||||
for _, a := range b.GetAssetTypes(true) {
|
||||
errs.Go(func() error {
|
||||
if err := b.UpdateOrderExecutionLimits(ctx, a); err != nil && !errors.Is(err, common.ErrNotYetImplemented) {
|
||||
errC <- fmt.Errorf("failed to set exchange order execution limits: %w", err)
|
||||
return fmt.Errorf("failed to set exchange order execution limits: %w", err)
|
||||
}
|
||||
}(a[i])
|
||||
return nil
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
close(errC)
|
||||
|
||||
var err error
|
||||
for e := range errC {
|
||||
err = common.AppendError(err, e)
|
||||
}
|
||||
|
||||
return err
|
||||
return errs.Collect()
|
||||
}
|
||||
|
||||
// Bootstrap is a fallback method for exchange startup actions
|
||||
|
||||
@@ -1066,14 +1066,10 @@ func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (stri
|
||||
// wsCancelOrders cancels open orders concurrently
|
||||
// It does not use the multiple txId facility of the cancelOrder API because the errors are not specific
|
||||
func (e *Exchange) wsCancelOrders(ctx context.Context, orderIDs []string) error {
|
||||
errs := common.CollectErrors(len(orderIDs))
|
||||
var errs common.ErrorCollector
|
||||
for _, id := range orderIDs {
|
||||
go func() {
|
||||
defer errs.Wg.Done()
|
||||
errs.C <- e.wsCancelOrder(ctx, id)
|
||||
}()
|
||||
errs.Go(func() error { return e.wsCancelOrder(ctx, id) })
|
||||
}
|
||||
|
||||
return errs.Collect()
|
||||
}
|
||||
|
||||
|
||||
@@ -307,13 +307,9 @@ func TestDoRequest(t *testing.T) {
|
||||
require.False(t, respErr.Error, "Error must be false")
|
||||
|
||||
// Check client side rate limit
|
||||
const numReqs = 5
|
||||
ec := common.CollectErrors(numReqs)
|
||||
|
||||
for range numReqs {
|
||||
go func() {
|
||||
defer ec.Wg.Done()
|
||||
|
||||
var ec common.ErrorCollector
|
||||
for range 5 {
|
||||
ec.Go(func() error {
|
||||
var resp struct {
|
||||
Response bool `json:"response"`
|
||||
}
|
||||
@@ -324,13 +320,13 @@ func TestDoRequest(t *testing.T) {
|
||||
Result: &resp,
|
||||
}, nil
|
||||
}, AuthenticatedRequest); err != nil {
|
||||
ec.C <- fmt.Errorf("SendPayload error: %w", err)
|
||||
return
|
||||
return fmt.Errorf("SendPayload error: %w", err)
|
||||
}
|
||||
if !resp.Response {
|
||||
ec.C <- fmt.Errorf("unexpected response: %+v", resp)
|
||||
return fmt.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, ec.Collect(), "Collect must return no errors")
|
||||
@@ -342,14 +338,9 @@ func TestDoRequest_Retries(t *testing.T) {
|
||||
r, err := New("test", new(http.Client), WithBackoff(func(int) time.Duration { return 0 }))
|
||||
require.NoError(t, err, "New requester must not error")
|
||||
|
||||
const numReqs = 4
|
||||
|
||||
ec := common.CollectErrors(numReqs)
|
||||
|
||||
for range numReqs {
|
||||
go func() {
|
||||
defer ec.Wg.Done()
|
||||
|
||||
var ec common.ErrorCollector
|
||||
for range 4 {
|
||||
ec.Go(func() error {
|
||||
var resp struct {
|
||||
Response bool `json:"response"`
|
||||
}
|
||||
@@ -362,13 +353,13 @@ func TestDoRequest_Retries(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := r.SendPayload(t.Context(), Auth, itemFn, AuthenticatedRequest); err != nil {
|
||||
ec.C <- fmt.Errorf("SendPayload error: %w", err)
|
||||
return
|
||||
return fmt.Errorf("SendPayload error: %w", err)
|
||||
}
|
||||
if !resp.Response {
|
||||
ec.C <- fmt.Errorf("unexpected response: %+v", resp)
|
||||
return fmt.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, ec.Collect(), "Collect must return no errors")
|
||||
|
||||
Reference in New Issue
Block a user