diff --git a/common/common.go b/common/common.go index bc60ba0c..070db74a 100644 --- a/common/common.go +++ b/common/common.go @@ -553,31 +553,33 @@ func ExcludeError(err, excl error) error { } // ErrorCollector allows collecting a stream of errors from concurrent go routines -// Users should call e.Wg.Done and send errors to e.C type ErrorCollector struct { - C chan error - Wg sync.WaitGroup + errs error + wg sync.WaitGroup + m sync.Mutex } -// CollectErrors returns an ErrorCollector with WaitGroup and Channel buffer set to n -func CollectErrors(n int) *ErrorCollector { - e := &ErrorCollector{ - C: make(chan error, n), - } - e.Wg.Add(n) - return e -} - -// Collect runs waits for e.Wg to be Done, closes the error channel, and return a error collection +// Collect waits for the internal wait group to be done and returns an error collection +// State is reset after each Collect, so successive calls are okay func (e *ErrorCollector) Collect() (errs error) { - e.Wg.Wait() - close(e.C) - for err := range e.C { - if err != nil { - errs = AppendError(errs, err) - } + e.wg.Wait() + e.m.Lock() + defer func() { e.errs = nil; e.m.Unlock() }() + return e.errs +} + +// Go runs a function in a goroutine and collects any error it returns +func (e *ErrorCollector) Go(f func() error) { + if err := NilGuard(f); err != nil { + panic(err) } - return + e.wg.Go(func() { + if err := f(); err != nil { + e.m.Lock() + e.errs = AppendError(e.errs, err) + e.m.Unlock() + } + }) } // StartEndTimeCheck provides some basic checks which occur diff --git a/common/common_test.go b/common/common_test.go index e85eff09..278a4582 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -590,23 +590,22 @@ func TestGenerateRandomString(t *testing.T) { } } -// TestErrorCollector exercises the error collector func TestErrorCollector(t *testing.T) { - e := CollectErrors(4) + var e ErrorCollector + require.Panics(t, func() { e.Go(nil) }, "Go with nil function must panic") for i := range 4 { - go func() { + e.Go(func() error { if i%2 == 0 { - e.C <- errors.New("Collected error") - } else { - e.C <- nil + return errors.New("collected error") } - e.Wg.Done() - }() + return nil + }) } v := e.Collect() errs, ok := v.(*multiError) require.True(t, ok, "Must return a multiError") assert.Len(t, errs.Unwrap(), 2, "Should have 2 errors") + assert.NoError(t, e.Collect(), "should return nil when a previous collection emptied the errors") } // TestBatch ensures the Batch function does not regress into common behavioural faults if implementation changes diff --git a/exchanges/deribit/deribit_wrapper.go b/exchanges/deribit/deribit_wrapper.go index 5840a734..9691a820 100644 --- a/exchanges/deribit/deribit_wrapper.go +++ b/exchanges/deribit/deribit_wrapper.go @@ -224,18 +224,15 @@ func (e *Exchange) FetchTradablePairs(ctx context.Context, assetType asset.Item) // UpdateTradablePairs updates the exchanges available pairs and stores // them in the exchanges config func (e *Exchange) UpdateTradablePairs(ctx context.Context) error { - assets := e.GetAssetTypes(false) - errs := common.CollectErrors(len(assets)) - for x := range assets { - go func(x int) { - defer errs.Wg.Done() - pairs, err := e.FetchTradablePairs(ctx, assets[x]) + var errs common.ErrorCollector + for _, a := range e.GetAssetTypes(false) { + errs.Go(func() error { + pairs, err := e.FetchTradablePairs(ctx, a) if err != nil { - errs.C <- err - return + return err } - errs.C <- e.UpdatePairs(pairs, assets[x], false) - }(x) + return e.UpdatePairs(pairs, a, false) + }) } return errs.Collect() } diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 44abcb04..4bca1263 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -10,7 +10,6 @@ import ( "sort" "strconv" "strings" - "sync" "text/template" "time" "unicode" @@ -1769,28 +1768,11 @@ func (b *Base) GetOpenInterest(context.Context, ...key.PairAsset) ([]futures.Ope // ParallelChanOp performs a single method call in parallel across streams and waits to return any errors func (b *Base) ParallelChanOp(ctx context.Context, channels subscription.List, m func(context.Context, subscription.List) error, batchSize int) error { - wg := sync.WaitGroup{} - errC := make(chan error, len(channels)) - - for _, b := range common.Batch(channels, batchSize) { - wg.Add(1) - go func(c subscription.List) { - defer wg.Done() - if err := m(ctx, c); err != nil { - errC <- err - } - }(b) + var errs common.ErrorCollector + for _, batchedSubs := range common.Batch(channels, batchSize) { + errs.Go(func() error { return m(ctx, batchedSubs) }) } - - wg.Wait() - close(errC) - - var errs error - for err := range errC { - errs = common.AppendError(errs, err) - } - - return errs + return errs.Collect() } // Bootstrap function allows for exchange authors to supplement or override common startup actions @@ -1825,27 +1807,16 @@ func Bootstrap(ctx context.Context, b IBotExchange) error { } } - a := b.GetAssetTypes(true) - var wg sync.WaitGroup - errC := make(chan error, len(a)) - for i := range a { - wg.Add(1) - go func(a asset.Item) { - defer wg.Done() + var errs common.ErrorCollector + for _, a := range b.GetAssetTypes(true) { + errs.Go(func() error { if err := b.UpdateOrderExecutionLimits(ctx, a); err != nil && !errors.Is(err, common.ErrNotYetImplemented) { - errC <- fmt.Errorf("failed to set exchange order execution limits: %w", err) + return fmt.Errorf("failed to set exchange order execution limits: %w", err) } - }(a[i]) + return nil + }) } - wg.Wait() - close(errC) - - var err error - for e := range errC { - err = common.AppendError(err, e) - } - - return err + return errs.Collect() } // Bootstrap is a fallback method for exchange startup actions diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 5f3f99d1..0841e69d 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -1066,14 +1066,10 @@ func (e *Exchange) wsAddOrder(ctx context.Context, req *WsAddOrderRequest) (stri // wsCancelOrders cancels open orders concurrently // It does not use the multiple txId facility of the cancelOrder API because the errors are not specific func (e *Exchange) wsCancelOrders(ctx context.Context, orderIDs []string) error { - errs := common.CollectErrors(len(orderIDs)) + var errs common.ErrorCollector for _, id := range orderIDs { - go func() { - defer errs.Wg.Done() - errs.C <- e.wsCancelOrder(ctx, id) - }() + errs.Go(func() error { return e.wsCancelOrder(ctx, id) }) } - return errs.Collect() } diff --git a/exchanges/request/request_test.go b/exchanges/request/request_test.go index 7af7a1fd..1cea4e58 100644 --- a/exchanges/request/request_test.go +++ b/exchanges/request/request_test.go @@ -307,13 +307,9 @@ func TestDoRequest(t *testing.T) { require.False(t, respErr.Error, "Error must be false") // Check client side rate limit - const numReqs = 5 - ec := common.CollectErrors(numReqs) - - for range numReqs { - go func() { - defer ec.Wg.Done() - + var ec common.ErrorCollector + for range 5 { + ec.Go(func() error { var resp struct { Response bool `json:"response"` } @@ -324,13 +320,13 @@ func TestDoRequest(t *testing.T) { Result: &resp, }, nil }, AuthenticatedRequest); err != nil { - ec.C <- fmt.Errorf("SendPayload error: %w", err) - return + return fmt.Errorf("SendPayload error: %w", err) } if !resp.Response { - ec.C <- fmt.Errorf("unexpected response: %+v", resp) + return fmt.Errorf("unexpected response: %+v", resp) } - }() + return nil + }) } require.NoError(t, ec.Collect(), "Collect must return no errors") @@ -342,14 +338,9 @@ func TestDoRequest_Retries(t *testing.T) { r, err := New("test", new(http.Client), WithBackoff(func(int) time.Duration { return 0 })) require.NoError(t, err, "New requester must not error") - const numReqs = 4 - - ec := common.CollectErrors(numReqs) - - for range numReqs { - go func() { - defer ec.Wg.Done() - + var ec common.ErrorCollector + for range 4 { + ec.Go(func() error { var resp struct { Response bool `json:"response"` } @@ -362,13 +353,13 @@ func TestDoRequest_Retries(t *testing.T) { } if err := r.SendPayload(t.Context(), Auth, itemFn, AuthenticatedRequest); err != nil { - ec.C <- fmt.Errorf("SendPayload error: %w", err) - return + return fmt.Errorf("SendPayload error: %w", err) } if !resp.Response { - ec.C <- fmt.Errorf("unexpected response: %+v", resp) + return fmt.Errorf("unexpected response: %+v", resp) } - }() + return nil + }) } require.NoError(t, ec.Collect(), "Collect must return no errors")