mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-24 07:26:47 +00:00
Request: Fix http.Client race issue when setting transport layer proxy and timeouts (#885)
* backtester/request: trying to fix panic (WIP) * request: fix race for transport layer * request: linter issue fix * request: more linter issues * requester: Add function to remove the tracking of underlying http client and add to engine unload exchange. * request: add more context to error return * request: Fix after cherry pick issues * request: fix niterinos * exchanges: change return to package variable * request: changed named Co-authored-by: Ryan O'Hara-Reid <ryan.oharareid@thrasher.io>
This commit is contained in:
@@ -1288,15 +1288,19 @@ func updateFile(name string) error {
|
||||
// SendGetReq sends get req
|
||||
func sendGetReq(path string, result interface{}) error {
|
||||
var requester *request.Requester
|
||||
var err error
|
||||
if strings.Contains(path, "github") {
|
||||
requester = request.New("Apichecker",
|
||||
requester, err = request.New("Apichecker",
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(time.Hour, 60)))
|
||||
} else {
|
||||
requester = request.New("Apichecker",
|
||||
requester, err = request.New("Apichecker",
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(time.Second, 100)))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item := &request.Item{
|
||||
Method: http.MethodGet,
|
||||
Path: path,
|
||||
@@ -1309,9 +1313,12 @@ func sendGetReq(path string, result interface{}) error {
|
||||
|
||||
// sendAuthReq sends auth req
|
||||
func sendAuthReq(method, path string, result interface{}) error {
|
||||
requester := request.New("Apichecker",
|
||||
requester, err := request.New("Apichecker",
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(time.Second*10, 100)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
item := &request.Item{
|
||||
Method: method,
|
||||
Path: path,
|
||||
|
||||
@@ -36,10 +36,14 @@ func (c *Coinmarketcap) SetDefaults() {
|
||||
c.Verbose = false
|
||||
c.APIUrl = baseURL
|
||||
c.APIVersion = version
|
||||
c.Requester = request.New(c.Name,
|
||||
var err error
|
||||
c.Requester, err = request.New(c.Name,
|
||||
common.NewHTTPClientWithTimeout(defaultTimeOut),
|
||||
request.WithLimiter(request.NewBasicRateLimit(RateInterval, BasicRequestRate)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorln(log.Global, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup sets user configuration
|
||||
|
||||
@@ -23,10 +23,11 @@ func (c *CurrencyConverter) Setup(config base.Settings) error {
|
||||
c.Enabled = config.Enabled
|
||||
c.Verbose = config.Verbose
|
||||
c.PrimaryProvider = config.PrimaryProvider
|
||||
c.Requester = request.New(c.Name,
|
||||
var err error
|
||||
c.Requester, err = request.New(c.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut),
|
||||
request.WithLimiter(request.NewBasicRateLimit(rateInterval, requestRate)))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRates is a wrapper function to return rates
|
||||
|
||||
@@ -43,10 +43,10 @@ func (c *CurrencyLayer) Setup(config base.Settings) error {
|
||||
c.Verbose = config.Verbose
|
||||
c.PrimaryProvider = config.PrimaryProvider
|
||||
// Rate limit is based off a monthly counter - Open limit used.
|
||||
c.Requester = request.New(c.Name,
|
||||
var err error
|
||||
c.Requester, err = request.New(c.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRates is a wrapper function to return rates for GoCryptoTrader
|
||||
|
||||
@@ -33,9 +33,10 @@ func (e *ExchangeRateHost) Setup(config base.Settings) error {
|
||||
e.Enabled = config.Enabled
|
||||
e.Verbose = config.Verbose
|
||||
e.PrimaryProvider = config.PrimaryProvider
|
||||
e.Requester = request.New(e.Name,
|
||||
var err error
|
||||
e.Requester, err = request.New(e.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLatestRates returns a list of forex rates based on the supplied params
|
||||
|
||||
@@ -29,10 +29,11 @@ func (e *ExchangeRates) Setup(config base.Settings) error {
|
||||
e.PrimaryProvider = config.PrimaryProvider
|
||||
e.APIKey = config.APIKey
|
||||
e.APIKeyLvl = config.APIKeyLvl
|
||||
e.Requester = request.New(e.Name,
|
||||
var err error
|
||||
e.Requester, err = request.New(e.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut),
|
||||
request.WithLimiter(request.NewBasicRateLimit(rateLimitInterval, requestRate)))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *ExchangeRates) cleanCurrencies(baseCurrency, symbols string) string {
|
||||
|
||||
@@ -36,9 +36,10 @@ func (f *Fixer) Setup(config base.Settings) error {
|
||||
f.Name = config.Name
|
||||
f.Verbose = config.Verbose
|
||||
f.PrimaryProvider = config.PrimaryProvider
|
||||
f.Requester = request.New(f.Name,
|
||||
var err error
|
||||
f.Requester, err = request.New(f.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetSupportedCurrencies returns supported currencies
|
||||
|
||||
@@ -37,9 +37,10 @@ func (o *OXR) Setup(config base.Settings) error {
|
||||
o.Name = config.Name
|
||||
o.Verbose = config.Verbose
|
||||
o.PrimaryProvider = config.PrimaryProvider
|
||||
o.Requester = request.New(o.Name,
|
||||
var err error
|
||||
o.Requester, err = request.New(o.Name,
|
||||
common.NewHTTPClientWithTimeout(base.DefaultTimeOut))
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRates is a wrapper function to return rates
|
||||
|
||||
@@ -94,12 +94,16 @@ func (m *ExchangeManager) RemoveExchange(exchName string) error {
|
||||
if m.Len() == 0 {
|
||||
return ErrNoExchangesLoaded
|
||||
}
|
||||
_, err := m.GetExchangeByName(exchName)
|
||||
exch, err := m.GetExchangeByName(exchName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.m.Lock()
|
||||
defer m.m.Unlock()
|
||||
err = exch.GetBase().Requester.Shutdown()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.exchanges, strings.ToLower(exchName))
|
||||
log.Infof(log.ExchangeSys, "%s exchange unloaded successfully.\n", exchName)
|
||||
return nil
|
||||
|
||||
@@ -73,8 +73,11 @@ func (a *Alphapoint) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
a.Requester = request.New(a.Name,
|
||||
a.Requester, err = request.New(a.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
}
|
||||
|
||||
// FetchTradablePairs returns a list of the exchanges tradable pairs
|
||||
|
||||
@@ -46,7 +46,10 @@ func TestMain(m *testing.M) {
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
b.HTTPClient = newClient
|
||||
err = b.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpointMap := b.API.Endpoints.GetURLMap()
|
||||
for k := range endpointMap {
|
||||
err = b.API.Endpoints.SetRunning(k, serverDetails)
|
||||
|
||||
@@ -186,9 +186,12 @@ func (b *Binance) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: spotAPIURL,
|
||||
|
||||
@@ -165,9 +165,12 @@ func (b *Bitfinex) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: bitfinexAPIURLBase,
|
||||
|
||||
@@ -94,9 +94,12 @@ func (b *Bitflyer) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: japanURL,
|
||||
|
||||
@@ -129,9 +129,12 @@ func (b *Bithumb) SetDefaults() {
|
||||
},
|
||||
},
|
||||
}
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: apiURL,
|
||||
|
||||
@@ -124,9 +124,12 @@ func (b *Bitmex) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: bitmexAPIURL,
|
||||
|
||||
@@ -46,7 +46,10 @@ func TestMain(m *testing.M) {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
|
||||
b.HTTPClient = newClient
|
||||
err = b.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpointMap := b.API.Endpoints.GetURLMap()
|
||||
for k := range endpointMap {
|
||||
err = b.API.Endpoints.SetRunning(k, serverDetails+"/api")
|
||||
|
||||
@@ -129,9 +129,12 @@ func (b *Bitstamp) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(bitstampRateInterval, bitstampRequestRate)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: bitstampAPIURL,
|
||||
|
||||
@@ -120,9 +120,12 @@ func (b *Bittrex) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
|
||||
|
||||
@@ -119,9 +119,12 @@ func (b *BTCMarkets) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: btcMarketsAPIURL,
|
||||
|
||||
@@ -149,9 +149,12 @@ func (b *BTSE) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name,
|
||||
b.Requester, err = request.New(b.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
b.API.Endpoints = b.NewEndpoints()
|
||||
err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: btseAPIURL,
|
||||
|
||||
@@ -130,9 +130,12 @@ func (c *CoinbasePro) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
c.Requester = request.New(c.Name,
|
||||
c.Requester, err = request.New(c.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
c.API.Endpoints = c.NewEndpoints()
|
||||
err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: coinbaseproAPIURL,
|
||||
|
||||
@@ -113,8 +113,11 @@ func (c *COINUT) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
c.Requester = request.New(c.Name,
|
||||
c.Requester, err = request.New(c.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
c.API.Endpoints = c.NewEndpoints()
|
||||
err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: coinutAPIURL,
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/currencystate"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/protocol"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/request"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/stream"
|
||||
"github.com/thrasher-corp/gocryptotrader/exchanges/trade"
|
||||
"github.com/thrasher-corp/gocryptotrader/log"
|
||||
@@ -46,56 +44,8 @@ var (
|
||||
ErrAuthenticatedRequestWithoutCredentialsSet = errors.New("authenticated HTTP request called but not supported due to unset/default API keys")
|
||||
|
||||
errEndpointStringNotFound = errors.New("endpoint string not found")
|
||||
errTransportNotSet = errors.New("transport not set, cannot set timeout")
|
||||
|
||||
// ErrPairNotFound is an error message for when unable to find a currency pair
|
||||
ErrPairNotFound = errors.New("pair not found")
|
||||
)
|
||||
|
||||
func (b *Base) checkAndInitRequester() {
|
||||
if b.Requester == nil {
|
||||
b.Requester = request.New(b.Name,
|
||||
&http.Client{Transport: new(http.Transport)})
|
||||
}
|
||||
}
|
||||
|
||||
// SetHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and
|
||||
// also the underlying transports idle connection timeout
|
||||
func (b *Base) SetHTTPClientTimeout(t time.Duration) error {
|
||||
b.checkAndInitRequester()
|
||||
b.Requester.HTTPClient.Timeout = t
|
||||
tr, ok := b.Requester.HTTPClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return errTransportNotSet
|
||||
}
|
||||
tr.IdleConnTimeout = t
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHTTPClient sets exchanges HTTP client
|
||||
func (b *Base) SetHTTPClient(h *http.Client) {
|
||||
b.checkAndInitRequester()
|
||||
b.Requester.HTTPClient = h
|
||||
}
|
||||
|
||||
// GetHTTPClient gets the exchanges HTTP client
|
||||
func (b *Base) GetHTTPClient() *http.Client {
|
||||
b.checkAndInitRequester()
|
||||
return b.Requester.HTTPClient
|
||||
}
|
||||
|
||||
// SetHTTPClientUserAgent sets the exchanges HTTP user agent
|
||||
func (b *Base) SetHTTPClientUserAgent(ua string) {
|
||||
b.checkAndInitRequester()
|
||||
b.Requester.UserAgent = ua
|
||||
b.HTTPUserAgent = ua
|
||||
}
|
||||
|
||||
// GetHTTPClientUserAgent gets the exchanges HTTP user agent
|
||||
func (b *Base) GetHTTPClientUserAgent() string {
|
||||
return b.HTTPUserAgent
|
||||
}
|
||||
|
||||
// SetClientProxyAddress sets a proxy address for REST and websocket requests
|
||||
func (b *Base) SetClientProxyAddress(addr string) error {
|
||||
if addr == "" {
|
||||
@@ -441,17 +391,16 @@ func (b *Base) GetEnabledPairs(a asset.Item) (currency.Pairs, error) {
|
||||
// GetRequestFormattedPairAndAssetType is a method that returns the enabled currency pair of
|
||||
// along with its asset type. Only use when there is no chance of the same name crossing over
|
||||
func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, asset.Item, error) {
|
||||
assetTypes := b.GetAssetTypes(false)
|
||||
var response currency.Pair
|
||||
assetTypes := b.GetAssetTypes(true)
|
||||
for i := range assetTypes {
|
||||
format, err := b.GetPairFormat(assetTypes[i], true)
|
||||
if err != nil {
|
||||
return response, assetTypes[i], err
|
||||
return currency.EMPTYPAIR, assetTypes[i], err
|
||||
}
|
||||
|
||||
pairs, err := b.CurrencyPairs.GetPairs(assetTypes[i], true)
|
||||
if err != nil {
|
||||
return response, assetTypes[i], err
|
||||
return currency.EMPTYPAIR, assetTypes[i], err
|
||||
}
|
||||
|
||||
for j := range pairs {
|
||||
@@ -461,8 +410,7 @@ func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, ass
|
||||
}
|
||||
}
|
||||
}
|
||||
return response, "",
|
||||
fmt.Errorf("%s %w", p, ErrPairNotFound)
|
||||
return currency.EMPTYPAIR, "", fmt.Errorf("%s %w", p, currency.ErrPairNotFound)
|
||||
}
|
||||
|
||||
// GetAvailablePairs is a method that returns the available currency pairs
|
||||
@@ -607,7 +555,10 @@ func (b *Base) SetupDefaults(exch *config.Exchange) error {
|
||||
|
||||
b.HTTPDebugging = exch.HTTPDebugging
|
||||
b.BypassConfigFormatUpgrades = exch.CurrencyPairs.BypassConfigFormatUpgrades
|
||||
b.SetHTTPClientUserAgent(exch.HTTPUserAgent)
|
||||
err = b.SetHTTPClientUserAgent(exch.HTTPUserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.SetCurrencyPairFormat()
|
||||
|
||||
err = b.SetConfigPairs()
|
||||
|
||||
@@ -5,9 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -193,76 +191,21 @@ func TestSetDefaultEndpoints(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := Base{Name: "asdf"}
|
||||
err := r.SetHTTPClientTimeout(time.Second * 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.GetHTTPClient().Timeout != time.Second*5 {
|
||||
t.Fatalf("TestHTTPClient unexpected value")
|
||||
}
|
||||
|
||||
r.Requester = nil
|
||||
newClient := new(http.Client)
|
||||
newClient.Timeout = time.Second * 10
|
||||
|
||||
r.SetHTTPClient(newClient)
|
||||
if r.GetHTTPClient().Timeout != time.Second*10 {
|
||||
t.Fatalf("TestHTTPClient unexpected value")
|
||||
}
|
||||
|
||||
r.Requester = nil
|
||||
if r.GetHTTPClient() == nil {
|
||||
t.Fatalf("TestHTTPClient unexpected value")
|
||||
}
|
||||
|
||||
b := Base{Name: "RAWR"}
|
||||
|
||||
b.Requester = request.New(b.Name, new(http.Client))
|
||||
err = b.SetHTTPClientTimeout(time.Second * 5)
|
||||
if !errors.Is(err, errTransportNotSet) {
|
||||
t.Fatalf("received: %v but expected: %v", err, errTransportNotSet)
|
||||
}
|
||||
|
||||
b.Requester = request.New(b.Name, &http.Client{Transport: new(http.Transport)})
|
||||
err = b.SetHTTPClientTimeout(time.Second * 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if b.GetHTTPClient().Timeout != time.Second*5 {
|
||||
t.Fatalf("TestHTTPClient unexpected value")
|
||||
}
|
||||
|
||||
newClient = new(http.Client)
|
||||
newClient.Timeout = time.Second * 10
|
||||
|
||||
b.SetHTTPClient(newClient)
|
||||
if b.GetHTTPClient().Timeout != time.Second*10 {
|
||||
t.Fatalf("TestHTTPClient unexpected value")
|
||||
}
|
||||
|
||||
b.SetHTTPClientUserAgent("epicUserAgent")
|
||||
if !strings.Contains(b.GetHTTPClientUserAgent(), "epicUserAgent") {
|
||||
t.Error("user agent not set properly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetClientProxyAddress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requester := request.New("rawr",
|
||||
requester, err := request.New("rawr",
|
||||
common.NewHTTPClientWithTimeout(time.Second*15))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newBase := Base{
|
||||
Name: "rawr",
|
||||
Requester: requester}
|
||||
|
||||
newBase.Websocket = stream.New()
|
||||
err := newBase.SetClientProxyAddress("")
|
||||
err = newBase.SetClientProxyAddress("")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -289,13 +232,6 @@ func TestSetClientProxyAddress(t *testing.T) {
|
||||
if newBase.Websocket.GetProxyAddress() != "http://www.valid.com" {
|
||||
t.Error("SetClientProxyAddress error", err)
|
||||
}
|
||||
|
||||
// Nil out transport
|
||||
newBase.Requester.HTTPClient.Transport = nil
|
||||
err = newBase.SetClientProxyAddress("http://www.valid.com")
|
||||
if err == nil {
|
||||
t.Error("error cannot be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFeatureDefaults(t *testing.T) {
|
||||
@@ -1268,7 +1204,16 @@ func TestSetAPIKeys(t *testing.T) {
|
||||
func TestSetupDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b = Base{Name: "awesomeTest"}
|
||||
newRequester, err := request.New("testSetupDefaults",
|
||||
common.NewHTTPClientWithTimeout(0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b = Base{
|
||||
Name: "awesomeTest",
|
||||
Requester: newRequester,
|
||||
}
|
||||
cfg := config.Exchange{
|
||||
HTTPTimeout: time.Duration(-1),
|
||||
API: config.APIConfig{
|
||||
@@ -1276,7 +1221,7 @@ func TestSetupDefaults(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := b.SetupDefaults(&cfg)
|
||||
err = b.SetupDefaults(&cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2055,25 +2000,34 @@ func TestCheckTransientError(t *testing.T) {
|
||||
|
||||
func TestDisableEnableRateLimiter(t *testing.T) {
|
||||
b := Base{}
|
||||
b.checkAndInitRequester()
|
||||
err := b.EnableRateLimiter()
|
||||
if err == nil {
|
||||
t.Fatal("error cannot be nil")
|
||||
if !errors.Is(err, request.ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRequestSystemIsNil)
|
||||
}
|
||||
|
||||
err = b.DisableRateLimiter()
|
||||
b.Requester, err = request.New("testingRateLimiter", common.NewHTTPClientWithTimeout(0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = b.DisableRateLimiter()
|
||||
if err == nil {
|
||||
t.Fatal("error cannot be nil")
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
err = b.DisableRateLimiter()
|
||||
if !errors.Is(err, request.ErrRateLimiterAlreadyDisabled) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyDisabled)
|
||||
}
|
||||
|
||||
err = b.EnableRateLimiter()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
err = b.EnableRateLimiter()
|
||||
if !errors.Is(err, request.ErrRateLimiterAlreadyEnabled) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyEnabled)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -219,7 +219,6 @@ type Base struct {
|
||||
CurrencyPairs currency.PairsManager
|
||||
Features Features
|
||||
HTTPTimeout time.Duration
|
||||
HTTPUserAgent string
|
||||
HTTPRecording bool
|
||||
HTTPDebugging bool
|
||||
BypassConfigFormatUpgrades bool
|
||||
|
||||
@@ -110,9 +110,12 @@ func (e *EXMO) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
e.Requester = request.New(e.Name,
|
||||
e.Requester, err = request.New(e.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(exmoRateInterval, exmoRequestRate)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
e.API.Endpoints = e.NewEndpoints()
|
||||
err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: exmoAPIURL,
|
||||
|
||||
@@ -147,9 +147,12 @@ func (f *FTX) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
f.Requester = request.New(f.Name,
|
||||
f.Requester, err = request.New(f.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
f.API.Endpoints = f.NewEndpoints()
|
||||
err = f.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: ftxAPIURL,
|
||||
|
||||
@@ -130,8 +130,11 @@ func (g *Gateio) SetDefaults() {
|
||||
},
|
||||
},
|
||||
}
|
||||
g.Requester = request.New(g.Name,
|
||||
g.Requester, err = request.New(g.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
g.API.Endpoints = g.NewEndpoints()
|
||||
err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: gateioTradeURL,
|
||||
|
||||
@@ -45,7 +45,10 @@ func TestMain(m *testing.M) {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
|
||||
g.HTTPClient = newClient
|
||||
err = g.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpointMap := g.API.Endpoints.GetURLMap()
|
||||
for k := range endpointMap {
|
||||
err = g.API.Endpoints.SetRunning(k, serverDetails)
|
||||
|
||||
@@ -113,9 +113,12 @@ func (g *Gemini) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
g.Requester = request.New(g.Name,
|
||||
g.Requester, err = request.New(g.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
g.API.Endpoints = g.NewEndpoints()
|
||||
err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: geminiAPIURL,
|
||||
|
||||
@@ -129,9 +129,12 @@ func (h *HitBTC) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
h.Requester = request.New(h.Name,
|
||||
h.Requester, err = request.New(h.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
h.API.Endpoints = h.NewEndpoints()
|
||||
err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: apiURL,
|
||||
|
||||
@@ -158,9 +158,12 @@ func (h *HUOBI) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
h.Requester = request.New(h.Name,
|
||||
h.Requester, err = request.New(h.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
h.API.Endpoints = h.NewEndpoints()
|
||||
err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: huobiAPIURL,
|
||||
|
||||
@@ -68,8 +68,8 @@ type IBotExchange interface {
|
||||
WithdrawCryptocurrencyFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
|
||||
WithdrawFiatFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
|
||||
WithdrawFiatFundsToInternationalBank(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error)
|
||||
SetHTTPClientUserAgent(ua string)
|
||||
GetHTTPClientUserAgent() string
|
||||
SetHTTPClientUserAgent(ua string) error
|
||||
GetHTTPClientUserAgent() (string, error)
|
||||
SetClientProxyAddress(addr string) error
|
||||
SupportsREST() bool
|
||||
GetSubscriptions() ([]stream.ChannelSubscription, error)
|
||||
|
||||
@@ -94,8 +94,11 @@ func (i *ItBit) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
i.Requester = request.New(i.Name,
|
||||
i.Requester, err = request.New(i.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
i.API.Endpoints = i.NewEndpoints()
|
||||
err = i.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: itbitAPIURL,
|
||||
|
||||
@@ -171,9 +171,12 @@ func (k *Kraken) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
k.Requester = request.New(k.Name,
|
||||
k.Requester, err = request.New(k.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(request.NewBasicRateLimit(krakenRateInterval, krakenRequestRate)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
k.API.Endpoints = k.NewEndpoints()
|
||||
err = k.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: krakenAPIURL,
|
||||
|
||||
@@ -108,8 +108,11 @@ func (l *Lbank) SetDefaults() {
|
||||
},
|
||||
},
|
||||
}
|
||||
l.Requester = request.New(l.Name,
|
||||
l.Requester, err = request.New(l.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
l.API.Endpoints = l.NewEndpoints()
|
||||
err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: lbankAPIURL,
|
||||
|
||||
@@ -44,7 +44,10 @@ func TestMain(m *testing.M) {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
|
||||
l.HTTPClient = newClient
|
||||
err = l.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpoints := l.API.Endpoints.GetURLMap()
|
||||
for k := range endpoints {
|
||||
err = l.API.Endpoints.SetRunning(k, serverDetails)
|
||||
|
||||
@@ -93,8 +93,11 @@ func (l *LocalBitcoins) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
l.Requester = request.New(l.Name,
|
||||
l.Requester, err = request.New(l.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
l.API.Endpoints = l.NewEndpoints()
|
||||
err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: localbitcoinsAPIURL,
|
||||
|
||||
@@ -133,11 +133,14 @@ func (o *OKCoin) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
o.Requester = request.New(o.Name,
|
||||
o.Requester, err = request.New(o.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
// TODO: Specify each individual endpoint rate limits as per docs
|
||||
request.WithLimiter(request.NewBasicRateLimit(okCoinRateInterval, okCoinStandardRequestRate)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
o.API.Endpoints = o.NewEndpoints()
|
||||
err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: okCoinAPIURL,
|
||||
|
||||
@@ -191,11 +191,14 @@ func (o *OKEX) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
o.Requester = request.New(o.Name,
|
||||
o.Requester, err = request.New(o.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
// TODO: Specify each individual endpoint rate limits as per docs
|
||||
request.WithLimiter(request.NewBasicRateLimit(okExRateInterval, okExRequestRate)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
o.API.Endpoints = o.NewEndpoints()
|
||||
err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: okExAPIURL,
|
||||
|
||||
@@ -45,7 +45,10 @@ func TestMain(m *testing.M) {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
|
||||
p.HTTPClient = newClient
|
||||
err = p.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpoints := p.API.Endpoints.GetURLMap()
|
||||
for k := range endpoints {
|
||||
err = p.API.Endpoints.SetRunning(k, serverDetails)
|
||||
|
||||
@@ -133,9 +133,12 @@ func (p *Poloniex) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
p.Requester = request.New(p.Name,
|
||||
p.Requester, err = request.New(p.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
p.API.Endpoints = p.NewEndpoints()
|
||||
err = p.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: poloniexAPIURL,
|
||||
|
||||
131
exchanges/request/client.go
Normal file
131
exchanges/request/client.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// tracker is the global to maintain sanity between clients across all
|
||||
// services using the request package.
|
||||
tracker clientTracker
|
||||
|
||||
errNoProxyURLSupplied = errors.New("no proxy URL supplied")
|
||||
errCannotReuseHTTPClient = errors.New("cannot reuse http client")
|
||||
errHTTPClientIsNil = errors.New("http client is nil")
|
||||
errHTTPClientNotFound = errors.New("http client not found")
|
||||
)
|
||||
|
||||
// clientTracker attempts to maintain service/http.Client segregation
|
||||
type clientTracker struct {
|
||||
clients []*http.Client
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// checkAndRegister stops the sharing of the same http.Client between services.
|
||||
func (c *clientTracker) checkAndRegister(newClient *http.Client) error {
|
||||
if newClient == nil {
|
||||
return errHTTPClientIsNil
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for x := range c.clients {
|
||||
if newClient == c.clients[x] {
|
||||
return errCannotReuseHTTPClient
|
||||
}
|
||||
}
|
||||
c.clients = append(c.clients, newClient)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deRegister removes the *http.Client from being tracked
|
||||
func (c *clientTracker) deRegister(oldClient *http.Client) error {
|
||||
if oldClient == nil {
|
||||
return errHTTPClientIsNil
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for x := range c.clients {
|
||||
if oldClient != c.clients[x] {
|
||||
continue
|
||||
}
|
||||
c.clients[x] = c.clients[len(c.clients)-1]
|
||||
c.clients[len(c.clients)-1] = nil
|
||||
c.clients = c.clients[:len(c.clients)-1]
|
||||
return nil
|
||||
}
|
||||
return errHTTPClientNotFound
|
||||
}
|
||||
|
||||
// client wraps over a http client for better protection
|
||||
type client struct {
|
||||
protected *http.Client
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
// newProtectedClient registers a http.Client to inhibit cross service usage and
|
||||
// return a thread safe holder (*request.Client) with getter and setters for
|
||||
// timeouts and transports.
|
||||
func newProtectedClient(newClient *http.Client) (*client, error) {
|
||||
if err := tracker.checkAndRegister(newClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &client{protected: newClient}, nil
|
||||
}
|
||||
|
||||
// setProxy sets a proxy address for the client transport
|
||||
func (c *client) setProxy(p *url.URL) error {
|
||||
if p == nil || p.String() == "" {
|
||||
return errNoProxyURLSupplied
|
||||
}
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
// Check transport first so we don't set something and then error.
|
||||
tr, ok := c.protected.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return errTransportNotSet
|
||||
}
|
||||
// This closes idle connections before an attempt at reassignment and
|
||||
// boots any dangly routines.
|
||||
tr.CloseIdleConnections()
|
||||
tr.Proxy = http.ProxyURL(p)
|
||||
tr.TLSHandshakeTimeout = proxyTLSTimeout
|
||||
return nil
|
||||
}
|
||||
|
||||
// setHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and
|
||||
// also the underlying transports idle connection timeout
|
||||
func (c *client) setHTTPClientTimeout(timeout time.Duration) error {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
// Check transport first so we don't set something and then error.
|
||||
tr, ok := c.protected.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return errTransportNotSet
|
||||
}
|
||||
// This closes idle connections before an attempt at reassignment and
|
||||
// boots any dangly routines.
|
||||
tr.CloseIdleConnections()
|
||||
tr.IdleConnTimeout = timeout
|
||||
c.protected.Timeout = timeout
|
||||
return nil
|
||||
}
|
||||
|
||||
// do sends request in a protected manner
|
||||
func (c *client) do(request *http.Request) (resp *http.Response, err error) {
|
||||
c.m.RLock()
|
||||
resp, err = c.protected.Do(request)
|
||||
c.m.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// release de-registers the underlying client
|
||||
func (c *client) release() error {
|
||||
c.m.Lock()
|
||||
err := tracker.deRegister(c.protected)
|
||||
c.m.Unlock()
|
||||
return err
|
||||
}
|
||||
148
exchanges/request/client_test.go
Normal file
148
exchanges/request/client_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
)
|
||||
|
||||
// this doesn't need to be included in binary
|
||||
func (c *clientTracker) contains(check *http.Client) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for x := range c.clients {
|
||||
if check == c.clients[x] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestCheckAndRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tracker.checkAndRegister(nil)
|
||||
if !errors.Is(err, errHTTPClientIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
|
||||
}
|
||||
|
||||
newLovelyClient := new(http.Client)
|
||||
err = tracker.checkAndRegister(newLovelyClient)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if !tracker.contains(newLovelyClient) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", false, true)
|
||||
}
|
||||
|
||||
err = tracker.checkAndRegister(newLovelyClient)
|
||||
if !errors.Is(err, errCannotReuseHTTPClient) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errCannotReuseHTTPClient)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tracker.deRegister(nil)
|
||||
if !errors.Is(err, errHTTPClientIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
|
||||
}
|
||||
|
||||
newLovelyClient := new(http.Client)
|
||||
err = tracker.deRegister(newLovelyClient)
|
||||
if !errors.Is(err, errHTTPClientNotFound) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientNotFound)
|
||||
}
|
||||
|
||||
err = tracker.checkAndRegister(newLovelyClient)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if !tracker.contains(newLovelyClient) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", false, true)
|
||||
}
|
||||
|
||||
err = tracker.deRegister(newLovelyClient)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if tracker.contains(newLovelyClient) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", true, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProtectedClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, err := newProtectedClient(nil); !errors.Is(err, errHTTPClientIsNil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil)
|
||||
}
|
||||
|
||||
newLovelyClient := new(http.Client)
|
||||
protec, err := newProtectedClient(newLovelyClient)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if protec.protected != newLovelyClient {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSetProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := (&client{}).setProxy(nil)
|
||||
if !errors.Is(err, errNoProxyURLSupplied) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errNoProxyURLSupplied)
|
||||
}
|
||||
pp, err := url.Parse("lol.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = (&client{protected: new(http.Client)}).setProxy(pp)
|
||||
if !errors.Is(err, errTransportNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet)
|
||||
}
|
||||
err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setProxy(pp)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSetHTTPClientTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := (&client{protected: new(http.Client)}).setHTTPClientTimeout(time.Second)
|
||||
if !errors.Is(err, errTransportNotSet) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet)
|
||||
}
|
||||
err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setHTTPClientTimeout(time.Second)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelease(t *testing.T) {
|
||||
t.Parallel()
|
||||
newLovelyClient, err := newProtectedClient(common.NewHTTPClientWithTimeout(0))
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if !tracker.contains(newLovelyClient.protected) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", false, true)
|
||||
}
|
||||
|
||||
err = newLovelyClient.release()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", err, nil)
|
||||
}
|
||||
|
||||
if tracker.contains(newLovelyClient.protected) {
|
||||
t.Fatalf("received: '%v' but expected: '%v'", true, false)
|
||||
}
|
||||
}
|
||||
@@ -3,12 +3,19 @@ package request
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Defines rate limiting errors
|
||||
var (
|
||||
ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled")
|
||||
ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled")
|
||||
)
|
||||
|
||||
// Const here define individual functionality sub types for rate limiting
|
||||
const (
|
||||
Unset EndpointLimit = iota
|
||||
@@ -60,6 +67,9 @@ func NewBasicRateLimit(interval time.Duration, actions int) Limiter {
|
||||
|
||||
// InitiateRateLimit sleeps for designated end point rate limits
|
||||
func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
if atomic.LoadInt32(&r.disableRateLimiter) == 1 {
|
||||
return nil
|
||||
}
|
||||
@@ -73,16 +83,22 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro
|
||||
|
||||
// DisableRateLimiter disables the rate limiting system for the exchange
|
||||
func (r *Requester) DisableRateLimiter() error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 0, 1) {
|
||||
return errors.New("rate limiter already disabled")
|
||||
return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyDisabled)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableRateLimiter enables the rate limiting system for the exchange
|
||||
func (r *Requester) EnableRateLimiter() error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 1, 0) {
|
||||
return errors.New("rate limiter already enabled")
|
||||
return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyEnabled)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errRequestSystemIsNil = errors.New("request system is nil")
|
||||
// ErrRequestSystemIsNil defines and error if the request system has not
|
||||
// been set up yet.
|
||||
ErrRequestSystemIsNil = errors.New("request system is nil")
|
||||
errMaxRequestJobs = errors.New("max request jobs reached")
|
||||
errRequestFunctionIsNil = errors.New("request function is nil")
|
||||
errRequestItemNil = errors.New("request item is nil")
|
||||
@@ -28,13 +30,18 @@ var (
|
||||
errHeaderResponseMapIsNil = errors.New("header response map is nil")
|
||||
errFailedToRetryRequest = errors.New("failed to retry request")
|
||||
errContextRequired = errors.New("context is required")
|
||||
errTransportNotSet = errors.New("transport not set, cannot set timeout")
|
||||
)
|
||||
|
||||
// New returns a new Requester
|
||||
func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requester {
|
||||
func New(name string, httpRequester *http.Client, opts ...RequesterOption) (*Requester, error) {
|
||||
protectedClient, err := newProtectedClient(httpRequester)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot set up a new requester for %s: %w", name, err)
|
||||
}
|
||||
r := &Requester{
|
||||
HTTPClient: httpRequester,
|
||||
Name: name,
|
||||
_HTTPClient: protectedClient,
|
||||
name: name,
|
||||
backoff: DefaultBackoff(),
|
||||
retryPolicy: DefaultRetryPolicy,
|
||||
maxRetries: MaxRetryAttempts,
|
||||
@@ -46,13 +53,13 @@ func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requ
|
||||
o(r)
|
||||
}
|
||||
|
||||
return r
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// SendPayload handles sending HTTP/HTTPS requests
|
||||
func (r *Requester) SendPayload(ctx context.Context, ep EndpointLimit, newRequest Generate) error {
|
||||
if r == nil {
|
||||
return errRequestSystemIsNil
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
@@ -107,8 +114,8 @@ func (i *Item) validateRequest(ctx context.Context, r *Requester) (*http.Request
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
if r.UserAgent != "" && req.Header.Get(userAgent) == "" {
|
||||
req.Header.Add(userAgent, r.UserAgent)
|
||||
if r.userAgent != "" && req.Header.Get(userAgent) == "" {
|
||||
req.Header.Add(userAgent, r.userAgent)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -141,22 +148,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
}
|
||||
|
||||
if p.Verbose {
|
||||
log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.Name, attempt, p.Path)
|
||||
log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.name, attempt, p.Path)
|
||||
for k, d := range req.Header {
|
||||
log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.Name, k, d)
|
||||
log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.name, k, d)
|
||||
}
|
||||
log.Debugf(log.RequestSys, "%s request type: %s", r.Name, p.Method)
|
||||
log.Debugf(log.RequestSys, "%s request type: %s", r.name, p.Method)
|
||||
if p.Body != nil {
|
||||
log.Debugf(log.RequestSys, "%s request body: %v", r.Name, p.Body)
|
||||
log.Debugf(log.RequestSys, "%s request body: %v", r.name, p.Body)
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := r.HTTPClient.Do(req)
|
||||
resp, err := r._HTTPClient.do(req)
|
||||
|
||||
if r.reporter != nil {
|
||||
r.reporter.Latency(r.Name, p.Method, p.Path, time.Since(start))
|
||||
r.reporter.Latency(r.name, p.Method, p.Path, time.Since(start))
|
||||
}
|
||||
|
||||
if retry, checkErr := r.retryPolicy(resp, err); checkErr != nil {
|
||||
@@ -191,7 +198,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
if p.Verbose {
|
||||
log.Errorf(log.RequestSys,
|
||||
"%s request has failed. Retrying request in %s, attempt %d",
|
||||
r.Name,
|
||||
r.name,
|
||||
delay,
|
||||
attempt)
|
||||
}
|
||||
@@ -213,7 +220,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
|
||||
if p.HTTPRecording {
|
||||
// This dumps http responses for future mocking implementations
|
||||
err = mock.HTTPRecord(resp, r.Name, contents)
|
||||
err = mock.HTTPRecord(resp, r.name, contents)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mock recording failure %s", err)
|
||||
}
|
||||
@@ -228,21 +235,27 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
if resp.StatusCode < http.StatusOK ||
|
||||
resp.StatusCode > http.StatusAccepted {
|
||||
return fmt.Errorf("%s unsuccessful HTTP status code: %d raw response: %s",
|
||||
r.Name,
|
||||
r.name,
|
||||
resp.StatusCode,
|
||||
string(contents))
|
||||
}
|
||||
|
||||
if p.HTTPDebugging {
|
||||
dump, err := httputil.DumpResponse(resp, false)
|
||||
dump, dumpErr := httputil.DumpResponse(resp, false)
|
||||
if err != nil {
|
||||
log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", err)
|
||||
log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", dumpErr)
|
||||
}
|
||||
log.Debugf(log.RequestSys, "DumpResponse Headers (%v):\n%s", p.Path, dump)
|
||||
log.Debugf(log.RequestSys, "DumpResponse Body (%v):\n %s", p.Path, string(contents))
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf(log.RequestSys,
|
||||
"%s failed to close request body %s",
|
||||
r.name,
|
||||
err)
|
||||
}
|
||||
if p.Verbose {
|
||||
log.Debugf(log.RequestSys,
|
||||
"HTTP status: %s, Code: %v",
|
||||
@@ -251,7 +264,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
if !p.HTTPDebugging {
|
||||
log.Debugf(log.RequestSys,
|
||||
"%s raw response: %s",
|
||||
r.Name,
|
||||
r.name,
|
||||
string(contents))
|
||||
}
|
||||
}
|
||||
@@ -259,6 +272,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Requester) drainBody(body io.ReadCloser) {
|
||||
if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil {
|
||||
log.Errorf(log.RequestSys,
|
||||
"%s failed to drain request body %s",
|
||||
r.name,
|
||||
err)
|
||||
}
|
||||
|
||||
if err := body.Close(); err != nil {
|
||||
log.Errorf(log.RequestSys,
|
||||
"%s failed to close request body %s",
|
||||
r.name,
|
||||
err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetNonce returns a nonce for requests. This locks and enforces concurrent
|
||||
// nonce FIFO on the buffered job channel
|
||||
func (r *Requester) GetNonce(isNano bool) nonce.Value {
|
||||
@@ -285,27 +314,57 @@ func (r *Requester) GetNonceMilli() nonce.Value {
|
||||
return r.Nonce.GetInc()
|
||||
}
|
||||
|
||||
// SetProxy sets a proxy address to the client transport
|
||||
// SetProxy sets a proxy address for the client transport
|
||||
func (r *Requester) SetProxy(p *url.URL) error {
|
||||
if p.String() == "" {
|
||||
return errors.New("no proxy URL supplied")
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
return r._HTTPClient.setProxy(p)
|
||||
}
|
||||
|
||||
t, ok := r.HTTPClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return errors.New("transport not set, cannot set proxy")
|
||||
// SetHTTPClient sets exchanges HTTP client
|
||||
func (r *Requester) SetHTTPClient(newClient *http.Client) error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
t.Proxy = http.ProxyURL(p)
|
||||
t.TLSHandshakeTimeout = proxyTLSTimeout
|
||||
protectedClient, err := newProtectedClient(newClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r._HTTPClient = protectedClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Requester) drainBody(body io.ReadCloser) {
|
||||
defer body.Close()
|
||||
if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil {
|
||||
log.Errorf(log.RequestSys,
|
||||
"%s failed to drain request body %s",
|
||||
r.Name,
|
||||
err)
|
||||
// SetClientTimeout sets the timeout value for the exchanges HTTP Client and
|
||||
// also the underlying transports idle connection timeout
|
||||
func (r *Requester) SetHTTPClientTimeout(timeout time.Duration) error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
return r._HTTPClient.setHTTPClientTimeout(timeout)
|
||||
}
|
||||
|
||||
// SetHTTPClientUserAgent sets the exchanges HTTP user agent
|
||||
func (r *Requester) SetHTTPClientUserAgent(userAgent string) error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
r.userAgent = userAgent
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHTTPClientUserAgent gets the exchanges HTTP user agent
|
||||
func (r *Requester) GetHTTPClientUserAgent() (string, error) {
|
||||
if r == nil {
|
||||
return "", ErrRequestSystemIsNil
|
||||
}
|
||||
return r.userAgent, nil
|
||||
}
|
||||
|
||||
// Shutdown releases persistent memory for garbage collection.
|
||||
func (r *Requester) Shutdown() error {
|
||||
if r == nil {
|
||||
return ErrRequestSystemIsNil
|
||||
}
|
||||
return r._HTTPClient.release()
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -126,12 +127,15 @@ func TestNewRateLimit(t *testing.T) {
|
||||
func TestCheckRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := New("TestRequest",
|
||||
r, err := New("TestRequest",
|
||||
new(http.Client))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
var check *Item
|
||||
_, err := check.validateRequest(ctx, &Requester{})
|
||||
_, err = check.validateRequest(ctx, &Requester{})
|
||||
if err == nil {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
@@ -179,7 +183,7 @@ func TestCheckRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test user agent set
|
||||
r.UserAgent = "r00t axxs"
|
||||
r.userAgent = "r00t axxs"
|
||||
req, err := check.validateRequest(ctx, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -226,12 +230,15 @@ var globalshell = GlobalLimitTest{
|
||||
|
||||
func TestDoRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
r, err := New("test", new(http.Client), WithLimiter(&globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := (*Requester)(nil).SendPayload(ctx, Unset, nil)
|
||||
if !errors.Is(errRequestSystemIsNil, err) {
|
||||
t.Fatalf("expected: %v but received: %v", errRequestSystemIsNil, err)
|
||||
err = (*Requester)(nil).SendPayload(ctx, Unset, nil)
|
||||
if !errors.Is(ErrRequestSystemIsNil, err) {
|
||||
t.Fatalf("expected: %v but received: %v", ErrRequestSystemIsNil, err)
|
||||
}
|
||||
err = r.SendPayload(ctx, Unset, nil)
|
||||
if !errors.Is(errRequestFunctionIsNil, err) {
|
||||
@@ -305,8 +312,16 @@ func TestDoRequest(t *testing.T) {
|
||||
// reset jobs
|
||||
r.jobs = 0
|
||||
|
||||
r._HTTPClient, err = newProtectedClient(common.NewHTTPClientWithTimeout(0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// timeout checker
|
||||
r.HTTPClient.Timeout = time.Millisecond * 50
|
||||
err = r._HTTPClient.setHTTPClientTimeout(time.Millisecond * 50)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = r.SendPayload(ctx, UnAuth, func() (*Item, error) {
|
||||
return &Item{Path: testURL + "/timeout"}, nil
|
||||
})
|
||||
@@ -314,7 +329,10 @@ func TestDoRequest(t *testing.T) {
|
||||
t.Fatalf("received: %v but expected: %v", err, errFailedToRetryRequest)
|
||||
}
|
||||
// reset timeout
|
||||
r.HTTPClient.Timeout = 0
|
||||
err = r._HTTPClient.setHTTPClientTimeout(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check JSON
|
||||
var resp struct {
|
||||
@@ -403,7 +421,10 @@ func TestDoRequest_Retries(t *testing.T) {
|
||||
backoff := func(n int) time.Duration {
|
||||
return 0
|
||||
}
|
||||
r := New("test", new(http.Client), WithBackoff(backoff))
|
||||
r, err := New("test", new(http.Client), WithBackoff(backoff))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var failed int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(4)
|
||||
@@ -444,8 +465,11 @@ func TestDoRequest_RetryNonRecoverable(t *testing.T) {
|
||||
backoff := func(n int) time.Duration {
|
||||
return 0
|
||||
}
|
||||
r := New("test", new(http.Client), WithBackoff(backoff))
|
||||
err := r.SendPayload(context.Background(), Unset, func() (*Item, error) {
|
||||
r, err := New("test", new(http.Client), WithBackoff(backoff))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = r.SendPayload(context.Background(), Unset, func() (*Item, error) {
|
||||
return &Item{
|
||||
Method: http.MethodGet,
|
||||
Path: testURL + "/always-retry",
|
||||
@@ -466,8 +490,11 @@ func TestDoRequest_NotRetryable(t *testing.T) {
|
||||
backoff := func(n int) time.Duration {
|
||||
return time.Duration(n) * time.Millisecond
|
||||
}
|
||||
r := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff))
|
||||
err := r.SendPayload(context.Background(), Unset, func() (*Item, error) {
|
||||
r, err := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = r.SendPayload(context.Background(), Unset, func() (*Item, error) {
|
||||
return &Item{
|
||||
Method: http.MethodGet,
|
||||
Path: testURL + "/always-retry",
|
||||
@@ -480,17 +507,22 @@ func TestDoRequest_NotRetryable(t *testing.T) {
|
||||
|
||||
func TestGetNonce(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := New("test",
|
||||
r, err := New("test",
|
||||
new(http.Client),
|
||||
WithLimiter(&globalshell))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n1, n2 := r.GetNonce(false), r.GetNonce(false); n1 == n2 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
|
||||
r2 := New("test",
|
||||
r2, err := New("test",
|
||||
new(http.Client),
|
||||
WithLimiter(&globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n1, n2 := r2.GetNonce(true), r2.GetNonce(true); n1 == n2 {
|
||||
t.Fatal(unexpected)
|
||||
}
|
||||
@@ -498,9 +530,12 @@ func TestGetNonce(t *testing.T) {
|
||||
|
||||
func TestGetNonceMillis(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := New("test",
|
||||
r, err := New("test",
|
||||
new(http.Client),
|
||||
WithLimiter(&globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m1, m2 := r.GetNonceMilli(), r.GetNonceMilli(); m1 == m2 {
|
||||
log.Fatal(unexpected)
|
||||
}
|
||||
@@ -508,9 +543,17 @@ func TestGetNonceMillis(t *testing.T) {
|
||||
|
||||
func TestSetProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := New("test",
|
||||
var r *Requester
|
||||
err := r.SetProxy(nil)
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
r, err = New("test",
|
||||
&http.Client{Transport: new(http.Transport)},
|
||||
WithLimiter(&globalshell))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
u, err := url.Parse("http://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -530,9 +573,12 @@ func TestSetProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBasicLimiter(t *testing.T) {
|
||||
r := New("test",
|
||||
r, err := New("test",
|
||||
new(http.Client),
|
||||
WithLimiter(NewBasicRateLimit(time.Second, 1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
i := Item{
|
||||
Path: "http://www.google.com",
|
||||
Method: http.MethodGet,
|
||||
@@ -540,7 +586,7 @@ func TestBasicLimiter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tn := time.Now()
|
||||
err := r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil })
|
||||
err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil })
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -561,13 +607,16 @@ func TestBasicLimiter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEnableDisableRateLimit(t *testing.T) {
|
||||
r := New("TestRequest",
|
||||
r, err := New("TestRequest",
|
||||
new(http.Client),
|
||||
WithLimiter(NewBasicRateLimit(time.Minute, 1)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
var resp interface{}
|
||||
err := r.SendPayload(ctx, Auth, func() (*Item, error) {
|
||||
err = r.SendPayload(ctx, Auth, func() (*Item, error) {
|
||||
return &Item{
|
||||
Method: http.MethodGet,
|
||||
Path: testURL,
|
||||
@@ -635,3 +684,71 @@ func TestEnableDisableRateLimit(t *testing.T) {
|
||||
// Correct test
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHTTPClient(t *testing.T) {
|
||||
var r *Requester
|
||||
err := r.SetHTTPClient(nil)
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
client := new(http.Client)
|
||||
r = new(Requester)
|
||||
err = r.SetHTTPClient(client)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
|
||||
}
|
||||
err = r.SetHTTPClient(client)
|
||||
if !errors.Is(err, errCannotReuseHTTPClient) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, errCannotReuseHTTPClient)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHTTPClientTimeout(t *testing.T) {
|
||||
var r *Requester
|
||||
err := r.SetHTTPClientTimeout(0)
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
r = new(Requester)
|
||||
err = r.SetHTTPClient(common.NewHTTPClientWithTimeout(2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = r.SetHTTPClientTimeout(time.Second)
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHTTPClientUserAgent(t *testing.T) {
|
||||
var r *Requester
|
||||
err := r.SetHTTPClientUserAgent("")
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
r = new(Requester)
|
||||
err = r.SetHTTPClientUserAgent("")
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHTTPClientUserAgent(t *testing.T) {
|
||||
var r *Requester
|
||||
_, err := r.GetHTTPClientUserAgent()
|
||||
if !errors.Is(err, ErrRequestSystemIsNil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil)
|
||||
}
|
||||
r = new(Requester)
|
||||
err = r.SetHTTPClientUserAgent("sillyness")
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
|
||||
}
|
||||
ua, err := r.GetHTTPClientUserAgent()
|
||||
if !errors.Is(err, nil) {
|
||||
t.Fatalf("received: '%v', but expected: '%v'", err, nil)
|
||||
}
|
||||
if ua != "sillyness" {
|
||||
t.Fatal("unexpected value")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,11 +28,11 @@ var (
|
||||
|
||||
// Requester struct for the request client
|
||||
type Requester struct {
|
||||
HTTPClient *http.Client
|
||||
_HTTPClient *client
|
||||
limiter Limiter
|
||||
reporter Reporter
|
||||
Name string
|
||||
UserAgent string
|
||||
name string
|
||||
userAgent string
|
||||
maxRetries int
|
||||
jobs int32
|
||||
Nonce nonce.Nonce
|
||||
|
||||
@@ -198,11 +198,12 @@ func (c *CustomEx) WithdrawFiatFundsToInternationalBank(ctx context.Context, wit
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *CustomEx) SetHTTPClientUserAgent(ua string) {
|
||||
func (c *CustomEx) SetHTTPClientUserAgent(ua string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CustomEx) GetHTTPClientUserAgent() string {
|
||||
return ""
|
||||
func (c *CustomEx) GetHTTPClientUserAgent() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (c *CustomEx) SetClientProxyAddress(addr string) error {
|
||||
|
||||
@@ -97,10 +97,13 @@ func (y *Yobit) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
y.Requester = request.New(y.Name,
|
||||
y.Requester, err = request.New(y.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
// Server responses are cached every 2 seconds.
|
||||
request.WithLimiter(request.NewBasicRateLimit(time.Second, 1)))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
y.API.Endpoints = y.NewEndpoints()
|
||||
err = y.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: apiPublicURL,
|
||||
|
||||
@@ -47,7 +47,10 @@ func TestMain(m *testing.M) {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
|
||||
z.HTTPClient = newClient
|
||||
err = z.SetHTTPClient(newClient)
|
||||
if err != nil {
|
||||
log.Fatalf("Mock server error %s", err)
|
||||
}
|
||||
endpoints := z.API.Endpoints.GetURLMap()
|
||||
for k := range endpoints {
|
||||
err = z.API.Endpoints.SetRunning(k, serverDetails)
|
||||
|
||||
@@ -130,9 +130,12 @@ func (z *ZB) SetDefaults() {
|
||||
},
|
||||
}
|
||||
|
||||
z.Requester = request.New(z.Name,
|
||||
z.Requester, err = request.New(z.Name,
|
||||
common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout),
|
||||
request.WithLimiter(SetRateLimit()))
|
||||
if err != nil {
|
||||
log.Errorln(log.ExchangeSys, err)
|
||||
}
|
||||
z.API.Endpoints = z.NewEndpoints()
|
||||
err = z.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{
|
||||
exchange.RestSpot: zbTradeURL,
|
||||
|
||||
Reference in New Issue
Block a user