request/ratelimit: Add context value check and fix bug (#2073)

* Add WithNoDelayPermitted and fix bug on cancel all

* rm reservations as it is only for last reservation when cancelling and needed to take into account of the actual offset delay for correct returning of tokens, update tests

* export error

* misc fix

* more misc fix

* Add concurrent protection, cancel in reverse and add tests

* lint: fix

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* gk: nits doo

* linter: fix

* boss king: nits

* crank: nits

* crank: test patch which was cooked and had to be done manually

* Update exchanges/request/limit.go

Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>

* gk: nits

* linter: fix

* thrasher: nits

* use error collector in tests

* nolint: direction

* gk: fixup!

* my life has elapsed

* thrasher-: Because of synctest, we can now be deterministic with values. This rids a lot of the redundant wait calls which served no purpose

* thrasher-: patched

---------

Co-authored-by: shazbert <ryan.oharareid@thrasher.io>
Co-authored-by: Gareth Kirwan <gbjkirwan@gmail.com>
Co-authored-by: shazbert <shazbert@DESKTOP-3QKKR6J.localdomain>
This commit is contained in:
Ryan O'Hara-Reid
2025-11-27 11:10:11 +11:00
committed by GitHub
parent 2943a7f800
commit 719e6bebfe
4 changed files with 348 additions and 114 deletions

View File

@@ -4,66 +4,54 @@ import (
"context"
"errors"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"golang.org/x/time/rate"
)
// Defines rate limiting errors
// Rate limiting errors.
var (
ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled")
ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled")
ErrDelayNotAllowed = errors.New("delay not allowed")
errLimiterSystemIsNil = errors.New("limiter system is nil")
errInvalidWeightCount = errors.New("invalid weight count must equal or greater than 1")
errSpecificRateLimiterIsNil = errors.New("specific rate limiter is nil")
errInvalidWeight = errors.New("weight must be equal-or-greater than 1")
)
// RateLimitNotRequired is a no-op rate limiter
// RateLimitNotRequired is a no-op rate limiter.
var RateLimitNotRequired *RateLimiterWithWeight
// Const here define individual functionality sub types for rate limiting
// Const here define individual functionality sub types for rate limiting.
const (
Unset EndpointLimit = iota
Auth
UnAuth
)
// EndpointLimit defines individual endpoint rate limits that are set when
// New is called.
// EndpointLimit defines individual endpoint rate limits.
type EndpointLimit uint16
// Weight defines the number of reservations to be used. This is a generalised
// weight for rate limiting. e.g. n weight = n request. i.e. 50 Weight = 50
// requests.
// Weight defines the number of reservations to be used. This is a generalised weight for rate limiting.
// e.g. n weight = n request. i.e. 50 Weight = 50 requests.
type Weight uint8
// RateLimitDefinitions is a map of endpoint limits to rate limiters
// RateLimitDefinitions is a map of endpoint limits to rate limiters.
type RateLimitDefinitions map[any]*RateLimiterWithWeight
// RateLimiterWithWeight is a rate limiter coupled with a weight count which
// refers to the number or weighting of the request. This is used to define
// the rate limit for a specific endpoint.
// RateLimiterWithWeight is a rate limiter coupled with a weight which refers to the number or weighting of the request.
// This is used to define the rate limit for a specific endpoint.
type RateLimiterWithWeight struct {
*rate.Limiter
Weight
limiter *rate.Limiter
weight Weight
m sync.Mutex
}
// Reservations is a slice of rate reservations
type Reservations []*rate.Reservation
// CancelAll cancels all potential reservations to free up rate limiter for
// context cancellations and deadline exceeded cases.
func (r Reservations) CancelAll() {
for x := range r {
r[x].Cancel()
}
}
// NewRateLimit creates a new RateLimit based of time interval and how many
// actions allowed and breaks it down to an actions-per-second basis -- Burst
// rate is kept as one as this is not supported for out-bound requests.
// NewRateLimit creates a new RateLimit based of time interval and how many actions allowed and breaks it down to an
// actions-per-second basis -- Burst rate is kept as one as this is not supported for out-bound requests.
func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
if actions <= 0 || interval <= 0 {
// Returns an un-restricted rate limiter
@@ -75,34 +63,32 @@ func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
return rate.NewLimiter(rate.Limit(rps), 1)
}
// NewRateLimitWithWeight creates a new RateLimit based of time interval and how
// many actions allowed. This also has a weight count which refers to the number
// or weighting of the request. This is used to define the rate limit for a
// NewRateLimitWithWeight creates a new RateLimit based of time interval and how many actions allowed. This also has a
// weight count which refers to the number or weighting of the request. This is used to define the rate limit for a
// specific endpoint.
func NewRateLimitWithWeight(interval time.Duration, actions int, weight Weight) *RateLimiterWithWeight {
return GetRateLimiterWithWeight(NewRateLimit(interval, actions), weight)
}
// NewWeightedRateLimitByDuration creates a new RateLimit based of time
// interval. This equates to 1 action per interval. The weight is set to 1.
// NewWeightedRateLimitByDuration creates a new RateLimit based of time interval. This equates to 1 action per interval.
// The weight is set to 1.
func NewWeightedRateLimitByDuration(interval time.Duration) *RateLimiterWithWeight {
return NewRateLimitWithWeight(interval, 1, 1)
}
// GetRateLimiterWithWeight couples a rate limiter with a weight count into an
// accepted defined rate limiter with weight struct
// GetRateLimiterWithWeight couples a rate limiter with a weight count into an accepted defined rate limiter with weight
// struct.
func GetRateLimiterWithWeight(l *rate.Limiter, weight Weight) *RateLimiterWithWeight {
return &RateLimiterWithWeight{l, weight}
return &RateLimiterWithWeight{limiter: l, weight: weight}
}
// NewBasicRateLimit returns an object that implements the limiter interface
// for basic rate limit
// NewBasicRateLimit returns an object that implements the limiter interface for basic rate limit.
func NewBasicRateLimit(interval time.Duration, actions int, weight Weight) RateLimitDefinitions {
rl := NewRateLimitWithWeight(interval, actions, weight)
return RateLimitDefinitions{Unset: rl, Auth: rl, UnAuth: rl}
}
// InitiateRateLimit sleeps for designated end point rate limits
// InitiateRateLimit sleeps for designated end point rate limits.
func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error {
if r == nil {
return ErrRequestSystemIsNil
@@ -110,22 +96,16 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
return nil
}
if r.limiter == nil {
return fmt.Errorf("cannot rate limit request %w", errLimiterSystemIsNil)
if err := common.NilGuard(r.limiter); err != nil {
return err
}
rateLimiter := r.limiter[e]
err := RateLimit(ctx, rateLimiter)
if err != nil {
if err := r.limiter[e].RateLimit(ctx); err != nil {
return fmt.Errorf("cannot rate limit request %w for endpoint %d", err, e)
}
return nil
}
// GetRateLimiterDefinitions returns the rate limiter definitions for the
// requester
// GetRateLimiterDefinitions returns the rate limiter definitions for the requester.
func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions {
if r == nil {
return nil
@@ -133,47 +113,63 @@ func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions {
return r.limiter
}
// RateLimit is a function that will rate limit a request based on the rate
// limiter provided. It will return an error if the context is cancelled or
// deadline exceeded.
func RateLimit(ctx context.Context, rateLimiter *RateLimiterWithWeight) error {
if rateLimiter == nil {
return errSpecificRateLimiterIsNil
// RateLimit throttles a request based on weight, delaying the request.
// Errors if no delay is permitted via the context and a delay is required.
func (r *RateLimiterWithWeight) RateLimit(ctx context.Context) error {
if err := common.NilGuard(r); err != nil {
return err
}
if rateLimiter.Weight <= 0 {
return errInvalidWeightCount
r.m.Lock()
if r.weight == 0 {
r.m.Unlock()
return errInvalidWeight
}
var finalDelay time.Duration
reservations := make(Reservations, rateLimiter.Weight)
for i := Weight(0); i < rateLimiter.Weight; i++ {
// Consume 1 weight at a time as this avoids needing burst capacity in the limiter,
// which would otherwise allow the rate limit to be exceeded over short periods
reservations[i] = rateLimiter.Reserve()
finalDelay = reservations[i].Delay()
tn := time.Now()
reserved := make([]*rate.Reservation, 0, r.weight)
for range r.weight {
// This avoids needing burst capacity in the limiter, which would otherwise allow the rate limit to be exceeded over short periods
reserved = append(reserved, r.limiter.ReserveN(tn, 1))
}
finalDelay := reserved[len(reserved)-1].DelayFrom(tn)
if dl, ok := ctx.Deadline(); ok && dl.Before(time.Now().Add(finalDelay)) {
reservations.CancelAll()
return fmt.Errorf("rate limit delay of %s will exceed deadline: %w",
finalDelay,
context.DeadlineExceeded)
}
tick := time.NewTimer(finalDelay)
select {
case <-tick.C:
if finalDelay == 0 {
r.m.Unlock()
return nil
}
if hasDelayNotAllowed(ctx) {
cancelAll(reserved, tn)
r.m.Unlock()
return ErrDelayNotAllowed
}
if dl, ok := ctx.Deadline(); ok && dl.Before(tn.Add(finalDelay)) {
cancelAll(reserved, tn)
r.m.Unlock()
return fmt.Errorf("rate limit delay of %s will exceed deadline: %w", finalDelay, context.DeadlineExceeded)
}
r.m.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(finalDelay):
return nil
case <-ctx.Done():
tick.Stop()
reservations.CancelAll()
return ctx.Err()
}
// TODO: Shutdown case
}
// DisableRateLimiter disables the rate limiting system for the exchange
// cancelAll cancels all reservations at a specific time.
// Does not provide locking protection, so callers can maintain a single lock throughout.
func cancelAll(reservations []*rate.Reservation, at time.Time) {
slices.Reverse(reservations) // cancel in reverse order for correct token reimbursement
for _, r := range reservations {
r.CancelAt(at)
}
}
// DisableRateLimiter disables the rate limiting system for the exchange.
func (r *Requester) DisableRateLimiter() error {
if r == nil {
return ErrRequestSystemIsNil
@@ -184,7 +180,7 @@ func (r *Requester) DisableRateLimiter() error {
return nil
}
// EnableRateLimiter enables the rate limiting system for the exchange
// EnableRateLimiter enables the rate limiting system for the exchange.
func (r *Requester) EnableRateLimiter() error {
if r == nil {
return ErrRequestSystemIsNil
@@ -194,3 +190,15 @@ func (r *Requester) EnableRateLimiter() error {
}
return nil
}
type delayNotAllowedKey struct{}
// WithDelayNotAllowed adds a value to the context that indicates that no delay is allowed for rate limiting.
func WithDelayNotAllowed(ctx context.Context) context.Context {
return context.WithValue(ctx, delayNotAllowedKey{}, struct{}{})
}
func hasDelayNotAllowed(ctx context.Context) bool {
_, ok := ctx.Value(delayNotAllowedKey{}).(struct{})
return ok
}