Request package update & rate limit system expansion (#413)

* Initial rework of rework of requester - WIP

* Implementing and checking rate limits - WIP

* implemented coinbene rate limiting shenanigans

* add in remaining WIP

* fixy

* use authenticated rate limit

* drop ceiling as this can be done with a counter later

* add functionality to struct

* purge config options for rate limiting so as to keep things minimal

* prepare futures and swap rate limiting for implementation

* Address linter issues

* Addressed nits, fixed race

* fix linter issue

* remove global var as this was only setting when newrequester was called

* moved rate limit functionality into its own file

* Update Bitfinex with correct rate limit and test endpoints (WIP)

* finish off bitfinex adjustments

* fixes

* fix linter issues

* slowed rate for coinbasepro

* drop rate limit for huobi as the doc times have intermittent 429 issues.

* Set MACOSX_DEPLOYMENT_TARGET to remove linking warning

* Addr Thrasher nits

* Addr glorious nits

* unexport do request function

* fixed nitorinos

* Fixed something I missed

* move disabled rate limiter into loadexchange and use interface functionality

* Add temp quick fix
This commit is contained in:
Ryan O'Hara-Reid
2020-02-06 11:44:28 +11:00
committed by GitHub
parent 4625ef9b94
commit 0a84c5d97a
103 changed files with 3906 additions and 2581 deletions

View File

@@ -0,0 +1,88 @@
package request
import (
"errors"
"sync/atomic"
"time"
"golang.org/x/time/rate"
)
// Const here define individual functionality sub types for rate limiting
const (
Unset EndpointLimit = iota
Auth
UnAuth
)
// BasicLimit denotes basic rate limit that implements the Limiter interface
// does not need to set endpoint functionality.
type BasicLimit struct {
r *rate.Limiter
}
// Limit executes a single rate limit set by NewRateLimit
func (b *BasicLimit) Limit(_ EndpointLimit) error {
time.Sleep(b.r.Reserve().Delay())
return nil
}
// EndpointLimit defines individual endpoint rate limits that are set when
// New is called.
type EndpointLimit int
// Limiter interface groups rate limit functionality defined in the REST
// wrapper for extended rate limiting configuration i.e. Shells of rate
// limits with a global rate for sub rates.
type Limiter interface {
Limit(EndpointLimit) error
}
// NewRateLimit creates a new RateLimit based of time interval and how many
// actions allowed and breaks it down to an actions-per-second basis -- Burst
// rate is kept as one as this is not supported for out-bound requests.
func NewRateLimit(interval time.Duration, actions int) *rate.Limiter {
if actions <= 0 || interval <= 0 {
// Returns an un-restricted rate limiter
return rate.NewLimiter(rate.Inf, 1)
}
i := 1 / interval.Seconds()
rps := i * float64(actions)
return rate.NewLimiter(rate.Limit(rps), 1)
}
// NewBasicRateLimit returns an object that implements the limiter interface
// for basic rate limit
func NewBasicRateLimit(interval time.Duration, actions int) Limiter {
return &BasicLimit{NewRateLimit(interval, actions)}
}
// InitiateRateLimit sleeps for designated end point rate limits
func (r *Requester) InitiateRateLimit(e EndpointLimit) error {
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
return nil
}
if r.Limiter != nil {
return r.Limiter.Limit(e)
}
return nil
}
// DisableRateLimiter disables the rate limiting system for the exchange
func (r *Requester) DisableRateLimiter() error {
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 0, 1) {
return errors.New("rate limiter already disabled")
}
return nil
}
// EnableRateLimiter enables the rate limiting system for the exchange
func (r *Requester) EnableRateLimiter() error {
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 1, 0) {
return errors.New("rate limiter already enabled")
}
return nil
}

View File

