mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-06-07 15:11:03 +00:00
exchanges/request: abstract and consolidate rate limiting code to request package (#1477)
* initial consolidation of rate limiting code to request package to reduce bespoke code implementation * continued * finish abstraction * lint * exchanges: fix tests * linter: fix * poloniex: fix auth rate limit not being set * ratelimiter: convert from token to weight * glorious: nits addressed with fire * linter: rip * change func name set -> get * fix test * derbit: impl --------- Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
@@ -14,6 +14,10 @@ import (
|
||||
var (
|
||||
ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled")
|
||||
ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled")
|
||||
|
||||
errLimiterSystemIsNil = errors.New("limiter system is nil")
|
||||
errInvalidWeightCount = errors.New("invalid weight count must equal or greater than 1")
|
||||
errSpecificRateLimiterIsNil = errors.New("specific rate limiter is nil")
|
||||
)
|
||||
|
||||
// Const here define individual functionality sub types for rate limiting
|
||||
@@ -23,26 +27,35 @@ const (
|
||||
UnAuth
|
||||
)
|
||||
|
||||
// BasicLimit denotes basic rate limit that implements the Limiter interface
|
||||
// does not need to set endpoint functionality.
|
||||
type BasicLimit struct {
|
||||
r *rate.Limiter
|
||||
}
|
||||
|
||||
// Limit executes a single rate limit set by NewRateLimit
|
||||
func (b *BasicLimit) Limit(ctx context.Context, _ EndpointLimit) error {
|
||||
return b.r.Wait(ctx)
|
||||
}
|
||||
|
||||
// EndpointLimit defines individual endpoint rate limits that are set when
|
||||
// New is called.
|
||||
type EndpointLimit int
|
||||
type EndpointLimit uint16
|
||||
|
||||
// Limiter interface groups rate limit functionality defined in the REST
|
||||
// wrapper for extended rate limiting configuration i.e. Shells of rate
|
||||
// limits with a global rate for sub rates.
|
||||
type Limiter interface {
|
||||
Limit(context.Context, EndpointLimit) error
|
||||
// Weight defines the number of reservations to be used. This is a generalised
|
||||
// weight for rate limiting. e.g. n weight = n request. i.e. 50 Weight = 50
|
||||
// requests.
|
||||
type Weight uint8
|
||||
|
||||
// RateLimitDefinitions is a map of endpoint limits to rate limiters
|
||||
type RateLimitDefinitions map[interface{}]*RateLimiterWithWeight
|
||||
|
||||
// RateLimiterWithWeight is a rate limiter coupled with a weight count which
|
||||
// refers to the number or weighting of the request. This is used to define
|
||||
// the rate limit for a specific endpoint.
|
||||
type RateLimiterWithWeight struct {
|
||||
*rate.Limiter
|
||||
Weight
|
||||
}
|
||||
|
||||
// Reservations is a slice of rate reservations
|
||||
type Reservations []*rate.Reservation
|
||||
|
||||
// CancelAll cancels all potential reservations to free up rate limiter for
|
||||
// context cancellations and deadline exceeded cases.
|
||||
func (r Reservations) CancelAll() {
|
||||
for x := range r {
|
||||
r[x].Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// NewRateLimit creates a new RateLimit based of time interval and how many
|
||||
@@ -59,10 +72,25 @@ func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Limit(rps), 1)
|
||||
}
|
||||
|
||||
// NewRateLimitWithWeight creates a new RateLimit based of time interval and how
|
||||
// many actions allowed. This also has a weight count which refers to the number
|
||||
// or weighting of the request. This is used to define the rate limit for a
|
||||
// specific endpoint.
|
||||
func NewRateLimitWithWeight(interval time.Duration, actions int, weight Weight) *RateLimiterWithWeight {
|
||||
return GetRateLimiterWithWeight(NewRateLimit(interval, actions), weight)
|
||||
}
|
||||
|
||||
// GetRateLimiterWithWeight couples a rate limiter with a weight count into an
|
||||
// accepted defined rate limiter with weight struct
|
||||
func GetRateLimiterWithWeight(l *rate.Limiter, weight Weight) *RateLimiterWithWeight {
|
||||
return &RateLimiterWithWeight{l, weight}
|
||||
}
|
||||
|
||||
// NewBasicRateLimit returns an object that implements the limiter interface
|
||||
// for basic rate limit
|
||||
func NewBasicRateLimit(interval time.Duration, actions int) Limiter {
|
||||
return &BasicLimit{NewRateLimit(interval, actions)}
|
||||
func NewBasicRateLimit(interval time.Duration, actions int, weight Weight) RateLimitDefinitions {
|
||||
rl := NewRateLimitWithWeight(interval, actions, weight)
|
||||
return RateLimitDefinitions{Unset: rl, Auth: rl, UnAuth: rl}
|
||||
}
|
||||
|
||||
// InitiateRateLimit sleeps for designated end point rate limits
|
||||
@@ -73,12 +101,46 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro
|
||||
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.limiter != nil {
|
||||
return r.limiter.Limit(ctx, e)
|
||||
if r.limiter == nil {
|
||||
return fmt.Errorf("cannot rate limit request %w", errLimiterSystemIsNil)
|
||||
}
|
||||
|
||||
return nil
|
||||
rateLimiter := r.limiter[e]
|
||||
|
||||
if rateLimiter == nil {
|
||||
return fmt.Errorf("cannot rate limit request %w for endpoint %d", errSpecificRateLimiterIsNil, e)
|
||||
}
|
||||
|
||||
if rateLimiter.Weight <= 0 {
|
||||
return fmt.Errorf("cannot rate limit request %w for endpoint %d", errInvalidWeightCount, e)
|
||||
}
|
||||
|
||||
var finalDelay time.Duration
|
||||
var reservations = make(Reservations, rateLimiter.Weight)
|
||||
for i := Weight(0); i < rateLimiter.Weight; i++ {
|
||||
// Consume 1 weight at a time as this avoids needing burst capacity in the limiter,
|
||||
// which would otherwise allow the rate limit to be exceeded over short periods
|
||||
reservations[i] = rateLimiter.Reserve()
|
||||
finalDelay = reservations[i].Delay()
|
||||
}
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok && dl.Before(time.Now().Add(finalDelay)) {
|
||||
reservations.CancelAll()
|
||||
return fmt.Errorf("rate limit delay of %s will exceed deadline: %w",
|
||||
finalDelay,
|
||||
context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
tick := time.NewTimer(finalDelay)
|
||||
select {
|
||||
case <-tick.C:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
tick.Stop()
|
||||
reservations.CancelAll()
|
||||
return ctx.Err()
|
||||
}
|
||||
// TODO: Shutdown case
|
||||
}
|
||||
|
||||
// DisableRateLimiter disables the rate limiting system for the exchange
|
||||
|
||||
@@ -8,9 +8,9 @@ func WithBackoff(b Backoff) RequesterOption {
|
||||
}
|
||||
|
||||
// WithLimiter configures the rate limiter for a Requester.
|
||||
func WithLimiter(l Limiter) RequesterOption {
|
||||
func WithLimiter(def RateLimitDefinitions) RequesterOption {
|
||||
return func(r *Requester) {
|
||||
r.limiter = l
|
||||
r.limiter = def
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -149,10 +149,12 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
default:
|
||||
}
|
||||
|
||||
// Initiate a rate limit reservation and sleep on requested endpoint
|
||||
err := r.InitiateRateLimit(ctx, endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rate limit HTTP request: %w", err)
|
||||
if r.limiter != nil {
|
||||
// Initiate a rate limit reservation and sleep on requested endpoint
|
||||
err := r.InitiateRateLimit(ctx, endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rate limit HTTP request: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p, err := newRequest()
|
||||
@@ -231,7 +233,15 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
log.Errorf(log.RequestSys, "%s request has failed. Retrying request in %s, attempt %d", r.name, delay, attempt)
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
if delay > 0 {
|
||||
// Allow for context cancellation while delaying the retry.
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package request
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
@@ -28,12 +27,12 @@ import (
|
||||
const unexpected = "unexpected values"
|
||||
|
||||
var testURL string
|
||||
var serverLimit *rate.Limiter
|
||||
var serverLimit *RateLimiterWithWeight
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
serverLimitInterval := time.Millisecond * 500
|
||||
serverLimit = NewRateLimit(serverLimitInterval, 1)
|
||||
serverLimitRetry := NewRateLimit(serverLimitInterval, 1)
|
||||
serverLimit = NewRateLimitWithWeight(serverLimitInterval, 1, 1)
|
||||
serverLimitRetry := NewRateLimitWithWeight(serverLimitInterval, 1, 1)
|
||||
sm := http.NewServeMux()
|
||||
sm.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -102,26 +101,26 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(issues)
|
||||
}
|
||||
|
||||
func TestNewRateLimit(t *testing.T) {
|
||||
func TestNewRateLimitWithWeight(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := NewRateLimit(time.Second*10, 5)
|
||||
r := NewRateLimitWithWeight(time.Second*10, 5, 1)
|
||||
if r.Limit() != 0.5 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
// Ensures rate limiting factor is the same
|
||||
r = NewRateLimit(time.Second*2, 1)
|
||||
r = NewRateLimitWithWeight(time.Second*2, 1, 1)
|
||||
if r.Limit() != 0.5 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
// Test for open rate limit
|
||||
r = NewRateLimit(time.Second*2, 0)
|
||||
r = NewRateLimitWithWeight(time.Second*2, 0, 1)
|
||||
if r.Limit() != rate.Inf {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
r = NewRateLimit(0, 69)
|
||||
r = NewRateLimitWithWeight(0, 69, 1)
|
||||
if r.Limit() != rate.Inf {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
@@ -201,39 +200,13 @@ func TestCheckRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type GlobalLimitTest struct {
|
||||
Auth *rate.Limiter
|
||||
UnAuth *rate.Limiter
|
||||
}
|
||||
|
||||
var errEndpointLimitNotFound = errors.New("endpoint limit not found")
|
||||
|
||||
func (g *GlobalLimitTest) Limit(ctx context.Context, e EndpointLimit) error {
|
||||
switch e {
|
||||
case Auth:
|
||||
if g.Auth == nil {
|
||||
return errors.New("auth rate not set")
|
||||
}
|
||||
return g.Auth.Wait(ctx)
|
||||
case UnAuth:
|
||||
if g.UnAuth == nil {
|
||||
return errors.New("unauth rate not set")
|
||||
}
|
||||
return g.UnAuth.Wait(ctx)
|
||||
default:
|
||||
return fmt.Errorf("cannot execute functionality: %d %w",
|
||||
e,
|
||||
errEndpointLimitNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
var globalshell = GlobalLimitTest{
|
||||
Auth: NewRateLimit(time.Millisecond*600, 1),
|
||||
UnAuth: NewRateLimit(time.Second*1, 100)}
|
||||
var globalshell = RateLimitDefinitions{
|
||||
Auth: NewRateLimitWithWeight(time.Millisecond*600, 1, 1),
|
||||
UnAuth: NewRateLimitWithWeight(time.Second*1, 100, 1)}
|
||||
|
||||
func TestDoRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
r, err := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
r, err := New("test", new(http.Client), WithLimiter(globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -270,13 +243,9 @@ func TestDoRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Invalid/missing endpoint limit
|
||||
err = r.SendPayload(ctx, Unset, func() (*Item, error) {
|
||||
return &Item{
|
||||
Path: testURL,
|
||||
}, nil
|
||||
}, UnauthenticatedRequest)
|
||||
if !errors.Is(err, errEndpointLimitNotFound) {
|
||||
t.Fatalf("expected: %v but received: %v", errEndpointLimitNotFound, err)
|
||||
err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &Item{Path: testURL}, nil }, UnauthenticatedRequest)
|
||||
if !errors.Is(err, errSpecificRateLimiterIsNil) {
|
||||
t.Fatalf("expected: %v but received: %v", errSpecificRateLimiterIsNil, err)
|
||||
}
|
||||
|
||||
// Force debug
|
||||
@@ -497,7 +466,7 @@ func TestDoRequest_NotRetryable(t *testing.T) {
|
||||
|
||||
func TestGetNonce(t *testing.T) {
|
||||
t.Parallel()
|
||||
r, err := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
r, err := New("test", new(http.Client), WithLimiter(globalshell))
|
||||
require.NoError(t, err)
|
||||
n1 := r.GetNonce(nonce.Unix)
|
||||
assert.NotZero(t, n1)
|
||||
@@ -505,7 +474,7 @@ func TestGetNonce(t *testing.T) {
|
||||
assert.NotZero(t, n2)
|
||||
assert.NotEqual(t, n1, n2)
|
||||
|
||||
r2, err := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
r2, err := New("test", new(http.Client), WithLimiter(globalshell))
|
||||
require.NoError(t, err)
|
||||
n3 := r2.GetNonce(nonce.UnixNano)
|
||||
assert.NotZero(t, n3)
|
||||
@@ -520,7 +489,7 @@ func TestGetNonce(t *testing.T) {
|
||||
// 40532461 30.29 ns/op 0 B/op 0 allocs/op (prev)
|
||||
// 45329203 26.53 ns/op 0 B/op 0 allocs/op
|
||||
func BenchmarkGetNonce(b *testing.B) {
|
||||
r, err := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
r, err := New("test", new(http.Client), WithLimiter(globalshell))
|
||||
require.NoError(b, err)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -536,9 +505,7 @@ func TestSetProxy(t *testing.T) {
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
r, err = New("test",
|
||||
&http.Client{Transport: new(http.Transport)},
|
||||
WithLimiter(&globalshell))
|
||||
r, err = New("test", &http.Client{Transport: new(http.Transport)}, WithLimiter(globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -561,16 +528,11 @@ func TestSetProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBasicLimiter(t *testing.T) {
|
||||
r, err := New("test",
|
||||
new(http.Client),
|
||||
WithLimiter(NewBasicRateLimit(time.Second, 1)))
|
||||
r, err := New("test", new(http.Client), WithLimiter(NewBasicRateLimit(time.Second, 1, 1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
i := Item{
|
||||
Path: "http://www.google.com",
|
||||
Method: http.MethodGet,
|
||||
}
|
||||
i := Item{Path: "http://www.google.com", Method: http.MethodGet}
|
||||
ctx := context.Background()
|
||||
|
||||
tn := time.Now()
|
||||
@@ -595,9 +557,7 @@ func TestBasicLimiter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEnableDisableRateLimit(t *testing.T) {
|
||||
r, err := New("TestRequest",
|
||||
new(http.Client),
|
||||
WithLimiter(NewBasicRateLimit(time.Minute, 1)))
|
||||
r, err := New("TestRequest", new(http.Client), WithLimiter(NewBasicRateLimit(time.Minute, 1, 1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ var (
|
||||
// Requester struct for the request client
|
||||
type Requester struct {
|
||||
_HTTPClient *client
|
||||
limiter Limiter
|
||||
limiter RateLimitDefinitions
|
||||
reporter Reporter
|
||||
name string
|
||||
userAgent string
|
||||
|
||||
Reference in New Issue
Block a user