diff --git a/exchange/websocket/connection.go b/exchange/websocket/connection.go index c5b9376b..a237916b 100644 --- a/exchange/websocket/connection.go +++ b/exchange/websocket/connection.go @@ -204,7 +204,7 @@ func (c *connection) writeToConn(ctx context.Context, epl request.EndpointLimit, } if rl != nil { - if err := request.RateLimit(ctx, rl); err != nil { + if err := rl.RateLimit(ctx); err != nil { return fmt.Errorf("%s websocket connection: rate limit error: %w", c.ExchangeName, err) } } diff --git a/exchanges/request/limit.go b/exchanges/request/limit.go index 539e31a8..cab633d7 100644 --- a/exchanges/request/limit.go +++ b/exchanges/request/limit.go @@ -4,66 +4,54 @@ import ( "context" "errors" "fmt" + "slices" + "sync" "sync/atomic" "time" + "github.com/thrasher-corp/gocryptotrader/common" "golang.org/x/time/rate" ) -// Defines rate limiting errors +// Rate limiting errors. var ( ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled") ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled") + ErrDelayNotAllowed = errors.New("delay not allowed") - errLimiterSystemIsNil = errors.New("limiter system is nil") - errInvalidWeightCount = errors.New("invalid weight count must equal or greater than 1") - errSpecificRateLimiterIsNil = errors.New("specific rate limiter is nil") + errInvalidWeight = errors.New("weight must be equal-or-greater than 1") ) -// RateLimitNotRequired is a no-op rate limiter +// RateLimitNotRequired is a no-op rate limiter. var RateLimitNotRequired *RateLimiterWithWeight -// Const here define individual functionality sub types for rate limiting +// Const here define individual functionality sub types for rate limiting. const ( Unset EndpointLimit = iota Auth UnAuth ) -// EndpointLimit defines individual endpoint rate limits that are set when -// New is called. +// EndpointLimit defines individual endpoint rate limits. type EndpointLimit uint16 -// Weight defines the number of reservations to be used. This is a generalised -// weight for rate limiting. e.g. n weight = n request. i.e. 50 Weight = 50 -// requests. +// Weight defines the number of reservations to be used. This is a generalised weight for rate limiting. +// e.g. n weight = n request. i.e. 50 Weight = 50 requests. type Weight uint8 -// RateLimitDefinitions is a map of endpoint limits to rate limiters +// RateLimitDefinitions is a map of endpoint limits to rate limiters. type RateLimitDefinitions map[any]*RateLimiterWithWeight -// RateLimiterWithWeight is a rate limiter coupled with a weight count which -// refers to the number or weighting of the request. This is used to define -// the rate limit for a specific endpoint. +// RateLimiterWithWeight is a rate limiter coupled with a weight which refers to the number or weighting of the request. +// This is used to define the rate limit for a specific endpoint. type RateLimiterWithWeight struct { - *rate.Limiter - Weight + limiter *rate.Limiter + weight Weight + m sync.Mutex } -// Reservations is a slice of rate reservations -type Reservations []*rate.Reservation - -// CancelAll cancels all potential reservations to free up rate limiter for -// context cancellations and deadline exceeded cases. -func (r Reservations) CancelAll() { - for x := range r { - r[x].Cancel() - } -} - -// NewRateLimit creates a new RateLimit based of time interval and how many -// actions allowed and breaks it down to an actions-per-second basis -- Burst -// rate is kept as one as this is not supported for out-bound requests. +// NewRateLimit creates a new RateLimit based of time interval and how many actions allowed and breaks it down to an +// actions-per-second basis -- Burst rate is kept as one as this is not supported for out-bound requests. func NewRateLimit(interval time.Duration, actions int) *rate.Limiter { if actions <= 0 || interval <= 0 { // Returns an un-restricted rate limiter @@ -75,34 +63,32 @@ func NewRateLimit(interval time.Duration, actions int) *rate.Limiter { return rate.NewLimiter(rate.Limit(rps), 1) } -// NewRateLimitWithWeight creates a new RateLimit based of time interval and how -// many actions allowed. This also has a weight count which refers to the number -// or weighting of the request. This is used to define the rate limit for a +// NewRateLimitWithWeight creates a new RateLimit based of time interval and how many actions allowed. This also has a +// weight count which refers to the number or weighting of the request. This is used to define the rate limit for a // specific endpoint. func NewRateLimitWithWeight(interval time.Duration, actions int, weight Weight) *RateLimiterWithWeight { return GetRateLimiterWithWeight(NewRateLimit(interval, actions), weight) } -// NewWeightedRateLimitByDuration creates a new RateLimit based of time -// interval. This equates to 1 action per interval. The weight is set to 1. +// NewWeightedRateLimitByDuration creates a new RateLimit based of time interval. This equates to 1 action per interval. +// The weight is set to 1. func NewWeightedRateLimitByDuration(interval time.Duration) *RateLimiterWithWeight { return NewRateLimitWithWeight(interval, 1, 1) } -// GetRateLimiterWithWeight couples a rate limiter with a weight count into an -// accepted defined rate limiter with weight struct +// GetRateLimiterWithWeight couples a rate limiter with a weight count into an accepted defined rate limiter with weight +// struct. func GetRateLimiterWithWeight(l *rate.Limiter, weight Weight) *RateLimiterWithWeight { - return &RateLimiterWithWeight{l, weight} + return &RateLimiterWithWeight{limiter: l, weight: weight} } -// NewBasicRateLimit returns an object that implements the limiter interface -// for basic rate limit +// NewBasicRateLimit returns an object that implements the limiter interface for basic rate limit. func NewBasicRateLimit(interval time.Duration, actions int, weight Weight) RateLimitDefinitions { rl := NewRateLimitWithWeight(interval, actions, weight) return RateLimitDefinitions{Unset: rl, Auth: rl, UnAuth: rl} } -// InitiateRateLimit sleeps for designated end point rate limits +// InitiateRateLimit sleeps for designated end point rate limits. func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error { if r == nil { return ErrRequestSystemIsNil @@ -110,22 +96,16 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro if atomic.LoadInt32(&r.disableRateLimiter) == 1 { return nil } - if r.limiter == nil { - return fmt.Errorf("cannot rate limit request %w", errLimiterSystemIsNil) + if err := common.NilGuard(r.limiter); err != nil { + return err } - - rateLimiter := r.limiter[e] - - err := RateLimit(ctx, rateLimiter) - if err != nil { + if err := r.limiter[e].RateLimit(ctx); err != nil { return fmt.Errorf("cannot rate limit request %w for endpoint %d", err, e) } - return nil } -// GetRateLimiterDefinitions returns the rate limiter definitions for the -// requester +// GetRateLimiterDefinitions returns the rate limiter definitions for the requester. func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions { if r == nil { return nil @@ -133,47 +113,63 @@ func (r *Requester) GetRateLimiterDefinitions() RateLimitDefinitions { return r.limiter } -// RateLimit is a function that will rate limit a request based on the rate -// limiter provided. It will return an error if the context is cancelled or -// deadline exceeded. -func RateLimit(ctx context.Context, rateLimiter *RateLimiterWithWeight) error { - if rateLimiter == nil { - return errSpecificRateLimiterIsNil +// RateLimit throttles a request based on weight, delaying the request. +// Errors if no delay is permitted via the context and a delay is required. +func (r *RateLimiterWithWeight) RateLimit(ctx context.Context) error { + if err := common.NilGuard(r); err != nil { + return err } - if rateLimiter.Weight <= 0 { - return errInvalidWeightCount + r.m.Lock() + if r.weight == 0 { + r.m.Unlock() + return errInvalidWeight } - var finalDelay time.Duration - reservations := make(Reservations, rateLimiter.Weight) - for i := Weight(0); i < rateLimiter.Weight; i++ { - // Consume 1 weight at a time as this avoids needing burst capacity in the limiter, - // which would otherwise allow the rate limit to be exceeded over short periods - reservations[i] = rateLimiter.Reserve() - finalDelay = reservations[i].Delay() + tn := time.Now() + reserved := make([]*rate.Reservation, 0, r.weight) + for range r.weight { + // This avoids needing burst capacity in the limiter, which would otherwise allow the rate limit to be exceeded over short periods + reserved = append(reserved, r.limiter.ReserveN(tn, 1)) } + finalDelay := reserved[len(reserved)-1].DelayFrom(tn) - if dl, ok := ctx.Deadline(); ok && dl.Before(time.Now().Add(finalDelay)) { - reservations.CancelAll() - return fmt.Errorf("rate limit delay of %s will exceed deadline: %w", - finalDelay, - context.DeadlineExceeded) - } - - tick := time.NewTimer(finalDelay) - select { - case <-tick.C: + if finalDelay == 0 { + r.m.Unlock() + return nil + } + + if hasDelayNotAllowed(ctx) { + cancelAll(reserved, tn) + r.m.Unlock() + return ErrDelayNotAllowed + } + + if dl, ok := ctx.Deadline(); ok && dl.Before(tn.Add(finalDelay)) { + cancelAll(reserved, tn) + r.m.Unlock() + return fmt.Errorf("rate limit delay of %s will exceed deadline: %w", finalDelay, context.DeadlineExceeded) + } + r.m.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(finalDelay): return nil - case <-ctx.Done(): - tick.Stop() - reservations.CancelAll() - return ctx.Err() } - // TODO: Shutdown case } -// DisableRateLimiter disables the rate limiting system for the exchange +// cancelAll cancels all reservations at a specific time. +// Does not provide locking protection, so callers can maintain a single lock throughout. +func cancelAll(reservations []*rate.Reservation, at time.Time) { + slices.Reverse(reservations) // cancel in reverse order for correct token reimbursement + for _, r := range reservations { + r.CancelAt(at) + } +} + +// DisableRateLimiter disables the rate limiting system for the exchange. func (r *Requester) DisableRateLimiter() error { if r == nil { return ErrRequestSystemIsNil @@ -184,7 +180,7 @@ func (r *Requester) DisableRateLimiter() error { return nil } -// EnableRateLimiter enables the rate limiting system for the exchange +// EnableRateLimiter enables the rate limiting system for the exchange. func (r *Requester) EnableRateLimiter() error { if r == nil { return ErrRequestSystemIsNil @@ -194,3 +190,15 @@ func (r *Requester) EnableRateLimiter() error { } return nil } + +type delayNotAllowedKey struct{} + +// WithDelayNotAllowed adds a value to the context that indicates that no delay is allowed for rate limiting. +func WithDelayNotAllowed(ctx context.Context) context.Context { + return context.WithValue(ctx, delayNotAllowedKey{}, struct{}{}) +} + +func hasDelayNotAllowed(ctx context.Context) bool { + _, ok := ctx.Value(delayNotAllowedKey{}).(struct{}) + return ok +} diff --git a/exchanges/request/limit_test.go b/exchanges/request/limit_test.go new file mode 100644 index 00000000..af244c3d --- /dev/null +++ b/exchanges/request/limit_test.go @@ -0,0 +1,252 @@ +package request + +import ( + "context" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "golang.org/x/time/rate" +) + +func TestRateLimit(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive + err := (*RateLimiterWithWeight)(nil).RateLimit(t.Context()) + assert.ErrorContains(t, err, "nil pointer: *request.RateLimiterWithWeight") + + r := &RateLimiterWithWeight{limiter: rate.NewLimiter(rate.Limit(1), 1)} + err = r.RateLimit(t.Context()) + assert.ErrorIs(t, err, errInvalidWeight, "should return errInvalidWeightCount for zero weight") + + r = NewRateLimitWithWeight(time.Second, 10, 1) + start := time.Now() + err = r.RateLimit(t.Context()) + elapsed := time.Since(start) + require.NoError(t, err, "rate limit must not error") + assert.Zero(t, elapsed, "first call should be immediate") + + r = NewRateLimitWithWeight(time.Second, 10, 5) + start = time.Now() + err = r.RateLimit(t.Context()) + elapsed = time.Since(start) + require.NoError(t, err, "rate limit must not error") + assert.Equal(t, 400*time.Millisecond, elapsed, "should wait 400ms (4 intervals) for weight 5") + + r = NewRateLimitWithWeight(100*time.Millisecond, 1, 1) + start = time.Now() + err = r.RateLimit(WithDelayNotAllowed(t.Context())) + synctest.Wait() + elapsed = time.Since(start) + require.NoError(t, err, "first rate limit call must not error and must be immediate") + assert.Zero(t, elapsed, "first call should be immediate") + + start = time.Now() + err = r.RateLimit(t.Context()) + elapsed = time.Since(start) + require.NoError(t, err, "second rate limit call must not error") + assert.Equal(t, 100*time.Millisecond, elapsed, "second call should be delayed by exactly 100ms") + + err = r.RateLimit(WithDelayNotAllowed(t.Context())) + assert.ErrorIs(t, err, ErrDelayNotAllowed, "should return correct error") + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + err = r.RateLimit(ctx) + assert.ErrorIs(t, err, context.Canceled, "should return correct error when context is canceled") + + // Rate limit is 100ms. Set deadline for 50ms. + ctx, cancel = context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + err = r.RateLimit(ctx) + assert.ErrorIs(t, err, context.DeadlineExceeded, "should return correct error when context deadline exceeded") + }) +} + +func TestRateLimit_Concurrent_WithFailure(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive + r := NewRateLimitWithWeight(time.Second, 10, 1) + tn := time.Now() + errs := common.ErrorCollector{} + for i := range 10 { + ctx := t.Context() + if i%2 == 0 { + ctx = WithDelayNotAllowed(ctx) + } + errs.Go(func() error { return r.RateLimit(ctx) }) + } + + require.ErrorContains(t, errs.Collect(), "delay not allowed, delay not allowed, delay not allowed, delay not allowed", "must return correct error") + assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time") + }) +} + +func TestRateLimit_Concurrent(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive + r := NewRateLimitWithWeight(time.Second, 10, 1) + tn := time.Now() + errs := common.ErrorCollector{} + for range 10 { + errs.Go(func() error { return r.RateLimit(t.Context()) }) + } + require.NoError(t, errs.Collect(), "rate limit must not error") + assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time") + }) +} + +func TestRateLimit_Linear_WithFailure(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive + r := NewRateLimitWithWeight(time.Second, 10, 1) + tn := time.Now() + for i := range 10 { + ctx := t.Context() + if i%2 == 0 { + ctx = WithDelayNotAllowed(ctx) + } + if err := r.RateLimit(ctx); err != nil { + require.ErrorIs(t, err, ErrDelayNotAllowed, "must return correct error") + } + } + assert.Less(t, time.Since(tn), time.Millisecond*600, "should complete within reasonable time") + }) +} + +func TestRateLimit_Linear(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { //nolint:thelper,nolintlint // false positive + r := NewRateLimitWithWeight(time.Second, 10, 1) + tn := time.Now() + for range 10 { + require.NoError(t, r.RateLimit(t.Context())) + } + assert.Less(t, time.Since(tn), time.Second, "should complete within reasonable time") + }) +} + +func TestNewRateLimit(t *testing.T) { + t.Parallel() + + r := NewRateLimit(time.Second, 10) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, rate.Limit(10), r.Limit(), "limit should be 10 per second") + + r = NewRateLimit(time.Second, 0) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero actions") + + r = NewRateLimit(time.Second, -1) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative actions") + + r = NewRateLimit(0, 10) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on zero interval") + + r = NewRateLimit(-time.Second, 10) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, rate.Inf, r.Limit(), "limit should be infinite on negative interval") +} + +func TestNewRateLimitWithWeight(t *testing.T) { + t.Parallel() + + r := NewRateLimitWithWeight(time.Second, 10, 5) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, Weight(5), r.weight, "weight should be 5") + assert.Equal(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second") +} + +func TestNewWeightedRateLimitByDuration(t *testing.T) { + t.Parallel() + + r := NewWeightedRateLimitByDuration(time.Second) + require.NotNil(t, r, "limiter must not be nil") + assert.Equal(t, Weight(1), r.weight, "weight should be 1") + assert.Equal(t, rate.Limit(1), r.limiter.Limit(), "limit should be 1 per second") +} + +func TestGetRateLimiterWithWeight(t *testing.T) { + t.Parallel() + + r := rate.NewLimiter(rate.Limit(10), 1) + weighted := GetRateLimiterWithWeight(r, 5) + require.NotNil(t, weighted, "weighted limiter must not be nil") + assert.Equal(t, Weight(5), weighted.weight, "weight should be 5") + assert.Equal(t, r, weighted.limiter, "should reference same limiter") +} + +func TestNewBasicRateLimit(t *testing.T) { + t.Parallel() + + defs := NewBasicRateLimit(time.Second, 10, 5) + require.NotNil(t, defs, "definitions must not be nil") + require.Len(t, defs, 3, "must have 3 definitions") + + for _, key := range []EndpointLimit{Unset, Auth, UnAuth} { + r, ok := defs[key] + require.Truef(t, ok, "must have definition for %v", key) + assert.Equalf(t, Weight(5), r.weight, "weight should be 5 for %v", key) + assert.Equalf(t, rate.Limit(10), r.limiter.Limit(), "limit should be 10 per second for %v", key) + } + + assert.Same(t, defs[Unset], defs[Auth], "Unset and Auth should be same instance") + assert.Same(t, defs[Auth], defs[UnAuth], "Auth and UnAuth should be same instance") +} + +func TestWithDelayNotAllowed(t *testing.T) { + t.Parallel() + + assert.True(t, hasDelayNotAllowed(WithDelayNotAllowed(t.Context()))) + assert.False(t, hasDelayNotAllowed(t.Context())) + assert.False(t, hasDelayNotAllowed(WithVerbose(t.Context()))) +} + +func TestCancelAll(t *testing.T) { + t.Parallel() + + var reservations []*rate.Reservation + cancelAll(reservations, time.Now()) + + r := rate.NewLimiter(rate.Limit(1), 1) + tn := time.Now() + reservations = append(reservations, r.ReserveN(tn, 1)) + require.Equal(t, 0.0, r.TokensAt(tn), "must have zero tokens remaining") + reservations = append(reservations, r.ReserveN(tn, 1)) + require.Equal(t, time.Second, reservations[1].DelayFrom(tn), "second reservation must have 1 second delay") + require.Equal(t, -1.0, r.TokensAt(tn), "must have negative tokens remaining") + cancelAll(reservations, tn) + require.Equal(t, 1.0, r.TokensAt(tn), "must have 1 token remaining after cancellation") +} + +func TestInitiateRateLimit(t *testing.T) { + t.Parallel() + + var r *Requester + err := r.InitiateRateLimit(t.Context(), Unset) + assert.ErrorIs(t, err, ErrRequestSystemIsNil, "should return correct error") + + r = &Requester{} + atomic.StoreInt32(&r.disableRateLimiter, 1) + err = r.InitiateRateLimit(t.Context(), Unset) + assert.NoError(t, err, "should not error when rate limiter is disabled") + + atomic.StoreInt32(&r.disableRateLimiter, 0) + err = r.InitiateRateLimit(t.Context(), Unset) + assert.ErrorContains(t, err, "nil pointer: request.RateLimitDefinitions", "should return correct error when limiter is nil") + + r.limiter = NewBasicRateLimit(time.Second, 10, 1) + err = r.InitiateRateLimit(t.Context(), Unset) + assert.NoError(t, err, "should not error on valid rate limit initiation") +} diff --git a/exchanges/request/request_test.go b/exchanges/request/request_test.go index 1cea4e58..24307ff2 100644 --- a/exchanges/request/request_test.go +++ b/exchanges/request/request_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/exchanges/nonce" - "golang.org/x/time/rate" ) const unexpected = "unexpected values" @@ -54,7 +53,7 @@ func TestMain(m *testing.M) { w.WriteHeader(http.StatusGatewayTimeout) }) sm.HandleFunc("/rate", func(w http.ResponseWriter, _ *http.Request) { - if !serverLimit.Allow() { + if !serverLimit.limiter.Allow() { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) @@ -70,7 +69,7 @@ func TestMain(m *testing.M) { } }) sm.HandleFunc("/rate-retry", func(w http.ResponseWriter, _ *http.Request) { - if !serverLimitRetry.Allow() { + if !serverLimitRetry.limiter.Allow() { w.Header().Add("Retry-After", strconv.Itoa(int(math.Round(serverLimitInterval.Seconds())))) http.Error(w, http.StatusText(http.StatusTooManyRequests), @@ -102,31 +101,6 @@ func TestMain(m *testing.M) { os.Exit(issues) } -func TestNewRateLimitWithWeight(t *testing.T) { - t.Parallel() - r := NewRateLimitWithWeight(time.Second*10, 5, 1) - if r.Limit() != 0.5 { - t.Fatal(unexpected) - } - - // Ensures rate limiting factor is the same - r = NewRateLimitWithWeight(time.Second*2, 1, 1) - if r.Limit() != 0.5 { - t.Fatal(unexpected) - } - - // Test for open rate limit - r = NewRateLimitWithWeight(time.Second*2, 0, 1) - if r.Limit() != rate.Inf { - t.Fatal(unexpected) - } - - r = NewRateLimitWithWeight(0, 69, 1) - if r.Limit() != rate.Inf { - t.Fatal(unexpected) - } -} - func TestCheckRequest(t *testing.T) { t.Parallel() @@ -235,7 +209,7 @@ func TestDoRequest(t *testing.T) { // Invalid/missing endpoint limit err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &Item{Path: testURL}, nil }, UnauthenticatedRequest) - require.ErrorIs(t, err, errSpecificRateLimiterIsNil) + require.ErrorIs(t, err, common.ErrNilPointer) // Force debug err = r.SendPayload(ctx, UnAuth, func() (*Item, error) {