@@ -1,300 +1,156 @@
package request
import (
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/timedmutex"
"github.com/thrasher-corp/gocryptotrader/exchanges/mock"
"github.com/thrasher-corp/gocryptotrader/exchanges/nonce"
log "github.com/thrasher-corp/gocryptotrader/logger"
)
// NewRateLimit creates a new RateLimit
func NewRateLimit(d time.Duration, rate int) *RateLimit {
return &RateLimit{Duration: d, Rate: rate}
}
// String returns the rate limiter in string notation
func (r *RateLimit) String() string {
return fmt.Sprintf("Rate limiter set to %d requests per %v", r.Rate, r.Duration)
}
// GetRate returns the ratelimit rate
func (r *RateLimit) GetRate() int {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.Rate
}
// SetRate sets the ratelimit rate
func (r *RateLimit) SetRate(rate int) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Rate = rate
}
// GetRequests returns the number of requests for the ratelimit
func (r *RateLimit) GetRequests() int {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.Requests
}
// SetRequests sets requests counter for the rateliit
func (r *RateLimit) SetRequests(l int) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Requests = l
}
// SetDuration sets the duration for the ratelimit
func (r *RateLimit) SetDuration(d time.Duration) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
r.Duration = d
}
// GetDuration gets the duration for the ratelimit
func (r *RateLimit) GetDuration() time.Duration {
r.Mutex.Lock()
defer r.Mutex.Unlock()
return r.Duration
}
// StartCycle restarts the cycle time and requests counters
func (r *Requester) StartCycle() {
r.Cycle = time.Now()
r.AuthLimit.SetRequests(0)
r.UnauthLimit.SetRequests(0)
}
// IsRateLimited returns whether or not the request Requester is rate limited
func (r *Requester) IsRateLimited(auth bool) bool {
if auth {
if r.AuthLimit.GetRequests() >= r.AuthLimit.GetRate() && r.IsValidCycle(auth) {
return true
}
} else {
if r.UnauthLimit.GetRequests() >= r.UnauthLimit.GetRate() && r.IsValidCycle(auth) {
return true
}
}
return false
}
// RequiresRateLimiter returns whether or not the request Requester requires a rate limiter
func (r *Requester) RequiresRateLimiter() bool {
if DisableRateLimiter {
return false
}
if r.AuthLimit.GetRate() != 0 || r.UnauthLimit.GetRate() != 0 {
return true
}
return false
}
// IncrementRequests increments the ratelimiter request counter for either auth or unauth
// requests
func (r *Requester) IncrementRequests(auth bool) {
if auth {
reqs := r.AuthLimit.GetRequests()
reqs++
r.AuthLimit.SetRequests(reqs)
return
}
reqs := r.UnauthLimit.GetRequests()
reqs++
r.UnauthLimit.SetRequests(reqs)
}
// DecrementRequests decrements the ratelimiter request counter for either auth or unauth
// requests
func (r *Requester) DecrementRequests(auth bool) {
if auth {
reqs := r.AuthLimit.GetRequests()
reqs--
r.AuthLimit.SetRequests(reqs)
return
}
reqs := r.AuthLimit.GetRequests()
reqs--
r.UnauthLimit.SetRequests(reqs)
}
// SetRateLimit sets the request Requester ratelimiter
func (r *Requester) SetRateLimit(auth bool, duration time.Duration, rate int) {
if auth {
r.AuthLimit.SetRate(rate)
r.AuthLimit.SetDuration(duration)
return
}
r.UnauthLimit.SetRate(rate)
r.UnauthLimit.SetDuration(duration)
}
// GetRateLimit gets the request Requester ratelimiter
func (r *Requester) GetRateLimit(auth bool) *RateLimit {
if auth {
return r.AuthLimit
}
return r.UnauthLimit
}
// SetTimeoutRetryAttempts sets the amount of times the job will be retried
// if it times out
func (r *Requester) SetTimeoutRetryAttempts(n int) error {
if n < 0 {
return errors.New("routines.go error - timeout retry attempts cannot be less than zero")
}
r.timeoutRetryAttempts = n
return nil
}
// New returns a new Requester
func New(name string, authLimit, unauthLimit *RateLimit, httpRequester *http.Client) *Requester {
func New(name string, httpRequester *http.Client, l Limiter) *Requester {
return &Requester{
HTTPClient: httpRequester,
UnauthLimit: unauthLimit,
AuthLimit: authLimit,
Limiter: l,
Name: name,
Jobs: make(chan Job, MaxRequestJobs),
timeoutRetryAttempts: TimeoutRetryAttempts,
timedLock: timedmutex.NewTimedMutex(DefaultMutexLockTimeout),
}
}
// IsValidMethod returns whether the supplied method is supported
func IsValidMethod(method string) bool {
return common.StringDataCompareInsensitive(supportedMethods, method)
}
// IsValidCycle checks to see whether the current request cycle is valid or not
func (r *Requester) IsValidCycle(auth bool) bool {
if auth {
if time.Since(r.Cycle) < r.AuthLimit.GetDuration() {
return true
}
} else {
if time.Since(r.Cycle) < r.UnauthLimit.GetDuration() {
return true
}
// SendPayload handles sending HTTP/HTTPS requests
func (r *Requester) SendPayload(i *Item) error {
if !i.NonceEnabled {
r.timedLock.LockForDuration()
}
r.StartCycle()
return false
req, err := i.validateRequest(r)
if err != nil {
r.timedLock.UnlockIfLocked()
return err
}
if i.HTTPDebugging {
// Err not evaluated due to validation check above
dump, _ := httputil.DumpRequestOut(req, true)
log.Debugf(log.RequestSys, "DumpRequest:\n%s", dump)
}
if atomic.LoadInt32(&r.jobs) >= MaxRequestJobs {
r.timedLock.UnlockIfLocked()
return errors.New("max request jobs reached")
}
atomic.AddInt32(&r.jobs, 1)
err = r.doRequest(req, i)
atomic.AddInt32(&r.jobs, -1)
r.timedLock.UnlockIfLocked()
return err
}
func (r *Requester) checkRequest(method, path string, body io.Reader, headers map[string]string) (*http.Request, error) {
req, err := http.NewRequest(method, path, body)
// validateRequest validates the requester item fields
func (i *Item) validateRequest(r *Requester) (*http.Request, error) {
if r == nil || r.Name == "" {
return nil, errors.New("not initialised, SetDefaults() called before making request?")
}
if i == nil {
return nil, errors.New("request item cannot be nil")
}
if i.Path == "" {
return nil, errors.New("invalid path")
}
req, err := http.NewRequest(i.Method, i.Path, i.Body)
if err != nil {
return nil, err
}
for k, v := range headers {
for k, v := range i.Headers {
req.Header.Add(k, v)
}
if r.UserAgent != "" && req.Header.Get("User-Agent") == "" {
req.Header.Add("User-Agent", r.UserAgent)
if r.UserAgent != "" && req.Header.Get(userAgent) == "" {
req.Header.Add(userAgent, r.UserAgent)
}
return req, nil
}
// DoRequest performs a HTTP/HTTPS request with the supplied params
func (r *Requester) DoRequest(req *http.Request, path string, body io.Reader, result interface{}, authRequest, verbose, httpDebug, httpRecord bool) error {
if verbose {
log.Debugf(log.Global,
"%s exchange request path: %s requires rate limiter: %v",
func (r *Requester) doRequest(req *http.Request, p *Item) error {
if p == nil {
return errors.New("request item cannot be nil")
}
if p.Verbose {
log.Debugf(log.RequestSys,
"%s request path: %s",
r.Name,
path,
r.RequiresRateLimiter())
p.Path)
for k, d := range req.Header {
log.Debugf(log.Global, "%s exchange request header [%s]: %s", r.Name, k, d)
log.Debugf(log.RequestSys,
"%s request header [%s]: %s",
r.Name,
k,
d)
}
log.Debugf(log.RequestSys,
"%s request type: %s",
r.Name,
req.Method)
if p.Body != nil {
log.Debugf(log.RequestSys,
"%s request body: %v",
r.Name,
p.Body)
}
log.Debugf(log.Global,
"%s exchange request type: %s", r.Name, req.Method)
log.Debugf(log.Global,
"%s exchange request body: %v", r.Name, body)
}
var timeoutError error
for i := 0; i < r.timeoutRetryAttempts+1; i++ {
// Initiate a rate limit reservation and sleep on requested endpoint
err := r.InitiateRateLimit(p.Endpoint)
if err != nil {
return err
}
resp, err := r.HTTPClient.Do(req)
if err != nil {
if timeoutErr, ok := err.(net.Error); ok && timeoutErr.Timeout() {
if verbose {
log.Errorf(log.ExchangeSys, "%s request has timed-out retrying request, count %d",
if p.Verbose {
log.Errorf(log.RequestSys,
"%s request has timed-out retrying request, count %d",
r.Name,
i)
}
timeoutError = err
continue
}
if r.RequiresRateLimiter() {
r.DecrementRequests(authRequest)
}
return err
}
if resp == nil {
if r.RequiresRateLimiter() {
r.DecrementRequests(authRequest)
}
return errors.New("resp is nil")
}
var reader io.ReadCloser
switch resp.Header.Get("Content-Encoding") {
case "gzip":
reader, err = gzip.NewReader(resp.Body)
defer reader.Close()
if err != nil {
return err
}
case "json":
reader = resp.Body
default:
switch {
case strings.Contains(resp.Header.Get("Content-Type"), "application/json"):
reader = resp.Body
default:
if verbose {
log.Warnf(log.ExchangeSys,
"%s request response content type differs from JSON; received %v [path: %s]\n",
r.Name,
resp.Header.Get("Content-Type"),
path)
}
reader = resp.Body
}
}
contents, err := ioutil.ReadAll(reader)
contents, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
if httpRecord {
if p.HTTPRecording {
// This dumps http responses for future mocking implementations
err = mock.HTTPRecord(resp, r.Name, contents)
if err != nil {
@@ -304,169 +160,40 @@ func (r *Requester) DoRequest(req *http.Request, path string, body io.Reader, re
if resp.StatusCode < http.StatusOK ||
resp.StatusCode > http.StatusAccepted {
return fmt.Errorf("unsuccessful HTTP status code: %d body: %s",
return fmt.Errorf("%s unsuccessful HTTP status code: %d raw response: %s",
r.Name,
resp.StatusCode,
string(contents))
}
if httpDebug {
if p.HTTPDebugging {
dump, err := httputil.DumpResponse(resp, false)
if err != nil {
log.Errorf(log.Global, "DumpResponse invalid response: %v:", err)
log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", err)
}
log.Debugf(log.Global, "DumpResponse Headers (%v):\n%s", path, dump)
log.Debugf(log.Global, "DumpResponse Body (%v):\n %s", path, string(contents))
log.Debugf(log.RequestSys, "DumpResponse Headers (%v):\n%s", p.Path, dump)
log.Debugf(log.RequestSys, "DumpResponse Body (%v):\n %s", p.Path, string(contents))
}
resp.Body.Close()
if verbose {
log.Debugf(log.ExchangeSys, "HTTP status: %s, Code: %v", resp.Status, resp.StatusCode)
if !httpDebug {
log.Debugf(log.ExchangeSys, "%s exchange raw response: %s", r.Name, string(contents))
if p.Verbose {
log.Debugf(log.RequestSys,
"HTTP status: %s, Code: %v",
resp.Status,
resp.StatusCode)
if !p.HTTPDebugging {
log.Debugf(log.RequestSys,
"%s raw response: %s",
r.Name,
string(contents))
}
}
if result != nil {
return json.Unmarshal(contents, result)
}
return nil
return json.Unmarshal(contents, p.Result)
}
return fmt.Errorf("request.go error - failed to retry request %s",
timeoutError)
}
func (r *Requester) worker() {
for {
for x := range r.Jobs {
if !r.IsRateLimited(x.AuthRequest) {
r.IncrementRequests(x.AuthRequest)
err := r.DoRequest(x.Request, x.Path, x.Body, x.Result, x.AuthRequest, x.Verbose, x.HTTPDebugging, x.Record)
x.JobResult <- &JobResult{
Error: err,
Result: x.Result,
}
} else {
limit := r.GetRateLimit(x.AuthRequest)
diff := limit.GetDuration() - time.Since(r.Cycle)
if x.Verbose {
log.Debugf(log.ExchangeSys, "%s request. Rate limited! Sleeping for %v", r.Name, diff)
}
time.Sleep(diff)
for {
if r.IsRateLimited(x.AuthRequest) {
time.Sleep(time.Millisecond)
continue
}
r.IncrementRequests(x.AuthRequest)
if x.Verbose {
log.Debugf(log.ExchangeSys, "%s request. No longer rate limited! Doing request", r.Name)
}
err := r.DoRequest(x.Request, x.Path, x.Body, x.Result, x.AuthRequest, x.Verbose, x.HTTPDebugging, x.Record)
x.JobResult <- &JobResult{
Error: err,
Result: x.Result,
}
break
}
}
}
}
}
// SendPayload handles sending HTTP/HTTPS requests
func (r *Requester) SendPayload(method, path string, headers map[string]string, body io.Reader, result interface{}, authRequest, nonceEnabled, verbose, httpDebugging, record bool) error {
if !nonceEnabled {
r.timedLock.LockForDuration()
}
if r == nil || r.Name == "" {
r.timedLock.UnlockIfLocked()
return errors.New("not initiliased, SetDefaults() called before making request?")
}
if !IsValidMethod(method) {
r.timedLock.UnlockIfLocked()
return fmt.Errorf("incorrect method supplied %s: supported %s", method, supportedMethods)
}
if path == "" {
r.timedLock.UnlockIfLocked()
return errors.New("invalid path")
}
req, err := r.checkRequest(method, path, body, headers)
if err != nil {
r.timedLock.UnlockIfLocked()
return err
}
if httpDebugging {
dump, err := httputil.DumpRequestOut(req, true)
if err != nil {
log.Errorf(log.Global,
"DumpRequest invalid response %v:", err)
}
log.Debugf(log.Global,
"DumpRequest:\n%s", dump)
}
if !r.RequiresRateLimiter() {
r.timedLock.UnlockIfLocked()
return r.DoRequest(req, path, body, result, authRequest, verbose, httpDebugging, record)
}
if len(r.Jobs) == MaxRequestJobs {
r.timedLock.UnlockIfLocked()
return errors.New("max request jobs reached")
}
r.m.Lock()
if !r.WorkerStarted {
r.StartCycle()
r.WorkerStarted = true
go r.worker()
}
r.m.Unlock()
jobResult := make(chan *JobResult)
newJob := Job{
Request: req,
Method: method,
Path: path,
Headers: headers,
Body: body,
Result: result,
JobResult: jobResult,
AuthRequest: authRequest,
Verbose: verbose,
HTTPDebugging: httpDebugging,
Record: record,
}
if verbose {
log.Debugf(log.ExchangeSys, "%s request. Attaching new job.", r.Name)
}
r.Jobs <- newJob
r.timedLock.UnlockIfLocked()
if verbose {
log.Debugf(log.ExchangeSys, "%s request. Waiting for job to complete.", r.Name)
}
resp := <-newJob.JobResult
if verbose {
log.Debugf(log.ExchangeSys, "%s request. Job complete.", r.Name)
}
return resp.Error
}
// GetNonce returns a nonce for requests. This locks and enforces concurrent
// nonce FIFO on the buffered job channel
func (r *Requester) GetNonce(isNano bool) nonce.Value {

View File

@@ -1,322 +1,437 @@
package request
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"net/url"
"os"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
)
func TestNewRateLimit(t *testing.T) {
r := NewRateLimit(time.Second*10, 5)
const unexpected = "unexpected values"
if r.Duration != time.Second*10 && r.Rate != 5 {
t.Fatal("unexpected values")
}
}
var testURL string
var serverLimit *rate.Limiter
func TestSetRate(t *testing.T) {
r := NewRateLimit(time.Second*10, 5)
r.SetRate(40)
if r.GetRate() != 40 {
t.Fatal("unexpected values")
}
}
func TestSetDuration(t *testing.T) {
r := NewRateLimit(time.Second*10, 5)
r.SetDuration(time.Second)
if r.GetDuration() != time.Second {
t.Fatal("unexpected values")
}
}
func TestDecerementRequests(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
r.AuthLimit.SetRequests(99)
r.DecrementRequests(true)
if r.AuthLimit.GetRequests() != 98 {
t.Fatal("unexpected values")
}
}
func TestStartCycle(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
if r.AuthLimit.Duration != time.Second*10 && r.AuthLimit.Rate != 5 {
t.Fatal("unexpected values")
}
if r.UnauthLimit.Duration != time.Second*20 && r.UnauthLimit.Rate != 100 {
t.Fatal("unexpected values")
}
r.AuthLimit.SetRequests(1)
r.UnauthLimit.SetRequests(1)
r.StartCycle()
if r.Cycle.IsZero() || r.AuthLimit.GetRequests() != 0 || r.UnauthLimit.GetRequests() != 0 {
t.Fatal("unexpcted values")
}
}
func TestIsRateLimited(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
r.StartCycle()
if r.AuthLimit.String() != "Rate limiter set to 5 requests per 10s" {
t.Fatal("unexcpted values")
}
if r.UnauthLimit.String() != "Rate limiter set to 100 requests per 20s" {
t.Fatal("unexpected values")
}
if r.AuthLimit.String() != "Rate limiter set to 5 requests per 10s" {
t.Fatal("unexcpted values")
}
// FIXME: Need to account for unauth/auth/total requests
r.AuthLimit.SetRequests(4)
if r.AuthLimit.GetRequests() != 4 {
t.Fatal("unexpected values")
}
// test that we're not rate limited since 4 < 5
if r.IsRateLimited(true) {
t.Fatal("unexpected values")
}
// bump requests counter to 6 which would exceed the rate limiter
r.AuthLimit.SetRequests(6)
if !r.IsRateLimited(true) {
t.Fatal("unexpected values")
}
// FIXME: Need to account for unauth/auth/total requests
r.UnauthLimit.SetRequests(99)
if r.UnauthLimit.GetRequests() != 99 {
t.Fatal("unexpected values")
}
// test that we're not rate limited since 99 < 100
if r.IsRateLimited(false) {
t.Fatal("unexpected values")
}
// bump requests counter to 100 which would exceed the rate limiter
r.UnauthLimit.SetRequests(100)
if !r.IsRateLimited(false) {
t.Fatal("unexpected values")
}
}
func TestRequiresRateLimiter(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
if !r.RequiresRateLimiter() {
t.Fatal("unexpected values")
}
r.AuthLimit.Rate = 0
r.UnauthLimit.Rate = 0
if r.RequiresRateLimiter() {
t.Fatal("unexpected values")
}
}
func TestSetLimit(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
r.SetRateLimit(true, time.Minute, 20)
if r.AuthLimit.Rate != 20 && r.AuthLimit.Duration != time.Minute*20 {
t.Fatal("unexpected values")
}
r.SetRateLimit(false, time.Minute, 40)
if r.UnauthLimit.Rate != 40 && r.UnauthLimit.Duration != time.Minute {
t.Fatal("unexpected values")
}
}
func TestGetLimit(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
if r.GetRateLimit(true).Duration != time.Second*10 && r.GetRateLimit(true).Rate != 5 {
t.Fatal("unexpected values")
}
if r.GetRateLimit(false).Duration != time.Second*10 && r.GetRateLimit(false).Rate != 100 {
t.Fatal("unexpected values")
}
}
func TestIsValidMethod(t *testing.T) {
for x := range supportedMethods {
if !IsValidMethod(supportedMethods[x]) {
t.Fatal("unexpected values")
func TestMain(m *testing.M) {
serverLimit = NewRateLimit(time.Millisecond*500, 1)
sm := http.NewServeMux()
sm.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, `{"response":true}`)
})
sm.HandleFunc("/error", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, `{"error":true}`)
})
sm.HandleFunc("/timeout", func(w http.ResponseWriter, req *http.Request) {
time.Sleep(time.Millisecond * 100)
w.WriteHeader(http.StatusGatewayTimeout)
})
sm.HandleFunc("/rate", func(w http.ResponseWriter, req *http.Request) {
if !serverLimit.Allow() {
http.Error(w,
http.StatusText(http.StatusTooManyRequests),
http.StatusTooManyRequests)
io.WriteString(w, `{"response":false}`)
return
}
}
io.WriteString(w, `{"response":true}`)
})
if IsValidMethod("BLAH") {
t.Fatal("unexpected values")
}
server := httptest.NewServer(sm)
testURL = server.URL
issues := m.Run()
server.Close()
os.Exit(issues)
}
func TestIsValidCycle(t *testing.T) {
r := New("bitfinex", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
r.Cycle = time.Now().Add(-9 * time.Second)
if !r.IsValidCycle(true) {
t.Fatal("unexpected values")
func TestNewRateLimit(t *testing.T) {
t.Parallel()
r := NewRateLimit(time.Second*10, 5)
if r.Limit() != 0.5 {
t.Fatal(unexpected)
}
r.Cycle = time.Now().Add(-11 * time.Second)
if r.IsValidCycle(true) {
t.Fatal("unexpected values")
// Ensures rate limiting factor is the same
r = NewRateLimit(time.Second*2, 1)
if r.Limit() != 0.5 {
t.Fatal(unexpected)
}
r.Cycle = time.Now().Add(-19 * time.Second)
if !r.IsValidCycle(false) {
t.Fatal("unexpected values")
// Test for open rate limit
r = NewRateLimit(time.Second*2, 0)
if r.Limit() != rate.Inf {
t.Fatal(unexpected)
}
r.Cycle = time.Now().Add(-21 * time.Second)
if r.IsValidCycle(false) {
t.Fatal("unexpected values")
r = NewRateLimit(0, 69)
if r.Limit() != rate.Inf {
t.Fatal(unexpected)
}
}
func TestCheckRequest(t *testing.T) {
r := New("", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
_, err := r.checkRequest("bad method, bad", "http://www.google.com", nil, nil)
t.Parallel()
r := New("TestRequest",
new(http.Client),
nil)
var check *Item
_, err := check.validateRequest(&Requester{})
if err == nil {
t.Fatal("unexpected values")
t.Fatal(unexpected)
}
_, err = check.validateRequest(nil)
if err == nil {
t.Fatal(unexpected)
}
_, err = check.validateRequest(r)
if err == nil {
t.Fatal(unexpected)
}
check = &Item{}
_, err = check.validateRequest(r)
if err == nil {
t.Fatal(unexpected)
}
check.Path = testURL
check.Method = " " // Forces method check; "" automatically converts to GET
_, err = check.validateRequest(r)
if err == nil {
t.Fatal(unexpected)
}
check.Method = http.MethodPost
_, err = check.validateRequest(r)
if err != nil {
t.Fatal(err)
}
// Test setting headers
check.Headers = map[string]string{
"Content-Type": "Super awesome HTTP party experience",
}
// Test user agent set
r.UserAgent = "r00t axxs"
req, err := check.validateRequest(r)
if err != nil {
t.Fatal(err)
}
if req.Header.Get("Content-Type") != "Super awesome HTTP party experience" {
t.Fatal(unexpected)
}
if req.UserAgent() != "r00t axxs" {
t.Fatal(unexpected)
}
}
type GlobalLimitTest struct {
Auth *rate.Limiter
UnAuth *rate.Limiter
}
func (g *GlobalLimitTest) Limit(e EndpointLimit) error {
switch e {
case Auth:
if g.Auth == nil {
return errors.New("auth rate not set")
}
time.Sleep(g.Auth.Reserve().Delay())
return nil
case UnAuth:
if g.UnAuth == nil {
return errors.New("unauth rate not set")
}
time.Sleep(g.UnAuth.Reserve().Delay())
return nil
default:
return fmt.Errorf("cannot execute functionality: %d not found", e)
}
}
var globalshell = GlobalLimitTest{
Auth: NewRateLimit(time.Millisecond*600, 1),
UnAuth: NewRateLimit(time.Second*1, 100)}
func TestDoRequest(t *testing.T) {
r := New("", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
r.Name = "bitfinex"
err := r.SendPayload("BLAH", "https://www.google.com", nil, nil, nil, false, false, true, false, false)
t.Parallel()
r := New("test",
new(http.Client),
&globalshell)
err := r.SendPayload(&Item{})
if err == nil {
t.Fatal("Expected error")
t.Fatal(unexpected)
}
err = r.SendPayload(http.MethodGet, "", nil, nil, nil, false, false, true, false, false)
err = r.SendPayload(&Item{Method: http.MethodGet})
if err == nil {
t.Fatal("Expected error")
t.Fatal(unexpected)
}
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, nil, false, false, true, false, false)
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
})
if err == nil {
t.Fatal(unexpected)
}
// force debug
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
HTTPDebugging: true,
Verbose: true,
})
if err == nil {
t.Fatal(unexpected)
}
// max request job ceiling
r.jobs = MaxRequestJobs
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
})
if err == nil {
t.Fatal(unexpected)
}
// reset jobs
r.jobs = 0
// timeout checker
r.HTTPClient.Timeout = time.Millisecond * 50
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL + "/timeout",
})
if err == nil {
t.Fatal(unexpected)
}
// reset timeout
r.HTTPClient.Timeout = 0
// Check JSON
var resp struct {
Response bool `json:"response"`
}
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
Result: &resp,
Endpoint: UnAuth,
})
if err != nil {
t.Fatal("unexpected values", err)
t.Fatal(err)
}
if !resp.Response {
t.Fatal(unexpected)
}
if !r.RequiresRateLimiter() {
t.Fatal("unexpected values")
// Check error
var respErr struct {
Error bool `json:"error"`
}
r.SetRateLimit(false, time.Second, 0)
r.SetRateLimit(true, time.Second, 0)
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, nil, false, false, true, false, false)
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
Result: &respErr,
Endpoint: UnAuth,
})
if err != nil {
t.Fatal("unexpected values", err)
t.Fatal(err)
}
if !resp.Response {
t.Fatal(unexpected)
}
if r.RequiresRateLimiter() {
t.Fatal("unexpected values")
// Check rate limit
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func(wg *sync.WaitGroup) {
var resp struct {
Response bool `json:"response"`
}
payloadError := r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL + "/rate",
Result: &resp,
AuthRequest: true,
Endpoint: Auth,
})
wg.Done()
if payloadError != nil {
log.Fatal(payloadError)
}
if !resp.Response {
log.Fatal(unexpected)
}
}(&wg)
}
wg.Wait()
}
func TestGetNonce(t *testing.T) {
t.Parallel()
r := New("test",
new(http.Client),
&globalshell)
n1 := r.GetNonce(false)
n2 := r.GetNonce(false)
if n1 == n2 {
t.Fatal(unexpected)
}
r.SetRateLimit(false, time.Millisecond*200, 100)
r.SetRateLimit(true, time.Millisecond*100, 100)
r.Cycle = time.Now().Add(time.Millisecond * -201)
if r.IsValidCycle(false) {
t.Fatal("unexpected values")
r2 := New("test",
new(http.Client),
&globalshell)
n3 := r2.GetNonce(true)
n4 := r2.GetNonce(true)
if n3 == n4 {
t.Fatal(unexpected)
}
}
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, nil, false, false, true, false, false)
func TestGetNonceMillis(t *testing.T) {
t.Parallel()
r := New("test",
new(http.Client),
&globalshell)
m1 := r.GetNonceMilli()
m2 := r.GetNonceMilli()
if m1 == m2 {
log.Fatal(unexpected)
}
}
func TestSetProxy(t *testing.T) {
t.Parallel()
r := New("test",
new(http.Client),
&globalshell)
u, err := url.Parse("http://www.google.com")
if err != nil {
t.Fatal("unexpected values")
t.Fatal(err)
}
r.Cycle = time.Now().Add(time.Millisecond * -101)
if r.IsValidCycle(true) {
t.Fatal("unexepcted values")
}
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, nil, true, false, true, false, false)
err = r.SetProxy(u)
if err != nil {
t.Fatal("unexpected values")
t.Fatal(err)
}
u, err = url.Parse("")
if err != nil {
t.Fatal(err)
}
err = r.SetProxy(u)
if err == nil {
t.Fatal("error cannot be nil")
}
}
func TestBasicLimiter(t *testing.T) {
r := New("test",
new(http.Client),
NewBasicRateLimit(time.Second, 1))
i := Item{
Path: "http://www.google.com",
Method: http.MethodGet,
}
var result interface{}
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, result, false, false, true, false, false)
tn := time.Now()
_ = r.SendPayload(&i)
_ = r.SendPayload(&i)
if time.Since(tn) < time.Second {
t.Error("rate limit issues")
}
}
func TestEnableDisableRateLimit(t *testing.T) {
r := New("TestRequest",
new(http.Client),
NewBasicRateLimit(time.Minute, 1))
var resp interface{}
err := r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
Result: &resp,
AuthRequest: true,
Endpoint: Auth,
})
if err != nil {
t.Fatal(err)
}
headers := make(map[string]string)
headers["content-type"] = "content/text"
err = r.SendPayload(http.MethodPost, "https://bitfinex.com", headers, nil, result, false, false, true, false, false)
err = r.EnableRateLimiter()
if err == nil {
t.Fatal("error cannot be nil")
}
err = r.DisableRateLimiter()
if err != nil {
t.Fatal(err)
}
r.StartCycle()
r.UnauthLimit.SetRequests(100)
err = r.SendPayload(http.MethodGet, "https://www.google.com", nil, nil, result, false, false, false, false, false)
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
Result: &resp,
AuthRequest: true,
Endpoint: Auth,
})
if err != nil {
t.Fatal("unexpected values")
t.Fatal(err)
}
err = r.SetTimeoutRetryAttempts(1)
if err != nil {
t.Fatal("setting timeout retry attempts")
}
err = r.SetTimeoutRetryAttempts(-1)
err = r.DisableRateLimiter()
if err == nil {
t.Fatal("setting timeout retry attempts with negative value")
t.Fatal("error cannot be nil")
}
r.HTTPClient.Timeout = 1 * time.Second
err = r.SendPayload(http.MethodPost, "https://httpstat.us/200?sleep=20000", nil, nil, nil, false, false, true, false, false)
if err == nil {
t.Fatal("Expected error")
}
proxy, err := url.Parse("")
err = r.EnableRateLimiter()
if err != nil {
t.Error("failed to parse proxy address")
t.Fatal(err)
}
err = r.SetProxy(proxy)
if err == nil {
t.Error("Expected error")
}
ti := time.NewTicker(time.Second)
c := make(chan struct{})
go func(c chan struct{}) {
err = r.SendPayload(&Item{
Method: http.MethodGet,
Path: testURL,
Result: &resp,
AuthRequest: true,
Endpoint: Auth,
})
if err != nil {
log.Fatal(err)
}
c <- struct{}{}
}(c)
proxy, err = url.Parse("https://192.0.0.1")
if err != nil {
t.Error("failed to parse proxy address")
}
err = r.SetProxy(proxy)
if err != nil {
t.Error("failed to set proxy")
}
}
func BenchmarkRequestLockMech(b *testing.B) {
r := New("", NewRateLimit(time.Second*10, 5), NewRateLimit(time.Second*20, 100), new(http.Client))
var meep interface{}
for n := 0; n < b.N; n++ {
r.SendPayload(http.MethodGet, "127.0.0.1", nil, nil, &meep, false, false, false, false, false)
select {
case <-c:
t.Fatal("rate limiting failure")
case <-ti.C:
// Correct test
}
}

View File

@@ -3,73 +3,52 @@ package request
import (
"io"
"net/http"
"sync"
"time"
"github.com/thrasher-corp/gocryptotrader/common/timedmutex"
"github.com/thrasher-corp/gocryptotrader/exchanges/nonce"
)
var supportedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead,
http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodConnect}
// Const vars for rate limiter
const (
DefaultMaxRequestJobs = 50
DefaultTimeoutRetryAttempts = 3
DefaultMutexLockTimeout = 50 * time.Millisecond
proxyTLSTimeout = 15 * time.Second
DefaultMaxRequestJobs int32 = 50
DefaultTimeoutRetryAttempts = 3
DefaultMutexLockTimeout = 50 * time.Millisecond
proxyTLSTimeout = 15 * time.Second
userAgent = "User-Agent"
)
// Vars for rate limiter
var (
MaxRequestJobs = DefaultMaxRequestJobs
TimeoutRetryAttempts = DefaultTimeoutRetryAttempts
DisableRateLimiter bool
)
// Requester struct for the request client
type Requester struct {
HTTPClient *http.Client
UnauthLimit *RateLimit
AuthLimit *RateLimit
Limiter Limiter
Name string
UserAgent string
Cycle time.Time
timeoutRetryAttempts int
m sync.Mutex
Jobs chan Job
WorkerStarted bool
jobs int32
Nonce nonce.Nonce
DisableRateLimiter bool
disableRateLimiter int32
timedLock *timedmutex.TimedMutex
}
// RateLimit struct
type RateLimit struct {
Duration time.Duration
Rate int
Requests int
Mutex sync.Mutex
}
// JobResult holds a request job result
type JobResult struct {
Error error
Result interface{}
}
// Job holds a request job
type Job struct {
Request *http.Request
// Item is a temp item for requests
type Item struct {
Method string
Path string
Headers map[string]string
Body io.Reader
Result interface{}
JobResult chan *JobResult
AuthRequest bool
NonceEnabled bool
Verbose bool
HTTPDebugging bool
Record bool
HTTPRecording bool
IsReserved bool
Endpoint EndpointLimit
}