mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
request/ratelimit: Add context value check and fix bug (#2073)
* Add WithNoDelayPermitted and fix bug on cancel all * rm reservations as it is only for last reservation when cancelling and needed to take into account of the actual offset delay for correct returning of tokens, update tests * export error * misc fix * more misc fix * Add concurrent protection, cancel in reverse and add tests * lint: fix * Update exchanges/request/limit.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/request/limit.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/request/limit.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * Update exchanges/request/limit.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * gk: nits doo * linter: fix * boss king: nits * crank: nits * crank: test patch which was cooked and had to be done manually * Update exchanges/request/limit.go Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> * gk: nits * linter: fix * thrasher: nits * use error collector in tests * nolint: direction * gk: fixup! * my life has elapsed * thrasher-: Because of synctest, we can now be deterministic with values. This rids a lot of the redundant wait calls which served no purpose * thrasher-: patched --------- Co-authored-by: shazbert <ryan.oharareid@thrasher.io> Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com> Co-authored-by: shazbert <shazbert@DESKTOP-3QKKR6J.localdomain>
This commit is contained in:
@@ -204,7 +204,7 @@ func (c *connection) writeToConn(ctx context.Context, epl request.EndpointLimit,
|
||||
}
|
||||
|
||||
if rl != nil {
|
||||
if err := request.RateLimit(ctx, rl); err != nil {
|
||||
if err := rl.RateLimit(ctx); err != nil {
|
||||
return fmt.Errorf("%s websocket connection: rate limit error: %w", c.ExchangeName, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,66 +4,54 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Defines rate limiting errors
|
||||
// Rate limiting errors.
|
||||
var (
|
||||
ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled")
|
||||
ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled")
|
||||
ErrDelayNotAllowed = errors.New("delay not allowed")
|
||||
|
||||
errLimiterSystemIsNil = errors.New("limiter system is nil")
|
||||
errInvalidWeightCount = errors.New("invalid weight count must equal or greater than 1")
|
||||
errSpecificRateLimiterIsNil = errors.New("specific rate limiter is nil")
|
||||
errInvalidWeight = errors.New("weight must be equal-or-greater than 1")
|
||||
)
|
||||
|
||||
// RateLimitNotRequired is a no-op rate limiter
|
||||
// RateLimitNotRequired is a no-op rate limiter.
|
||||
var RateLimitNotRequired *RateLimiterWithWeight
|
||||
|
||||
// Const here define individual functionality sub types for rate limiting
|
||||
// Const here define individual functionality sub types for rate limiting.
|
||||
const (
|
||||
Unset EndpointLimit = iota
|
||||
Auth
|
||||
UnAuth
|
||||
)
|
||||
|
||||
// EndpointLimit defines individual endpoint rate limits that are set when
|
||||
// New is called.
|
||||
// EndpointLimit defines individual endpoint rate limits.
|
||||
type EndpointLimit uint16
|
||||
|
||||
// Weight defines the number of reservations to be used. This is a generalised
|
||||
// weight for rate limiting. e.g. n weight = n request. i.e. 50 Weight = 50
|
||||
// requests.
|
||||
// Weight defines the number of reservations to be used. This is a generalised weight for rate limiting.
|
||||
// e.g. n weight = n request. i.e. 50 Weight = 50 requests.
|
||||
type Weight uint8
|
||||
|
||||
// RateLimitDefinitions is a map of endpoint limits to rate limiters
|
||||
// RateLimitDefinitions is a map of endpoint limits to rate limiters.
|
||||
type RateLimitDefinitions map[any]*RateLimiterWithWeight
|
||||
|
||||
// RateLimiterWithWeight is a rate limiter coupled with a weight count which
|
||||
// refers to the number or weighting of the request. This is used to define
|
||||
// the rate limit for a specific endpoint.
|
||||
// RateLimiterWithWeight is a rate limiter coupled with a weight which refers to the number or weighting of the request.
|
||||
// This is used to define the rate limit for a specific endpoint.
|
||||
type RateLimiterWithWeight struct {
|
||||
*rate.Limiter
|
||||
Weight
|
||||
limiter *rate.Limiter
|
||||
weight Weight
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// Reservations is a slice of rate reservations
|
||||
type Reservations []*rate.Reservation
|
||||
|
||||
// CancelAll cancels all potential reservations to free up rate limiter for
|
||||
// context cancellations and deadline exceeded cases.
|
||||
func (r Reservations) CancelAll() {
|
||||
for x := range r {
|
||||
r[x].Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimit creates a new RateLimit based of time interval and how many
|
||||
// actions allowed and breaks it down to an actions-per-second basis -- Burst
|
||||
// rate is kept as one as this is not supported for out-bound requests.
|
||||
// NewRateLimit creates a new RateLimit based of time interval and how many actions allowed and breaks it down to an
|
||||
// actions-per-second basis -- Burst rate is kept as one as this is not supported for out-bound requests.
|
||||
func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
|
||||
if actions <= 0 || interval <= 0 {
|
||||
// Returns an un-restricted rate limiter
|
||||
@@ -75,34 +63,32 @@ func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(rps), 1)
|
||||
}
|
||||
|
||||
// NewRateLimitWithWeight creates a new RateLimit based of time interval and how
|
||||
// many actions allowed. This also has a weight count which refers to the number
|
||||
// or weighting of the request. This is used to define the rate limit for a
|
||||
// NewRateLimitWithWeight creates a new RateLimit based of time interval and how many actions allowed. This also has a
|
||||
// weight count which refers to the number or weighting of the request. This is used to define the rate limit for a
|
||||
// specific endpoint.
|
||||
func NewRateLimitWithWeight(interval time.Duration, actions int, weight Weight) *RateLimiterWithWeight {
|
||||
return GetRateLimiterWithWeight(NewRateLimit(interval, actions), weight)
|
||||
}
|
||||
|
||||
// NewWeightedRateLimitByDuration creates a new RateLimit based of time
|
||||
// interval. This equates to 1 action per interval. The weight is set to 1.
|
||||
// NewWeightedRateLimitByDuration creates a new RateLimit based of time interval. This equates to 1 action per interval.
|
||||
// The weight is set to 1.
|
||||
func NewWeightedRateLimitByDuration(interval time.Duration) *RateLimiterWithWeight {
|
||||
return NewRateLimitWithWeight(interval, 1, 1)
|
||||
}
|
||||
|
||||
// GetRateLimiterWithWeight couples a rate limiter with a weight count into an
|
||||
// accepted defined rate limiter with weight struct
|
||||
// GetRateLimiterWithWeight couples a rate limiter with a weight count into an accepted defined rate limiter with weight
|
||||
// struct.
|
||||
func GetRateLimiterWithWeight(l *rate.Limiter, weight Weight) *RateLimiterWithWeight {
|
||||
return &RateLimiterWithWeight{l, weight}
|
||||
return &RateLimiterWithWeight{limiter: l, weight: weight}
|
||||
}
|
||||
|
||||
// NewBasicRateLimit returns an object that implements the limiter interface
|
||||
// for basic rate limit
|
||||
// NewBasicRateLimit returns an object that implements the limiter interface for basic rate limit.
|
||||
func NewBasicRateLimit(interval time.Duration, actions int, weight Weight) RateLimitDefinitions {
|
||||
rl := NewRateLimitWithWeight(interval, actions, weight)
|
||||
return RateLimitDefinitions{Unset: rl, Auth: rl, UnAuth: rl}
|
||||
}
|
||||
|
||||
// InitiateRateLimit sleeps for designated end point rate limits
|
||||
// InitiateRateLimit sleeps for designated end point rate limits.
|
||||
func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
@@ -110,22 +96,16 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro
|
||||
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
|
||||
return nil
|
||||
}
|
||||
if r.limiter == nil {
|
||||
return fmt.Errorf("cannot rate limit request %w", errLimiterSystemIsNil)
|
||||
if err := common.NilGuard(r.limiter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rateLimiter := r.limiter[e]
|
||||
|
||||
err := RateLimit(ctx, rateLimiter)
|
||||
if err != nil {
|
||||
if err := r.limiter[e].RateLimit(ctx); err != nil {
|
||||
return fmt.Errorf("cannot rate limit request %w for endpoint %d", err, e)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRateLimiterDefinitions returns the rate limiter definitions for the
|
||||
// requester
|
||||
// GetRateLimiterDefinitions returns the rate limiter definitions for the requester.
|
||||
func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions {
|
||||
if r == nil {
|
||||
return nil
|
||||
@@ -133,47 +113,63 @@ func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions {
|
||||
return r.limiter
|
||||
}
|
||||
|
||||
// RateLimit is a function that will rate limit a request based on the rate
|
||||
// limiter provided. It will return an error if the context is cancelled or
|
||||
// deadline exceeded.
|
||||
func RateLimit(ctx context.Context, rateLimiter *RateLimiterWithWeight) error {
|
||||
if rateLimiter == nil {
|
||||
return errSpecificRateLimiterIsNil
|
||||
// RateLimit throttles a request based on weight, delaying the request.
|
||||
// Errors if no delay is permitted via the context and a delay is required.
|
||||
func (r *RateLimiterWithWeight) RateLimit(ctx context.Context) error {
|
||||
if err := common.NilGuard(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rateLimiter.Weight <= 0 {
|
||||
return errInvalidWeightCount
|
||||
r.m.Lock()
|
||||
if r.weight == 0 {
|
||||
r.m.Unlock()
|
||||
return errInvalidWeight
|
||||
}
|
||||
|
||||
var finalDelay time.Duration
|
||||
reservations := make(Reservations, rateLimiter.Weight)
|
||||
for i := Weight(0); i < rateLimiter.Weight; i++ {
|
||||
// Consume 1 weight at a time as this avoids needing burst capacity in the limiter,
|
||||
// which would otherwise allow the rate limit to be exceeded over short periods
|
||||
reservations[i] = rateLimiter.Reserve()
|
||||
finalDelay = reservations[i].Delay()
|
||||
tn := time.Now()
|
||||
reserved := make([]*rate.Reservation, 0, r.weight)
|
||||
for range r.weight {
|
||||
// This avoids needing burst capacity in the limiter, which would otherwise allow the rate limit to be exceeded over short periods
|
||||
reserved = append(reserved, r.limiter.ReserveN(tn, 1))
|
||||
}
|
||||
finalDelay := reserved[len(reserved)-1].DelayFrom(tn)
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok && dl.Before(time.Now().Add(finalDelay)) {
|
||||
reservations.CancelAll()
|
||||
return fmt.Errorf("rate limit delay of %s will exceed deadline: %w",
|
||||
finalDelay,
|
||||
context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
tick := time.NewTimer(finalDelay)
|
||||
select {
|
||||
case <-tick.C:
|
||||
if finalDelay == 0 {
|
||||
r.m.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasDelayNotAllowed(ctx) {
|
||||
cancelAll(reserved, tn)
|
||||
r.m.Unlock()
|
||||
return ErrDelayNotAllowed
|
||||
}
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok && dl.Before(tn.Add(finalDelay)) {
|
||||
cancelAll(reserved, tn)
|
||||
r.m.Unlock()
|
||||
return fmt.Errorf("rate limit delay of %s will exceed deadline: %w", finalDelay, context.DeadlineExceeded)
|
||||
}
|
||||
r.m.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(finalDelay):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
tick.Stop()
|
||||
reservations.CancelAll()
|
||||
return ctx.Err()
|
||||
}
|
||||
// TODO: Shutdown case
|
||||
}
|
||||
|
||||
// DisableRateLimiter disables the rate limiting system for the exchange
|
||||
// cancelAll cancels all reservations at a specific time.
|
||||
// Does not provide locking protection, so callers can maintain a single lock throughout.
|
||||
func cancelAll(reservations []*rate.Reservation, at time.Time) {
|
||||
slices.Reverse(reservations) // cancel in reverse order for correct token reimbursement
|
||||
for _, r := range reservations {
|
||||
r.CancelAt(at)
|
||||
}
|
||||
}
|
||||
|
||||
// DisableRateLimiter disables the rate limiting system for the exchange.
|
||||
func (r *Requester) DisableRateLimiter() error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
@@ -184,7 +180,7 @@ func (r *Requester) DisableRateLimiter() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableRateLimiter enables the rate limiting system for the exchange
|
||||
// EnableRateLimiter enables the rate limiting system for the exchange.
|
||||
func (r *Requester) EnableRateLimiter() error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
@@ -194,3 +190,15 @@ func (r *Requester) EnableRateLimiter() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type delayNotAllowedKey struct{}
|
||||
|
||||
// WithDelayNotAllowed adds a value to the context that indicates that no delay is allowed for rate limiting.
|
||||
func WithDelayNotAllowed(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, delayNotAllowedKey{}, struct{}{})
|
||||
}
|
||||
|
||||
func hasDelayNotAllowed(ctx context.Context) bool {
|
||||
_, ok := ctx.Value(delayNotAllowedKey{}).(struct{})
|
||||
return ok
|
||||
}
|
||||
|
||||
252
exchanges/request/limit_test.go
Normal file
252
exchanges/request/limit_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
|
||||
err := (*RateLimiterWithWeight)(nil).RateLimit(t.Context())
|
||||
assert.ErrorContains(t, err, "nil pointer: *request.RateLimiterWithWeight")
|
||||
|
||||
r := &RateLimiterWithWeight{limiter: rate.NewLimiter(rate.Limit(1), 1)}
|
||||
err = r.RateLimit(t.Context())
|
||||
assert.ErrorIs(t, err, errInvalidWeight, "should return errInvalidWeightCount for zero weight")
|
||||
|
||||
r = NewRateLimitWithWeight(time.Second, 10, 1)
|
||||
start := time.Now()
|
||||
err = r.RateLimit(t.Context())
|
||||
elapsed := time.Since(start)
|
||||
require.NoError(t, err, "rate limit must not error")
|
||||
assert.Zero(t, elapsed, "first call should be immediate")
|
||||
|
||||
r = NewRateLimitWithWeight(time.Second, 10, 5)
|
||||
start = time.Now()
|
||||
err = r.RateLimit(t.Context())
|
||||
elapsed = time.Since(start)
|
||||
require.NoError(t, err, "rate limit must not error")
|
||||
assert.Equal(t, 400*time.Millisecond, elapsed, "should wait 400ms (4 intervals) for weight 5")
|
||||
|
||||
r = NewRateLimitWithWeight(100*time.Millisecond, 1, 1)
|
||||
start = time.Now()
|
||||
err = r.RateLimit(WithDelayNotAllowed(t.Context()))
|
||||
synctest.Wait()
|
||||
elapsed = time.Since(start)
|
||||
require.NoError(t, err, "first rate limit call must not error and must be immediate")
|
||||
assert.Zero(t, elapsed, "first call should be immediate")
|
||||
|
||||
start = time.Now()
|
||||
err = r.RateLimit(t.Context())
|
||||
elapsed = time.Since(start)
|
||||
require.NoError(t, err, "second rate limit call must not error")
|
||||
assert.Equal(t, 100*time.Millisecond, elapsed, "second call should be delayed by exactly 100ms")
|
||||
|
||||
err = r.RateLimit(WithDelayNotAllowed(t.Context()))
|
||||
assert.ErrorIs(t, err, ErrDelayNotAllowed, "should return correct error")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel()
|
||||
err = r.RateLimit(ctx)
|
||||
assert.ErrorIs(t, err, context.Canceled, "should return correct error when context is canceled")
|
||||
|
||||
// Rate limit is 100ms. Set deadline for 50ms.
|
||||
ctx, cancel = context.WithTimeout(t.Context(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
err = r.RateLimit(ctx)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded, "should return correct error when context deadline exceeded")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimit_Concurrent_WithFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
|
||||
r := NewRateLimitWithWeight(time.Second, 10, 1)
|
||||
tn := time.Now()
|
||||
errs := common.ErrorCollector{}
|
||||
for i := range 10 {
|
||||
ctx := t.Context()
|
||||
if i%2 == 0 {
|
||||
ctx = WithDelayNotAllowed(ctx)
|
||||
}
|
||||
errs.Go(func() error { return r.RateLimit(ctx) })
|
||||
}
|
||||
|
||||
require.ErrorContains(t, errs.Collect(), "delay not allowed, delay not allowed, delay not allowed, delay not allowed", "must return correct error")
|
||||
assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimit_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
|
||||
r := NewRateLimitWithWeight(time.Second, 10, 1)
|
||||
tn := time.Now()
|
||||
errs := common.ErrorCollector{}
|
||||
for range 10 {
|
||||
errs.Go(func() error { return r.RateLimit(t.Context()) })
|
||||
}
|
||||
require.NoError(t, errs.Collect(), "rate limit must not error")
|
||||
assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimit_Linear_WithFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
|
||||
r := NewRateLimitWithWeight(time.Second, 10, 1)
|
||||
tn := time.Now()
|
||||
for i := range 10 {
|
||||
ctx := t.Context()
|
||||
if i%2 == 0 {
|
||||
ctx = WithDelayNotAllowed(ctx)
|
||||
}
|
||||
if err := r.RateLimit(ctx); err != nil {
|
||||
require.ErrorIs(t, err, ErrDelayNotAllowed, "must return correct error")
|
||||
}
|
||||
}
|
||||
assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimit_Linear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive
|
||||
r := NewRateLimitWithWeight(time.Second, 10, 1)
|
||||
tn := time.Now()
|
||||
for range 10 {
|
||||
require.NoError(t, r.RateLimit(t.Context()))
|
||||
}
|
||||
assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := NewRateLimit(time.Second, 10)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, rate.Limit(10), r.Limit(), "limit should be 10 per second")
|
||||
|
||||
r = NewRateLimit(time.Second, 0)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero actions")
|
||||
|
||||
r = NewRateLimit(time.Second, -1)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative actions")
|
||||
|
||||
r = NewRateLimit(0, 10)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero interval")
|
||||
|
||||
r = NewRateLimit(-time.Second, 10)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative interval")
|
||||
}
|
||||
|
||||
func TestNewRateLimitWithWeight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := NewRateLimitWithWeight(time.Second, 10, 5)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, Weight(5), r.weight, "weight should be 5")
|
||||
assert.Equal(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second")
|
||||
}
|
||||
|
||||
func TestNewWeightedRateLimitByDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := NewWeightedRateLimitByDuration(time.Second)
|
||||
require.NotNil(t, r, "limiter must not be nil")
|
||||
assert.Equal(t, Weight(1), r.weight, "weight should be 1")
|
||||
assert.Equal(t, rate.Limit(1), r.limiter.Limit(), "limit should be 1 per second")
|
||||
}
|
||||
|
||||
func TestGetRateLimiterWithWeight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := rate.NewLimiter(rate.Limit(10), 1)
|
||||
weighted := GetRateLimiterWithWeight(r, 5)
|
||||
require.NotNil(t, weighted, "weighted limiter must not be nil")
|
||||
assert.Equal(t, Weight(5), weighted.weight, "weight should be 5")
|
||||
assert.Equal(t, r, weighted.limiter, "should reference same limiter")
|
||||
}
|
||||
|
||||
func TestNewBasicRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defs := NewBasicRateLimit(time.Second, 10, 5)
|
||||
require.NotNil(t, defs, "definitions must not be nil")
|
||||
require.Len(t, defs, 3, "must have 3 definitions")
|
||||
|
||||
for _, key := range []EndpointLimit{Unset, Auth, UnAuth} {
|
||||
r, ok := defs[key]
|
||||
require.Truef(t, ok, "must have definition for %v", key)
|
||||
assert.Equalf(t, Weight(5), r.weight, "weight should be 5 for %v", key)
|
||||
assert.Equalf(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second for %v", key)
|
||||
}
|
||||
|
||||
assert.Same(t, defs[Unset], defs[Auth], "Unset and Auth should be same instance")
|
||||
assert.Same(t, defs[Auth], defs[UnAuth], "Auth and UnAuth should be same instance")
|
||||
}
|
||||
|
||||
func TestWithDelayNotAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.True(t, hasDelayNotAllowed(WithDelayNotAllowed(t.Context())))
|
||||
assert.False(t, hasDelayNotAllowed(t.Context()))
|
||||
assert.False(t, hasDelayNotAllowed(WithVerbose(t.Context())))
|
||||
}
|
||||
|
||||
func TestCancelAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var reservations []*rate.Reservation
|
||||
cancelAll(reservations, time.Now())
|
||||
|
||||
r := rate.NewLimiter(rate.Limit(1), 1)
|
||||
tn := time.Now()
|
||||
reservations = append(reservations, r.ReserveN(tn, 1))
|
||||
require.Equal(t, 0.0, r.TokensAt(tn), "must have zero tokens remaining")
|
||||
reservations = append(reservations, r.ReserveN(tn, 1))
|
||||
require.Equal(t, time.Second, reservations[1].DelayFrom(tn), "second reservation must have 1 second delay")
|
||||
require.Equal(t, -1.0, r.TokensAt(tn), "must have negative tokens remaining")
|
||||
cancelAll(reservations, tn)
|
||||
require.Equal(t, 1.0, r.TokensAt(tn), "must have 1 token remaining after cancellation")
|
||||
}
|
||||
|
||||
func TestInitiateRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var r *Requester
|
||||
err := r.InitiateRateLimit(t.Context(), Unset)
|
||||
assert.ErrorIs(t, err, ErrRequestSystemIsNil, "should return correct error")
|
||||
|
||||
r = &Requester{}
|
||||
atomic.StoreInt32(&r.disableRateLimiter, 1)
|
||||
err = r.InitiateRateLimit(t.Context(), Unset)
|
||||
assert.NoError(t, err, "should not error when rate limiter is disabled")
|
||||
|
||||
atomic.StoreInt32(&r.disableRateLimiter, 0)
|
||||
err = r.InitiateRateLimit(t.Context(), Unset)
|
||||
assert.ErrorContains(t, err, "nil pointer: request.RateLimitDefinitions", "should return correct error when limiter is nil")
|
||||
|
||||
r.limiter = NewBasicRateLimit(time.Second, 10, 1)
|
||||
err = r.InitiateRateLimit(t.Context(), Unset)
|
||||
assert.NoError(t, err, "should not error on valid rate limit initiation")
|
||||
}
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/nonce"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const unexpected = "unexpected values"
|
||||
@@ -54,7 +53,7 @@ func TestMain(m *testing.M) {
|
||||
w.WriteHeader(http.StatusGatewayTimeout)
|
||||
})
|
||||
sm.HandleFunc("/rate", func(w http.ResponseWriter, _ *http.Request) {
|
||||
if !serverLimit.Allow() {
|
||||
if !serverLimit.limiter.Allow() {
|
||||
http.Error(w,
|
||||
http.StatusText(http.StatusTooManyRequests),
|
||||
http.StatusTooManyRequests)
|
||||
@@ -70,7 +69,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
})
|
||||
sm.HandleFunc("/rate-retry", func(w http.ResponseWriter, _ *http.Request) {
|
||||
if !serverLimitRetry.Allow() {
|
||||
if !serverLimitRetry.limiter.Allow() {
|
||||
w.Header().Add("Retry-After", strconv.Itoa(int(math.Round(serverLimitInterval.Seconds()))))
|
||||
http.Error(w,
|
||||
http.StatusText(http.StatusTooManyRequests),
|
||||
@@ -102,31 +101,6 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(issues)
|
||||
}
|
||||
|
||||
func TestNewRateLimitWithWeight(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRateLimitWithWeight(time.Second*10, 5, 1)
|
||||
if r.Limit() != 0.5 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
// Ensures rate limiting factor is the same
|
||||
r = NewRateLimitWithWeight(time.Second*2, 1, 1)
|
||||
if r.Limit() != 0.5 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
// Test for open rate limit
|
||||
r = NewRateLimitWithWeight(time.Second*2, 0, 1)
|
||||
if r.Limit() != rate.Inf {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
r = NewRateLimitWithWeight(0, 69, 1)
|
||||
if r.Limit() != rate.Inf {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -235,7 +209,7 @@ func TestDoRequest(t *testing.T) {
|
||||
|
||||
// Invalid/missing endpoint limit
|
||||
err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &Item{Path: testURL}, nil }, UnauthenticatedRequest)
|
||||
require.ErrorIs(t, err, errSpecificRateLimiterIsNil)
|
||||
require.ErrorIs(t, err, common.ErrNilPointer)
|
||||
|
||||
// Force debug
|
||||
err = r.SendPayload(ctx, UnAuth, func() (*Item, error) {
|
||||
|
||||
Reference in New Issue
Block a user