Request: Fix http.Client race issue when setting transport layer proxy and timeouts (#885)

* backtester/request: trying to fix panic (WIP)

* request: fix race for transport layer

* request: linter issue fix

* request: more linter issues

* requester: Add function to remove the tracking of underlying http client and add to engine unload exchange.

* request: add more context to error return

* request: Fix after cherry pick issues

* request: fix niterinos

* exchanges: change return to package variable

* request: changed named

Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
Ryan O'Hara-Reid
2022-02-18 09:22:10 +11:00
committed by GitHub
parent 11da520dc8
commit 6127e2ab73
53 changed files with 752 additions and 257 deletions

View File

@@ -1288,15 +1288,19 @@ func updateFile(name string) error {
// SendGetReq sends get req
func sendGetReq(path string, result interface{}) error {
var requester *request.Requester
var err error
if strings.Contains(path, "github") {
requester = request.New("Apichecker",
requester, err = request.New("Apichecker",
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(time.Hour, 60)))
} else {
requester = request.New("Apichecker",
requester, err = request.New("Apichecker",
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(time.Second, 100)))
}
if err != nil {
return err
}
item := &request.Item{
Method: http.MethodGet,
Path: path,
@@ -1309,9 +1313,12 @@ func sendGetReq(path string, result interface{}) error {
// sendAuthReq sends auth req
func sendAuthReq(method, path string, result interface{}) error {
requester := request.New("Apichecker",
requester, err := request.New("Apichecker",
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(time.Second*10, 100)))
if err != nil {
return err
}
item := &request.Item{
Method: method,
Path: path,

View File

@@ -36,10 +36,14 @@ func (c *Coinmarketcap) SetDefaults() {
c.Verbose = false
c.APIUrl = baseURL
c.APIVersion = version
c.Requester = request.New(c.Name,
var err error
c.Requester, err = request.New(c.Name,
common.NewHTTPClientWithTimeout(defaultTimeOut),
request.WithLimiter(request.NewBasicRateLimit(RateInterval, BasicRequestRate)),
)
if err != nil {
log.Errorln(log.Global, err)
}
}
// Setup sets user configuration

View File

@@ -23,10 +23,11 @@ func (c *CurrencyConverter) Setup(config base.Settings) error {
c.Enabled = config.Enabled
c.Verbose = config.Verbose
c.PrimaryProvider = config.PrimaryProvider
c.Requester = request.New(c.Name,
var err error
c.Requester, err = request.New(c.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut),
request.WithLimiter(request.NewBasicRateLimit(rateInterval, requestRate)))
return nil
return err
}
// GetRates is a wrapper function to return rates

View File

@@ -43,10 +43,10 @@ func (c *CurrencyLayer) Setup(config base.Settings) error {
c.Verbose = config.Verbose
c.PrimaryProvider = config.PrimaryProvider
// Rate limit is based off a monthly counter - Open limit used.
c.Requester = request.New(c.Name,
var err error
c.Requester, err = request.New(c.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
return nil
return err
}
// GetRates is a wrapper function to return rates for GoCryptoTrader

View File

@@ -33,9 +33,10 @@ func (e *ExchangeRateHost) Setup(config base.Settings) error {
e.Enabled = config.Enabled
e.Verbose = config.Verbose
e.PrimaryProvider = config.PrimaryProvider
e.Requester = request.New(e.Name,
var err error
e.Requester, err = request.New(e.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
return nil
return err
}
// GetLatestRates returns a list of forex rates based on the supplied params

View File

@@ -29,10 +29,11 @@ func (e *ExchangeRates) Setup(config base.Settings) error {
e.PrimaryProvider = config.PrimaryProvider
e.APIKey = config.APIKey
e.APIKeyLvl = config.APIKeyLvl
e.Requester = request.New(e.Name,
var err error
e.Requester, err = request.New(e.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut),
request.WithLimiter(request.NewBasicRateLimit(rateLimitInterval, requestRate)))
return nil
return err
}
func (e *ExchangeRates) cleanCurrencies(baseCurrency, symbols string) string {

View File

@@ -36,9 +36,10 @@ func (f *Fixer) Setup(config base.Settings) error {
f.Name = config.Name
f.Verbose = config.Verbose
f.PrimaryProvider = config.PrimaryProvider
f.Requester = request.New(f.Name,
var err error
f.Requester, err = request.New(f.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
return nil
return err
}
// GetSupportedCurrencies returns supported currencies

View File

@@ -37,9 +37,10 @@ func (o *OXR) Setup(config base.Settings) error {
o.Name = config.Name
o.Verbose = config.Verbose
o.PrimaryProvider = config.PrimaryProvider
o.Requester = request.New(o.Name,
var err error
o.Requester, err = request.New(o.Name,
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
return nil
return err
}
// GetRates is a wrapper function to return rates

View File

@@ -94,12 +94,16 @@ func (m *ExchangeManager) RemoveExchange(exchName string) error {
if m.Len() == 0 {
return ErrNoExchangesLoaded
}
_, err := m.GetExchangeByName(exchName)
exch, err := m.GetExchangeByName(exchName)
if err != nil {
return err
}
m.m.Lock()
defer m.m.Unlock()
err = exch.GetBase().Requester.Shutdown()
if err != nil {
return err
}
delete(m.exchanges, strings.ToLower(exchName))
log.Infof(log.ExchangeSys, "%s exchange unloaded successfully.\n", exchName)
return nil

View File

@@ -73,8 +73,11 @@ func (a *Alphapoint) SetDefaults() {
},
}
a.Requester = request.New(a.Name,
a.Requester, err = request.New(a.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
}
// FetchTradablePairs returns a list of the exchanges tradable pairs

View File

@@ -46,7 +46,10 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("Mock server error %s", err)
}
b.HTTPClient = newClient
err = b.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpointMap := b.API.Endpoints.GetURLMap()
for k := range endpointMap {
err = b.API.Endpoints.SetRunning(k, serverDetails)

View File

@@ -186,9 +186,12 @@ func (b *Binance) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: spotAPIURL,

View File

@@ -165,9 +165,12 @@ func (b *Bitfinex) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: bitfinexAPIURLBase,

View File

@@ -94,9 +94,12 @@ func (b *Bitflyer) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: japanURL,

View File

@@ -129,9 +129,12 @@ func (b *Bithumb) SetDefaults() {
},
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: apiURL,

View File

@@ -124,9 +124,12 @@ func (b *Bitmex) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: bitmexAPIURL,

View File

@@ -46,7 +46,10 @@ func TestMain(m *testing.M) {
log.Fatalf("Mock server error %s", err)
}
b.HTTPClient = newClient
err = b.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpointMap := b.API.Endpoints.GetURLMap()
for k := range endpointMap {
err = b.API.Endpoints.SetRunning(k, serverDetails+"/api")

View File

@@ -129,9 +129,12 @@ func (b *Bitstamp) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(bitstampRateInterval, bitstampRequestRate)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: bitstampAPIURL,

View File

@@ -120,9 +120,12 @@ func (b *Bittrex) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()

View File

@@ -119,9 +119,12 @@ func (b *BTCMarkets) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: btcMarketsAPIURL,

View File

@@ -149,9 +149,12 @@ func (b *BTSE) SetDefaults() {
},
}
b.Requester = request.New(b.Name,
b.Requester, err = request.New(b.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
b.API.Endpoints = b.NewEndpoints()
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: btseAPIURL,

View File

@@ -130,9 +130,12 @@ func (c *CoinbasePro) SetDefaults() {
},
}
c.Requester = request.New(c.Name,
c.Requester, err = request.New(c.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
c.API.Endpoints = c.NewEndpoints()
err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: coinbaseproAPIURL,

View File

@@ -113,8 +113,11 @@ func (c *COINUT) SetDefaults() {
},
}
c.Requester = request.New(c.Name,
c.Requester, err = request.New(c.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
c.API.Endpoints = c.NewEndpoints()
err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: coinutAPIURL,

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
@@ -20,7 +19,6 @@ import (
"github.com/thrasher-corp/gocryptotrader/exchanges/currencystate"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
"github.com/thrasher-corp/gocryptotrader/log"
@@ -46,56 +44,8 @@ var (
ErrAuthenticatedRequestWithoutCredentialsSet = errors.New("authenticated HTTP request called but not supported due to unset/default API keys")
errEndpointStringNotFound = errors.New("endpoint string not found")
errTransportNotSet = errors.New("transport not set, cannot set timeout")
// ErrPairNotFound is an error message for when unable to find a currency pair
ErrPairNotFound = errors.New("pair not found")
)
func (b *Base) checkAndInitRequester() {
if b.Requester == nil {
b.Requester = request.New(b.Name,
&http.Client{Transport: new(http.Transport)})
}
}
// SetHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and
// also the underlying transports idle connection timeout
func (b *Base) SetHTTPClientTimeout(t time.Duration) error {
b.checkAndInitRequester()
b.Requester.HTTPClient.Timeout = t
tr, ok := b.Requester.HTTPClient.Transport.(*http.Transport)
if !ok {
return errTransportNotSet
}
tr.IdleConnTimeout = t
return nil
}
// SetHTTPClient sets exchanges HTTP client
func (b *Base) SetHTTPClient(h *http.Client) {
b.checkAndInitRequester()
b.Requester.HTTPClient = h
}
// GetHTTPClient gets the exchanges HTTP client
func (b *Base) GetHTTPClient() *http.Client {
b.checkAndInitRequester()
return b.Requester.HTTPClient
}
// SetHTTPClientUserAgent sets the exchanges HTTP user agent
func (b *Base) SetHTTPClientUserAgent(ua string) {
b.checkAndInitRequester()
b.Requester.UserAgent = ua
b.HTTPUserAgent = ua
}
// GetHTTPClientUserAgent gets the exchanges HTTP user agent
func (b *Base) GetHTTPClientUserAgent() string {
return b.HTTPUserAgent
}
// SetClientProxyAddress sets a proxy address for REST and websocket requests
func (b *Base) SetClientProxyAddress(addr string) error {
if addr == "" {
@@ -441,17 +391,16 @@ func (b *Base) GetEnabledPairs(a asset.Item) (currency.Pairs, error) {
// GetRequestFormattedPairAndAssetType is a method that returns the enabled currency pair of
// along with its asset type. Only use when there is no chance of the same name crossing over
func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, asset.Item, error) {
assetTypes := b.GetAssetTypes(false)
var response currency.Pair
assetTypes := b.GetAssetTypes(true)
for i := range assetTypes {
format, err := b.GetPairFormat(assetTypes[i], true)
if err != nil {
return response, assetTypes[i], err
return currency.EMPTYPAIR, assetTypes[i], err
}
pairs, err := b.CurrencyPairs.GetPairs(assetTypes[i], true)
if err != nil {
return response, assetTypes[i], err
return currency.EMPTYPAIR, assetTypes[i], err
}
for j := range pairs {
@@ -461,8 +410,7 @@ func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, ass
}
}
}
return response, "",
fmt.Errorf("%s %w", p, ErrPairNotFound)
return currency.EMPTYPAIR, "", fmt.Errorf("%s %w", p, currency.ErrPairNotFound)
}
// GetAvailablePairs is a method that returns the available currency pairs
@@ -607,7 +555,10 @@ func (b *Base) SetupDefaults(exch *config.Exchange) error {
b.HTTPDebugging = exch.HTTPDebugging
b.BypassConfigFormatUpgrades = exch.CurrencyPairs.BypassConfigFormatUpgrades
b.SetHTTPClientUserAgent(exch.HTTPUserAgent)
err = b.SetHTTPClientUserAgent(exch.HTTPUserAgent)
if err != nil {
return err
}
b.SetCurrencyPairFormat()
err = b.SetConfigPairs()

View File

@@ -5,9 +5,7 @@ import (
"errors"
"fmt"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
@@ -193,76 +191,21 @@ func TestSetDefaultEndpoints(t *testing.T) {
}
}
func TestHTTPClient(t *testing.T) {
t.Parallel()
r := Base{Name: "asdf"}
err := r.SetHTTPClientTimeout(time.Second * 5)
if err != nil {
t.Fatal(err)
}
if r.GetHTTPClient().Timeout != time.Second*5 {
t.Fatalf("TestHTTPClient unexpected value")
}
r.Requester = nil
newClient := new(http.Client)
newClient.Timeout = time.Second * 10
r.SetHTTPClient(newClient)
if r.GetHTTPClient().Timeout != time.Second*10 {
t.Fatalf("TestHTTPClient unexpected value")
}
r.Requester = nil
if r.GetHTTPClient() == nil {
t.Fatalf("TestHTTPClient unexpected value")
}
b := Base{Name: "RAWR"}
b.Requester = request.New(b.Name, new(http.Client))
err = b.SetHTTPClientTimeout(time.Second * 5)
if !errors.Is(err, errTransportNotSet) {
t.Fatalf("received: %v but expected: %v", err, errTransportNotSet)
}
b.Requester = request.New(b.Name, &http.Client{Transport: new(http.Transport)})
err = b.SetHTTPClientTimeout(time.Second * 5)
if err != nil {
t.Fatal(err)
}
if b.GetHTTPClient().Timeout != time.Second*5 {
t.Fatalf("TestHTTPClient unexpected value")
}
newClient = new(http.Client)
newClient.Timeout = time.Second * 10
b.SetHTTPClient(newClient)
if b.GetHTTPClient().Timeout != time.Second*10 {
t.Fatalf("TestHTTPClient unexpected value")
}
b.SetHTTPClientUserAgent("epicUserAgent")
if !strings.Contains(b.GetHTTPClientUserAgent(), "epicUserAgent") {
t.Error("user agent not set properly")
}
}
func TestSetClientProxyAddress(t *testing.T) {
t.Parallel()
requester := request.New("rawr",
requester, err := request.New("rawr",
common.NewHTTPClientWithTimeout(time.Second*15))
if err != nil {
t.Fatal(err)
}
newBase := Base{
Name: "rawr",
Requester: requester}
newBase.Websocket = stream.New()
err := newBase.SetClientProxyAddress("")
err = newBase.SetClientProxyAddress("")
if err != nil {
t.Error(err)
}
@@ -289,13 +232,6 @@ func TestSetClientProxyAddress(t *testing.T) {
if newBase.Websocket.GetProxyAddress() != "http://www.valid.com" {
t.Error("SetClientProxyAddress error", err)
}
// Nil out transport
newBase.Requester.HTTPClient.Transport = nil
err = newBase.SetClientProxyAddress("http://www.valid.com")
if err == nil {
t.Error("error cannot be nil")
}
}
func TestSetFeatureDefaults(t *testing.T) {
@@ -1268,7 +1204,16 @@ func TestSetAPIKeys(t *testing.T) {
func TestSetupDefaults(t *testing.T) {
t.Parallel()
var b = Base{Name: "awesomeTest"}
newRequester, err := request.New("testSetupDefaults",
common.NewHTTPClientWithTimeout(0))
if err != nil {
t.Fatal(err)
}
var b = Base{
Name: "awesomeTest",
Requester: newRequester,
}
cfg := config.Exchange{
HTTPTimeout: time.Duration(-1),
API: config.APIConfig{
@@ -1276,7 +1221,7 @@ func TestSetupDefaults(t *testing.T) {
},
}
err := b.SetupDefaults(&cfg)
err = b.SetupDefaults(&cfg)
if err != nil {
t.Fatal(err)
}
@@ -2055,25 +2000,34 @@ func TestCheckTransientError(t *testing.T) {
func TestDisableEnableRateLimiter(t *testing.T) {
b := Base{}
b.checkAndInitRequester()
err := b.EnableRateLimiter()
if err == nil {
t.Fatal("error cannot be nil")
if !errors.Is(err, request.ErrRequestSystemIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRequestSystemIsNil)
}
err = b.DisableRateLimiter()
b.Requester, err = request.New("testingRateLimiter", common.NewHTTPClientWithTimeout(0))
if err != nil {
t.Fatal(err)
}
err = b.DisableRateLimiter()
if err == nil {
t.Fatal("error cannot be nil")
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
err = b.DisableRateLimiter()
if !errors.Is(err, request.ErrRateLimiterAlreadyDisabled) {
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyDisabled)
}
err = b.EnableRateLimiter()
if err != nil {
t.Fatal(err)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
err = b.EnableRateLimiter()
if !errors.Is(err, request.ErrRateLimiterAlreadyEnabled) {
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyEnabled)
}
}

View File

@@ -219,7 +219,6 @@ type Base struct {
CurrencyPairs currency.PairsManager
Features Features
HTTPTimeout time.Duration
HTTPUserAgent string
HTTPRecording bool
HTTPDebugging bool
BypassConfigFormatUpgrades bool

View File

@@ -110,9 +110,12 @@ func (e *EXMO) SetDefaults() {
},
}
e.Requester = request.New(e.Name,
e.Requester, err = request.New(e.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(exmoRateInterval, exmoRequestRate)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
e.API.Endpoints = e.NewEndpoints()
err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: exmoAPIURL,

View File

@@ -147,9 +147,12 @@ func (f *FTX) SetDefaults() {
},
}
f.Requester = request.New(f.Name,
f.Requester, err = request.New(f.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
f.API.Endpoints = f.NewEndpoints()
err = f.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: ftxAPIURL,

View File

@@ -130,8 +130,11 @@ func (g *Gateio) SetDefaults() {
},
},
}
g.Requester = request.New(g.Name,
g.Requester, err = request.New(g.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
g.API.Endpoints = g.NewEndpoints()
err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: gateioTradeURL,

View File

@@ -45,7 +45,10 @@ func TestMain(m *testing.M) {
log.Fatalf("Mock server error %s", err)
}
g.HTTPClient = newClient
err = g.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpointMap := g.API.Endpoints.GetURLMap()
for k := range endpointMap {
err = g.API.Endpoints.SetRunning(k, serverDetails)

View File

@@ -113,9 +113,12 @@ func (g *Gemini) SetDefaults() {
},
}
g.Requester = request.New(g.Name,
g.Requester, err = request.New(g.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
g.API.Endpoints = g.NewEndpoints()
err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: geminiAPIURL,

View File

@@ -129,9 +129,12 @@ func (h *HitBTC) SetDefaults() {
},
}
h.Requester = request.New(h.Name,
h.Requester, err = request.New(h.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
h.API.Endpoints = h.NewEndpoints()
err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: apiURL,

View File

@@ -158,9 +158,12 @@ func (h *HUOBI) SetDefaults() {
},
}
h.Requester = request.New(h.Name,
h.Requester, err = request.New(h.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
h.API.Endpoints = h.NewEndpoints()
err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: huobiAPIURL,

View File

@@ -68,8 +68,8 @@ type IBotExchange interface {
WithdrawCryptocurrencyFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
WithdrawFiatFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
WithdrawFiatFundsToInternationalBank(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
SetHTTPClientUserAgent(ua string)
GetHTTPClientUserAgent() string
SetHTTPClientUserAgent(ua string) error
GetHTTPClientUserAgent() (string, error)
SetClientProxyAddress(addr string) error
SupportsREST() bool
GetSubscriptions() ([]stream.ChannelSubscription, error)

View File

@@ -94,8 +94,11 @@ func (i *ItBit) SetDefaults() {
},
}
i.Requester = request.New(i.Name,
i.Requester, err = request.New(i.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
i.API.Endpoints = i.NewEndpoints()
err = i.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: itbitAPIURL,

View File

@@ -171,9 +171,12 @@ func (k *Kraken) SetDefaults() {
},
}
k.Requester = request.New(k.Name,
k.Requester, err = request.New(k.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(request.NewBasicRateLimit(krakenRateInterval, krakenRequestRate)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
k.API.Endpoints = k.NewEndpoints()
err = k.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: krakenAPIURL,

View File

@@ -108,8 +108,11 @@ func (l *Lbank) SetDefaults() {
},
},
}
l.Requester = request.New(l.Name,
l.Requester, err = request.New(l.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
l.API.Endpoints = l.NewEndpoints()
err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: lbankAPIURL,

View File

@@ -44,7 +44,10 @@ func TestMain(m *testing.M) {
log.Fatalf("Mock server error %s", err)
}
l.HTTPClient = newClient
err = l.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpoints := l.API.Endpoints.GetURLMap()
for k := range endpoints {
err = l.API.Endpoints.SetRunning(k, serverDetails)

View File

@@ -93,8 +93,11 @@ func (l *LocalBitcoins) SetDefaults() {
},
}
l.Requester = request.New(l.Name,
l.Requester, err = request.New(l.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
l.API.Endpoints = l.NewEndpoints()
err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: localbitcoinsAPIURL,

View File

@@ -133,11 +133,14 @@ func (o *OKCoin) SetDefaults() {
},
}
o.Requester = request.New(o.Name,
o.Requester, err = request.New(o.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
// TODO: Specify each individual endpoint rate limits as per docs
request.WithLimiter(request.NewBasicRateLimit(okCoinRateInterval, okCoinStandardRequestRate)),
)
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
o.API.Endpoints = o.NewEndpoints()
err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: okCoinAPIURL,

View File

@@ -191,11 +191,14 @@ func (o *OKEX) SetDefaults() {
},
}
o.Requester = request.New(o.Name,
o.Requester, err = request.New(o.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
// TODO: Specify each individual endpoint rate limits as per docs
request.WithLimiter(request.NewBasicRateLimit(okExRateInterval, okExRequestRate)),
)
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
o.API.Endpoints = o.NewEndpoints()
err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: okExAPIURL,

View File

@@ -45,7 +45,10 @@ func TestMain(m *testing.M) {
log.Fatalf("Mock server error %s", err)
}
p.HTTPClient = newClient
err = p.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpoints := p.API.Endpoints.GetURLMap()
for k := range endpoints {
err = p.API.Endpoints.SetRunning(k, serverDetails)

View File

@@ -133,9 +133,12 @@ func (p *Poloniex) SetDefaults() {
},
}
p.Requester = request.New(p.Name,
p.Requester, err = request.New(p.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
p.API.Endpoints = p.NewEndpoints()
err = p.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: poloniexAPIURL,

131
exchanges/request/client.go Normal file
View File

@@ -0,0 +1,131 @@
package request
import (
"errors"
"net/http"
"net/url"
"sync"
"time"
)
var (
// tracker is the global to maintain sanity between clients across all
// services using the request package.
tracker clientTracker
errNoProxyURLSupplied = errors.New("no proxy URL supplied")
errCannotReuseHTTPClient = errors.New("cannot reuse http client")
errHTTPClientIsNil = errors.New("http client is nil")
errHTTPClientNotFound = errors.New("http client not found")
)
// clientTracker attempts to maintain service/http.Client segregation
type clientTracker struct {
clients []*http.Client
sync.Mutex
}
// checkAndRegister stops the sharing of the same http.Client between services.
func (c *clientTracker) checkAndRegister(newClient *http.Client) error {
if newClient == nil {
return errHTTPClientIsNil
}
c.Lock()
defer c.Unlock()
for x := range c.clients {
if newClient == c.clients[x] {
return errCannotReuseHTTPClient
}
}
c.clients = append(c.clients, newClient)
return nil
}
// deRegister removes the *http.Client from being tracked
func (c *clientTracker) deRegister(oldClient *http.Client) error {
if oldClient == nil {
return errHTTPClientIsNil
}
c.Lock()
defer c.Unlock()
for x := range c.clients {
if oldClient != c.clients[x] {
continue
}
c.clients[x] = c.clients[len(c.clients)-1]
c.clients[len(c.clients)-1] = nil
c.clients = c.clients[:len(c.clients)-1]
return nil
}
return errHTTPClientNotFound
}
// client wraps over a http client for better protection
type client struct {
protected *http.Client
m sync.RWMutex
}
// newProtectedClient registers a http.Client to inhibit cross service usage and
// return a thread safe holder (*request.Client) with getter and setters for
// timeouts and transports.
func newProtectedClient(newClient *http.Client) (*client, error) {
if err := tracker.checkAndRegister(newClient); err != nil {
return nil, err
}
return &client{protected: newClient}, nil
}
// setProxy sets a proxy address for the client transport
func (c *client) setProxy(p *url.URL) error {
if p == nil || p.String() == "" {
return errNoProxyURLSupplied
}
c.m.Lock()
defer c.m.Unlock()
// Check transport first so we don't set something and then error.
tr, ok := c.protected.Transport.(*http.Transport)
if !ok {
return errTransportNotSet
}
// This closes idle connections before an attempt at reassignment and
// boots any dangly routines.
tr.CloseIdleConnections()
tr.Proxy = http.ProxyURL(p)
tr.TLSHandshakeTimeout = proxyTLSTimeout
return nil
}
// setHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and
// also the underlying transports idle connection timeout
func (c *client) setHTTPClientTimeout(timeout time.Duration) error {
c.m.Lock()
defer c.m.Unlock()
// Check transport first so we don't set something and then error.
tr, ok := c.protected.Transport.(*http.Transport)
if !ok {
return errTransportNotSet
}
// This closes idle connections before an attempt at reassignment and
// boots any dangly routines.
tr.CloseIdleConnections()
tr.IdleConnTimeout = timeout
c.protected.Timeout = timeout
return nil
}
// do sends request in a protected manner
func (c *client) do(request *http.Request) (resp *http.Response, err error) {
c.m.RLock()
resp, err = c.protected.Do(request)
c.m.RUnlock()
return
}
// release de-registers the underlying client
func (c *client) release() error {
c.m.Lock()
err := tracker.deRegister(c.protected)
c.m.Unlock()
return err
}

View File

@@ -0,0 +1,148 @@
package request
import (
"errors"
"net/http"
"net/url"
"testing"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
)
// this doesn't need to be included in binary
func (c *clientTracker) contains(check *http.Client) bool {
c.Lock()
defer c.Unlock()
for x := range c.clients {
if check == c.clients[x] {
return true
}
}
return false
}
func TestCheckAndRegister(t *testing.T) {
t.Parallel()
err := tracker.checkAndRegister(nil)
if !errors.Is(err, errHTTPClientIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
}
newLovelyClient := new(http.Client)
err = tracker.checkAndRegister(newLovelyClient)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if !tracker.contains(newLovelyClient) {
t.Fatalf("received: '%v' but expected: '%v'", false, true)
}
err = tracker.checkAndRegister(newLovelyClient)
if !errors.Is(err, errCannotReuseHTTPClient) {
t.Fatalf("received: '%v' but expected: '%v'", err, errCannotReuseHTTPClient)
}
}
func TestDeRegister(t *testing.T) {
t.Parallel()
err := tracker.deRegister(nil)
if !errors.Is(err, errHTTPClientIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
}
newLovelyClient := new(http.Client)
err = tracker.deRegister(newLovelyClient)
if !errors.Is(err, errHTTPClientNotFound) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientNotFound)
}
err = tracker.checkAndRegister(newLovelyClient)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if !tracker.contains(newLovelyClient) {
t.Fatalf("received: '%v' but expected: '%v'", false, true)
}
err = tracker.deRegister(newLovelyClient)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if tracker.contains(newLovelyClient) {
t.Fatalf("received: '%v' but expected: '%v'", true, false)
}
}
func TestNewProtectedClient(t *testing.T) {
t.Parallel()
if _, err := newProtectedClient(nil); !errors.Is(err, errHTTPClientIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
}
newLovelyClient := new(http.Client)
protec, err := newProtectedClient(newLovelyClient)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if protec.protected != newLovelyClient {
t.Fatal("unexpected value")
}
}
func TestClientSetProxy(t *testing.T) {
t.Parallel()
err := (&client{}).setProxy(nil)
if !errors.Is(err, errNoProxyURLSupplied) {
t.Fatalf("received: '%v' but expected: '%v'", err, errNoProxyURLSupplied)
}
pp, err := url.Parse("lol.com")
if err != nil {
t.Fatal(err)
}
err = (&client{protected: new(http.Client)}).setProxy(pp)
if !errors.Is(err, errTransportNotSet) {
t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet)
}
err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setProxy(pp)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}
func TestClientSetHTTPClientTimeout(t *testing.T) {
t.Parallel()
err := (&client{protected: new(http.Client)}).setHTTPClientTimeout(time.Second)
if !errors.Is(err, errTransportNotSet) {
t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet)
}
err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setHTTPClientTimeout(time.Second)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
}
func TestRelease(t *testing.T) {
t.Parallel()
newLovelyClient, err := newProtectedClient(common.NewHTTPClientWithTimeout(0))
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if !tracker.contains(newLovelyClient.protected) {
t.Fatalf("received: '%v' but expected: '%v'", false, true)
}
err = newLovelyClient.release()
if !errors.Is(err, nil) {
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
}
if tracker.contains(newLovelyClient.protected) {
t.Fatalf("received: '%v' but expected: '%v'", true, false)
}
}

View File

@@ -3,12 +3,19 @@ package request
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"golang.org/x/time/rate"
)
// Defines rate limiting errors
var (
ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled")
ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled")
)
// Const here define individual functionality sub types for rate limiting
const (
Unset EndpointLimit = iota
@@ -60,6 +67,9 @@ func NewBasicRateLimit(interval time.Duration, actions int) Limiter {
// InitiateRateLimit sleeps for designated end point rate limits
func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error {
if r == nil {
return ErrRequestSystemIsNil
}
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
return nil
}
@@ -73,16 +83,22 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro
// DisableRateLimiter disables the rate limiting system for the exchange
func (r *Requester) DisableRateLimiter() error {
if r == nil {
return ErrRequestSystemIsNil
}
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 0, 1) {
return errors.New("rate limiter already disabled")
return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyDisabled)
}
return nil
}
// EnableRateLimiter enables the rate limiting system for the exchange
func (r *Requester) EnableRateLimiter() error {
if r == nil {
return ErrRequestSystemIsNil
}
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 1, 0) {
return errors.New("rate limiter already enabled")
return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyEnabled)
}
return nil
}

View File

@@ -20,7 +20,9 @@ import (
)
var (
errRequestSystemIsNil = errors.New("request system is nil")
// ErrRequestSystemIsNil defines and error if the request system has not
// been set up yet.
ErrRequestSystemIsNil = errors.New("request system is nil")
errMaxRequestJobs = errors.New("max request jobs reached")
errRequestFunctionIsNil = errors.New("request function is nil")
errRequestItemNil = errors.New("request item is nil")
@@ -28,13 +30,18 @@ var (
errHeaderResponseMapIsNil = errors.New("header response map is nil")
errFailedToRetryRequest = errors.New("failed to retry request")
errContextRequired = errors.New("context is required")
errTransportNotSet = errors.New("transport not set, cannot set timeout")
)
// New returns a new Requester
func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requester {
func New(name string, httpRequester *http.Client, opts ...RequesterOption) (*Requester, error) {
protectedClient, err := newProtectedClient(httpRequester)
if err != nil {
return nil, fmt.Errorf("cannot set up a new requester for %s: %w", name, err)
}
r := &Requester{
HTTPClient: httpRequester,
Name: name,
_HTTPClient: protectedClient,
name: name,
backoff: DefaultBackoff(),
retryPolicy: DefaultRetryPolicy,
maxRetries: MaxRetryAttempts,
@@ -46,13 +53,13 @@ func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requ
o(r)
}
return r
return r, nil
}
// SendPayload handles sending HTTP/HTTPS requests
func (r *Requester) SendPayload(ctx context.Context, ep EndpointLimit, newRequest Generate) error {
if r == nil {
return errRequestSystemIsNil
return ErrRequestSystemIsNil
}
if ctx == nil {
@@ -107,8 +114,8 @@ func (i *Item) validateRequest(ctx context.Context, r *Requester) (*http.Request
req.Header.Add(k, v)
}
if r.UserAgent != "" && req.Header.Get(userAgent) == "" {
req.Header.Add(userAgent, r.UserAgent)
if r.userAgent != "" && req.Header.Get(userAgent) == "" {
req.Header.Add(userAgent, r.userAgent)
}
return req, nil
@@ -141,22 +148,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
}
if p.Verbose {
log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.Name, attempt, p.Path)
log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.name, attempt, p.Path)
for k, d := range req.Header {
log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.Name, k, d)
log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.name, k, d)
}
log.Debugf(log.RequestSys, "%s request type: %s", r.Name, p.Method)
log.Debugf(log.RequestSys, "%s request type: %s", r.name, p.Method)
if p.Body != nil {
log.Debugf(log.RequestSys, "%s request body: %v", r.Name, p.Body)
log.Debugf(log.RequestSys, "%s request body: %v", r.name, p.Body)
}
}
start := time.Now()
resp, err := r.HTTPClient.Do(req)
resp, err := r._HTTPClient.do(req)
if r.reporter != nil {
r.reporter.Latency(r.Name, p.Method, p.Path, time.Since(start))
r.reporter.Latency(r.name, p.Method, p.Path, time.Since(start))
}
if retry, checkErr := r.retryPolicy(resp, err); checkErr != nil {
@@ -191,7 +198,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
if p.Verbose {
log.Errorf(log.RequestSys,
"%s request has failed. Retrying request in %s, attempt %d",
r.Name,
r.name,
delay,
attempt)
}
@@ -213,7 +220,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
if p.HTTPRecording {
// This dumps http responses for future mocking implementations
err = mock.HTTPRecord(resp, r.Name, contents)
err = mock.HTTPRecord(resp, r.name, contents)
if err != nil {
return fmt.Errorf("mock recording failure %s", err)
}
@@ -228,21 +235,27 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
if resp.StatusCode < http.StatusOK ||
resp.StatusCode > http.StatusAccepted {
return fmt.Errorf("%s unsuccessful HTTP status code: %d raw response: %s",
r.Name,
r.name,
resp.StatusCode,
string(contents))
}
if p.HTTPDebugging {
dump, err := httputil.DumpResponse(resp, false)
dump, dumpErr := httputil.DumpResponse(resp, false)
if err != nil {
log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", err)
log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", dumpErr)
}
log.Debugf(log.RequestSys, "DumpResponse Headers (%v):\n%s", p.Path, dump)
log.Debugf(log.RequestSys, "DumpResponse Body (%v):\n %s", p.Path, string(contents))
}
resp.Body.Close()
err = resp.Body.Close()
if err != nil {
log.Errorf(log.RequestSys,
"%s failed to close request body %s",
r.name,
err)
}
if p.Verbose {
log.Debugf(log.RequestSys,
"HTTP status: %s, Code: %v",
@@ -251,7 +264,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
if !p.HTTPDebugging {
log.Debugf(log.RequestSys,
"%s raw response: %s",
r.Name,
r.name,
string(contents))
}
}
@@ -259,6 +272,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
}
}
func (r *Requester) drainBody(body io.ReadCloser) {
if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil {
log.Errorf(log.RequestSys,
"%s failed to drain request body %s",
r.name,
err)
}
if err := body.Close(); err != nil {
log.Errorf(log.RequestSys,
"%s failed to close request body %s",
r.name,
err)
}
}
// GetNonce returns a nonce for requests. This locks and enforces concurrent
// nonce FIFO on the buffered job channel
func (r *Requester) GetNonce(isNano bool) nonce.Value {
@@ -285,27 +314,57 @@ func (r *Requester) GetNonceMilli() nonce.Value {
return r.Nonce.GetInc()
}
// SetProxy sets a proxy address to the client transport
// SetProxy sets a proxy address for the client transport
func (r *Requester) SetProxy(p *url.URL) error {
if p.String() == "" {
return errors.New("no proxy URL supplied")
if r == nil {
return ErrRequestSystemIsNil
}
return r._HTTPClient.setProxy(p)
}
t, ok := r.HTTPClient.Transport.(*http.Transport)
if !ok {
return errors.New("transport not set, cannot set proxy")
// SetHTTPClient sets exchanges HTTP client
func (r *Requester) SetHTTPClient(newClient *http.Client) error {
if r == nil {
return ErrRequestSystemIsNil
}
t.Proxy = http.ProxyURL(p)
t.TLSHandshakeTimeout = proxyTLSTimeout
protectedClient, err := newProtectedClient(newClient)
if err != nil {
return err
}
r._HTTPClient = protectedClient
return nil
}
func (r *Requester) drainBody(body io.ReadCloser) {
defer body.Close()
if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil {
log.Errorf(log.RequestSys,
"%s failed to drain request body %s",
r.Name,
err)
// SetClientTimeout sets the timeout value for the exchanges HTTP Client and
// also the underlying transports idle connection timeout
func (r *Requester) SetHTTPClientTimeout(timeout time.Duration) error {
if r == nil {
return ErrRequestSystemIsNil
}
return r._HTTPClient.setHTTPClientTimeout(timeout)
}
// SetHTTPClientUserAgent sets the exchanges HTTP user agent
func (r *Requester) SetHTTPClientUserAgent(userAgent string) error {
if r == nil {
return ErrRequestSystemIsNil
}
r.userAgent = userAgent
return nil
}
// GetHTTPClientUserAgent gets the exchanges HTTP user agent
func (r *Requester) GetHTTPClientUserAgent() (string, error) {
if r == nil {
return "", ErrRequestSystemIsNil
}
return r.userAgent, nil
}
// Shutdown releases persistent memory for garbage collection.
func (r *Requester) Shutdown() error {
if r == nil {
return ErrRequestSystemIsNil
}
return r._HTTPClient.release()
}

View File

@@ -18,6 +18,7 @@ import (
"testing"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"golang.org/x/time/rate"
)
@@ -126,12 +127,15 @@ func TestNewRateLimit(t *testing.T) {
func TestCheckRequest(t *testing.T) {
t.Parallel()
r := New("TestRequest",
r, err := New("TestRequest",
new(http.Client))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
var check *Item
_, err := check.validateRequest(ctx, &Requester{})
_, err = check.validateRequest(ctx, &Requester{})
if err == nil {
t.Fatal(unexpected)
}
@@ -179,7 +183,7 @@ func TestCheckRequest(t *testing.T) {
}
// Test user agent set
r.UserAgent = "r00t axxs"
r.userAgent = "r00t axxs"
req, err := check.validateRequest(ctx, r)
if err != nil {
t.Fatal(err)
@@ -226,12 +230,15 @@ var globalshell = GlobalLimitTest{
func TestDoRequest(t *testing.T) {
t.Parallel()
r := New("test", new(http.Client), WithLimiter(&globalshell))
r, err := New("test", new(http.Client), WithLimiter(&globalshell))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
err := (*Requester)(nil).SendPayload(ctx, Unset, nil)
if !errors.Is(errRequestSystemIsNil, err) {
t.Fatalf("expected: %v but received: %v", errRequestSystemIsNil, err)
err = (*Requester)(nil).SendPayload(ctx, Unset, nil)
if !errors.Is(ErrRequestSystemIsNil, err) {
t.Fatalf("expected: %v but received: %v", ErrRequestSystemIsNil, err)
}
err = r.SendPayload(ctx, Unset, nil)
if !errors.Is(errRequestFunctionIsNil, err) {
@@ -305,8 +312,16 @@ func TestDoRequest(t *testing.T) {
// reset jobs
r.jobs = 0
r._HTTPClient, err = newProtectedClient(common.NewHTTPClientWithTimeout(0))
if err != nil {
t.Fatal(err)
}
// timeout checker
r.HTTPClient.Timeout = time.Millisecond * 50
err = r._HTTPClient.setHTTPClientTimeout(time.Millisecond * 50)
if err != nil {
t.Fatal(err)
}
err = r.SendPayload(ctx, UnAuth, func() (*Item, error) {
return &Item{Path: testURL + "/timeout"}, nil
})
@@ -314,7 +329,10 @@ func TestDoRequest(t *testing.T) {
t.Fatalf("received: %v but expected: %v", err, errFailedToRetryRequest)
}
// reset timeout
r.HTTPClient.Timeout = 0
err = r._HTTPClient.setHTTPClientTimeout(0)
if err != nil {
t.Fatal(err)
}
// Check JSON
var resp struct {
@@ -403,7 +421,10 @@ func TestDoRequest_Retries(t *testing.T) {
backoff := func(n int) time.Duration {
return 0
}
r := New("test", new(http.Client), WithBackoff(backoff))
r, err := New("test", new(http.Client), WithBackoff(backoff))
if err != nil {
t.Fatal(err)
}
var failed int32
var wg sync.WaitGroup
wg.Add(4)
@@ -444,8 +465,11 @@ func TestDoRequest_RetryNonRecoverable(t *testing.T) {
backoff := func(n int) time.Duration {
return 0
}
r := New("test", new(http.Client), WithBackoff(backoff))
err := r.SendPayload(context.Background(), Unset, func() (*Item, error) {
r, err := New("test", new(http.Client), WithBackoff(backoff))
if err != nil {
t.Fatal(err)
}
err = r.SendPayload(context.Background(), Unset, func() (*Item, error) {
return &Item{
Method: http.MethodGet,
Path: testURL + "/always-retry",
@@ -466,8 +490,11 @@ func TestDoRequest_NotRetryable(t *testing.T) {
backoff := func(n int) time.Duration {
return time.Duration(n) * time.Millisecond
}
r := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff))
err := r.SendPayload(context.Background(), Unset, func() (*Item, error) {
r, err := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff))
if err != nil {
t.Fatal(err)
}
err = r.SendPayload(context.Background(), Unset, func() (*Item, error) {
return &Item{
Method: http.MethodGet,
Path: testURL + "/always-retry",
@@ -480,17 +507,22 @@ func TestDoRequest_NotRetryable(t *testing.T) {
func TestGetNonce(t *testing.T) {
t.Parallel()
r := New("test",
r, err := New("test",
new(http.Client),
WithLimiter(&globalshell))
if err != nil {
t.Fatal(err)
}
if n1, n2 := r.GetNonce(false), r.GetNonce(false); n1 == n2 {
t.Fatal(unexpected)
}
r2 := New("test",
r2, err := New("test",
new(http.Client),
WithLimiter(&globalshell))
if err != nil {
t.Fatal(err)
}
if n1, n2 := r2.GetNonce(true), r2.GetNonce(true); n1 == n2 {
t.Fatal(unexpected)
}
@@ -498,9 +530,12 @@ func TestGetNonce(t *testing.T) {
func TestGetNonceMillis(t *testing.T) {
t.Parallel()
r := New("test",
r, err := New("test",
new(http.Client),
WithLimiter(&globalshell))
if err != nil {
t.Fatal(err)
}
if m1, m2 := r.GetNonceMilli(), r.GetNonceMilli(); m1 == m2 {
log.Fatal(unexpected)
}
@@ -508,9 +543,17 @@ func TestGetNonceMillis(t *testing.T) {
func TestSetProxy(t *testing.T) {
t.Parallel()
r := New("test",
var r *Requester
err := r.SetProxy(nil)
if !errors.Is(err, ErrRequestSystemIsNil) {
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
}
r, err = New("test",
&http.Client{Transport: new(http.Transport)},
WithLimiter(&globalshell))
if err != nil {
t.Fatal(err)
}
u, err := url.Parse("http://www.google.com")
if err != nil {
t.Fatal(err)
@@ -530,9 +573,12 @@ func TestSetProxy(t *testing.T) {
}
func TestBasicLimiter(t *testing.T) {
r := New("test",
r, err := New("test",
new(http.Client),
WithLimiter(NewBasicRateLimit(time.Second, 1)))
if err != nil {
t.Fatal(err)
}
i := Item{
Path: "http://www.google.com",
Method: http.MethodGet,
@@ -540,7 +586,7 @@ func TestBasicLimiter(t *testing.T) {
ctx := context.Background()
tn := time.Now()
err := r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil })
err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil })
if err != nil {
t.Fatal(err)
}
@@ -561,13 +607,16 @@ func TestBasicLimiter(t *testing.T) {
}
func TestEnableDisableRateLimit(t *testing.T) {
r := New("TestRequest",
r, err := New("TestRequest",
new(http.Client),
WithLimiter(NewBasicRateLimit(time.Minute, 1)))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
var resp interface{}
err := r.SendPayload(ctx, Auth, func() (*Item, error) {
err = r.SendPayload(ctx, Auth, func() (*Item, error) {
return &Item{
Method: http.MethodGet,
Path: testURL,
@@ -635,3 +684,71 @@ func TestEnableDisableRateLimit(t *testing.T) {
// Correct test
}
}
func TestSetHTTPClient(t *testing.T) {
var r *Requester
err := r.SetHTTPClient(nil)
if !errors.Is(err, ErrRequestSystemIsNil) {
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
}
client := new(http.Client)
r = new(Requester)
err = r.SetHTTPClient(client)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
}
err = r.SetHTTPClient(client)
if !errors.Is(err, errCannotReuseHTTPClient) {
t.Fatalf("received: '%v', but expected: '%v'", err, errCannotReuseHTTPClient)
}
}
func TestSetHTTPClientTimeout(t *testing.T) {
var r *Requester
err := r.SetHTTPClientTimeout(0)
if !errors.Is(err, ErrRequestSystemIsNil) {
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
}
r = new(Requester)
err = r.SetHTTPClient(common.NewHTTPClientWithTimeout(2))
if err != nil {
t.Fatal(err)
}
err = r.SetHTTPClientTimeout(time.Second)
if !errors.Is(err, nil) {
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
}
}
func TestSetHTTPClientUserAgent(t *testing.T) {
var r *Requester
err := r.SetHTTPClientUserAgent("")
if !errors.Is(err, ErrRequestSystemIsNil) {
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
}
r = new(Requester)
err = r.SetHTTPClientUserAgent("")
if !errors.Is(err, nil) {
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
}
}
func TestGetHTTPClientUserAgent(t *testing.T) {
var r *Requester
_, err := r.GetHTTPClientUserAgent()
if !errors.Is(err, ErrRequestSystemIsNil) {
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
}
r = new(Requester)
err = r.SetHTTPClientUserAgent("sillyness")
if !errors.Is(err, nil) {
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
}
ua, err := r.GetHTTPClientUserAgent()
if !errors.Is(err, nil) {
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
}
if ua != "sillyness" {
t.Fatal("unexpected value")
}
}

View File

@@ -28,11 +28,11 @@ var (
// Requester struct for the request client
type Requester struct {
HTTPClient *http.Client
_HTTPClient *client
limiter Limiter
reporter Reporter
Name string
UserAgent string
name string
userAgent string
maxRetries int
jobs int32
Nonce nonce.Nonce

View File

@@ -198,11 +198,12 @@ func (c *CustomEx) WithdrawFiatFundsToInternationalBank(ctx context.Context, wit
return nil, nil
}
func (c *CustomEx) SetHTTPClientUserAgent(ua string) {
func (c *CustomEx) SetHTTPClientUserAgent(ua string) error {
return nil
}
func (c *CustomEx) GetHTTPClientUserAgent() string {
return ""
func (c *CustomEx) GetHTTPClientUserAgent() (string, error) {
return "", nil
}
func (c *CustomEx) SetClientProxyAddress(addr string) error {

View File

@@ -97,10 +97,13 @@ func (y *Yobit) SetDefaults() {
},
}
y.Requester = request.New(y.Name,
y.Requester, err = request.New(y.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
// Server responses are cached every 2 seconds.
request.WithLimiter(request.NewBasicRateLimit(time.Second, 1)))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
y.API.Endpoints = y.NewEndpoints()
err = y.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: apiPublicURL,

View File

@@ -47,7 +47,10 @@ func TestMain(m *testing.M) {
log.Fatalf("Mock server error %s", err)
}
z.HTTPClient = newClient
err = z.SetHTTPClient(newClient)
if err != nil {
log.Fatalf("Mock server error %s", err)
}
endpoints := z.API.Endpoints.GetURLMap()
for k := range endpoints {
err = z.API.Endpoints.SetRunning(k, serverDetails)

View File

@@ -130,9 +130,12 @@ func (z *ZB) SetDefaults() {
},
}
z.Requester = request.New(z.Name,
z.Requester, err = request.New(z.Name,
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
request.WithLimiter(SetRateLimit()))
if err != nil {
log.Errorln(log.ExchangeSys, err)
}
z.API.Endpoints = z.NewEndpoints()
err = z.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
exchange.RestSpot: zbTradeURL,