Files
gocryptotrader/exchanges/request/limit_test.go
Ryan O'Hara-Reid 719e6bebfe request/ratelimit: Add context value check and fix bug (#2073)
* Add WithNoDelayPermitted and fix bug on cancel all

* rm reservations as it is only for last reservation when cancelling and needed to take into account of the actual offset delay for correct returning of tokens, update tests

* export error

* misc fix

* more misc fix

* Add concurrent protection, cancel in reverse and add tests

* lint: fix

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* gk: nits doo

* linter: fix

* boss king: nits

* crank: nits

* crank: test patch which was cooked and had to be done manually

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* gk: nits

* linter: fix

* thrasher: nits

* use error collector in tests

* nolint: direction

* gk: fixup!

* my life has elapsed

* thrasher-: Because of synctest, we can now be deterministic with values. This rids a lot of the redundant wait calls which served no purpose

* thrasher-: patched

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
Co-authored-by: shazbert <shazbert@DESKTOP-3QKKR6J.localdomain>
2025-11-27 11:10:11 +11:00

253 lines
8.6 KiB
Go

package request
import (
"context"
"sync/atomic"
"testing"
"testing/synctest"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/common"
"golang.org/x/time/rate"
)
func TestRateLimit(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
err := (*RateLimiterWithWeight)(nil).RateLimit(t.Context())
assert.ErrorContains(t, err, "nil pointer: *request.RateLimiterWithWeight")
r := &RateLimiterWithWeight{limiter: rate.NewLimiter(rate.Limit(1), 1)}
err = r.RateLimit(t.Context())
assert.ErrorIs(t, err, errInvalidWeight, "should return errInvalidWeightCount for zero weight")
r = NewRateLimitWithWeight(time.Second, 10, 1)
start := time.Now()
err = r.RateLimit(t.Context())
elapsed := time.Since(start)
require.NoError(t, err, "rate limit must not error")
assert.Zero(t, elapsed, "first call should be immediate")
r = NewRateLimitWithWeight(time.Second, 10, 5)
start = time.Now()
err = r.RateLimit(t.Context())
elapsed = time.Since(start)
require.NoError(t, err, "rate limit must not error")
assert.Equal(t, 400*time.Millisecond, elapsed, "should wait 400ms (4 intervals) for weight 5")
r = NewRateLimitWithWeight(100*time.Millisecond, 1, 1)
start = time.Now()
err = r.RateLimit(WithDelayNotAllowed(t.Context()))
synctest.Wait()
elapsed = time.Since(start)
require.NoError(t, err, "first rate limit call must not error and must be immediate")
assert.Zero(t, elapsed, "first call should be immediate")
start = time.Now()
err = r.RateLimit(t.Context())
elapsed = time.Since(start)
require.NoError(t, err, "second rate limit call must not error")
assert.Equal(t, 100*time.Millisecond, elapsed, "second call should be delayed by exactly 100ms")
err = r.RateLimit(WithDelayNotAllowed(t.Context()))
assert.ErrorIs(t, err, ErrDelayNotAllowed, "should return correct error")
ctx, cancel := context.WithCancel(t.Context())
cancel()
err = r.RateLimit(ctx)
assert.ErrorIs(t, err, context.Canceled, "should return correct error when context is canceled")
// Rate limit is 100ms. Set deadline for 50ms.
ctx, cancel = context.WithTimeout(t.Context(), 50*time.Millisecond)
defer cancel()
err = r.RateLimit(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded, "should return correct error when context deadline exceeded")
})
}
func TestRateLimit_Concurrent_WithFailure(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
r := NewRateLimitWithWeight(time.Second, 10, 1)
tn := time.Now()
errs := common.ErrorCollector{}
for i := range 10 {
ctx := t.Context()
if i%2 == 0 {
ctx = WithDelayNotAllowed(ctx)
}
errs.Go(func() error { return r.RateLimit(ctx) })
}
require.ErrorContains(t, errs.Collect(), "delay not allowed, delay not allowed, delay not allowed, delay not allowed", "must return correct error")
assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time")
})
}
func TestRateLimit_Concurrent(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
r := NewRateLimitWithWeight(time.Second, 10, 1)
tn := time.Now()
errs := common.ErrorCollector{}
for range 10 {
errs.Go(func() error { return r.RateLimit(t.Context()) })
}
require.NoError(t, errs.Collect(), "rate limit must not error")
assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time")
})
}
func TestRateLimit_Linear_WithFailure(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
r := NewRateLimitWithWeight(time.Second, 10, 1)
tn := time.Now()
for i := range 10 {
ctx := t.Context()
if i%2 == 0 {
ctx = WithDelayNotAllowed(ctx)
}
if err := r.RateLimit(ctx); err != nil {
require.ErrorIs(t, err, ErrDelayNotAllowed, "must return correct error")
}
}
assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time")
})
}
func TestRateLimit_Linear(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
r := NewRateLimitWithWeight(time.Second, 10, 1)
tn := time.Now()
for range 10 {
require.NoError(t, r.RateLimit(t.Context()))
}
assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time")
})
}
func TestNewRateLimit(t *testing.T) {
t.Parallel()
r := NewRateLimit(time.Second, 10)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, rate.Limit(10), r.Limit(), "limit should be 10 per second")
r = NewRateLimit(time.Second, 0)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero actions")
r = NewRateLimit(time.Second, -1)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative actions")
r = NewRateLimit(0, 10)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero interval")
r = NewRateLimit(-time.Second, 10)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative interval")
}
func TestNewRateLimitWithWeight(t *testing.T) {
t.Parallel()
r := NewRateLimitWithWeight(time.Second, 10, 5)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, Weight(5), r.weight, "weight should be 5")
assert.Equal(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second")
}
func TestNewWeightedRateLimitByDuration(t *testing.T) {
t.Parallel()
r := NewWeightedRateLimitByDuration(time.Second)
require.NotNil(t, r, "limiter must not be nil")
assert.Equal(t, Weight(1), r.weight, "weight should be 1")
assert.Equal(t, rate.Limit(1), r.limiter.Limit(), "limit should be 1 per second")
}
func TestGetRateLimiterWithWeight(t *testing.T) {
t.Parallel()
r := rate.NewLimiter(rate.Limit(10), 1)
weighted := GetRateLimiterWithWeight(r, 5)
require.NotNil(t, weighted, "weighted limiter must not be nil")
assert.Equal(t, Weight(5), weighted.weight, "weight should be 5")
assert.Equal(t, r, weighted.limiter, "should reference same limiter")
}
func TestNewBasicRateLimit(t *testing.T) {
t.Parallel()
defs := NewBasicRateLimit(time.Second, 10, 5)
require.NotNil(t, defs, "definitions must not be nil")
require.Len(t, defs, 3, "must have 3 definitions")
for _, key := range []EndpointLimit{Unset, Auth, UnAuth} {
r, ok := defs[key]
require.Truef(t, ok, "must have definition for %v", key)
assert.Equalf(t, Weight(5), r.weight, "weight should be 5 for %v", key)
assert.Equalf(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second for %v", key)
}
assert.Same(t, defs[Unset], defs[Auth], "Unset and Auth should be same instance")
assert.Same(t, defs[Auth], defs[UnAuth], "Auth and UnAuth should be same instance")
}
func TestWithDelayNotAllowed(t *testing.T) {
t.Parallel()
assert.True(t, hasDelayNotAllowed(WithDelayNotAllowed(t.Context())))
assert.False(t, hasDelayNotAllowed(t.Context()))
assert.False(t, hasDelayNotAllowed(WithVerbose(t.Context())))
}
func TestCancelAll(t *testing.T) {
t.Parallel()
var reservations []*rate.Reservation
cancelAll(reservations, time.Now())
r := rate.NewLimiter(rate.Limit(1), 1)
tn := time.Now()
reservations = append(reservations, r.ReserveN(tn, 1))
require.Equal(t, 0.0, r.TokensAt(tn), "must have zero tokens remaining")
reservations = append(reservations, r.ReserveN(tn, 1))
require.Equal(t, time.Second, reservations[1].DelayFrom(tn), "second reservation must have 1 second delay")
require.Equal(t, -1.0, r.TokensAt(tn), "must have negative tokens remaining")
cancelAll(reservations, tn)
require.Equal(t, 1.0, r.TokensAt(tn), "must have 1 token remaining after cancellation")
}
func TestInitiateRateLimit(t *testing.T) {
t.Parallel()
var r *Requester
err := r.InitiateRateLimit(t.Context(), Unset)
assert.ErrorIs(t, err, ErrRequestSystemIsNil, "should return correct error")
r = &Requester{}
atomic.StoreInt32(&r.disableRateLimiter, 1)
err = r.InitiateRateLimit(t.Context(), Unset)
assert.NoError(t, err, "should not error when rate limiter is disabled")
atomic.StoreInt32(&r.disableRateLimiter, 0)
err = r.InitiateRateLimit(t.Context(), Unset)
assert.ErrorContains(t, err, "nil pointer: request.RateLimitDefinitions", "should return correct error when limiter is nil")
r.limiter = NewBasicRateLimit(time.Second, 10, 1)
err = r.InitiateRateLimit(t.Context(), Unset)
assert.NoError(t, err, "should not error on valid rate limit initiation")
}