diff --git a/cmd/apichecker/apicheck.go b/cmd/apichecker/apicheck.go index 12eaaa38..dc5dddd0 100644 --- a/cmd/apichecker/apicheck.go +++ b/cmd/apichecker/apicheck.go @@ -1288,15 +1288,19 @@ func updateFile(name string) error { // SendGetReq sends get req func sendGetReq(path string, result interface{}) error { var requester *request.Requester + var err error if strings.Contains(path, "github") { - requester = request.New("Apichecker", + requester, err = request.New("Apichecker", common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(time.Hour, 60))) } else { - requester = request.New("Apichecker", + requester, err = request.New("Apichecker", common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(time.Second, 100))) } + if err != nil { + return err + } item := &request.Item{ Method: http.MethodGet, Path: path, @@ -1309,9 +1313,12 @@ func sendGetReq(path string, result interface{}) error { // sendAuthReq sends auth req func sendAuthReq(method, path string, result interface{}) error { - requester := request.New("Apichecker", + requester, err := request.New("Apichecker", common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(time.Second*10, 100))) + if err != nil { + return err + } item := &request.Item{ Method: method, Path: path, diff --git a/currency/coinmarketcap/coinmarketcap.go b/currency/coinmarketcap/coinmarketcap.go index 69340a7e..980233ee 100644 --- a/currency/coinmarketcap/coinmarketcap.go +++ b/currency/coinmarketcap/coinmarketcap.go @@ -36,10 +36,14 @@ func (c *Coinmarketcap) SetDefaults() { c.Verbose = false c.APIUrl = baseURL c.APIVersion = version - c.Requester = request.New(c.Name, + var err error + c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(defaultTimeOut), request.WithLimiter(request.NewBasicRateLimit(RateInterval, BasicRequestRate)), ) + if err != nil { + log.Errorln(log.Global, err) + } } // Setup sets user configuration diff --git a/currency/forexprovider/currencyconverterapi/currencyconverterapi.go b/currency/forexprovider/currencyconverterapi/currencyconverterapi.go index 89f442de..5e65e3df 100644 --- a/currency/forexprovider/currencyconverterapi/currencyconverterapi.go +++ b/currency/forexprovider/currencyconverterapi/currencyconverterapi.go @@ -23,10 +23,11 @@ func (c *CurrencyConverter) Setup(config base.Settings) error { c.Enabled = config.Enabled c.Verbose = config.Verbose c.PrimaryProvider = config.PrimaryProvider - c.Requester = request.New(c.Name, + var err error + c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut), request.WithLimiter(request.NewBasicRateLimit(rateInterval, requestRate))) - return nil + return err } // GetRates is a wrapper function to return rates diff --git a/currency/forexprovider/currencylayer/currencylayer.go b/currency/forexprovider/currencylayer/currencylayer.go index 7827fabe..5fb5ccea 100644 --- a/currency/forexprovider/currencylayer/currencylayer.go +++ b/currency/forexprovider/currencylayer/currencylayer.go @@ -43,10 +43,10 @@ func (c *CurrencyLayer) Setup(config base.Settings) error { c.Verbose = config.Verbose c.PrimaryProvider = config.PrimaryProvider // Rate limit is based off a monthly counter - Open limit used. - c.Requester = request.New(c.Name, + var err error + c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut)) - - return nil + return err } // GetRates is a wrapper function to return rates for GoCryptoTrader diff --git a/currency/forexprovider/exchangerate.host/exchangerate.go b/currency/forexprovider/exchangerate.host/exchangerate.go index b0d89780..e4b78bd9 100644 --- a/currency/forexprovider/exchangerate.host/exchangerate.go +++ b/currency/forexprovider/exchangerate.host/exchangerate.go @@ -33,9 +33,10 @@ func (e *ExchangeRateHost) Setup(config base.Settings) error { e.Enabled = config.Enabled e.Verbose = config.Verbose e.PrimaryProvider = config.PrimaryProvider - e.Requester = request.New(e.Name, + var err error + e.Requester, err = request.New(e.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut)) - return nil + return err } // GetLatestRates returns a list of forex rates based on the supplied params diff --git a/currency/forexprovider/exchangeratesapi.io/exchangeratesapi.go b/currency/forexprovider/exchangeratesapi.io/exchangeratesapi.go index edcbc33e..77f565c0 100644 --- a/currency/forexprovider/exchangeratesapi.io/exchangeratesapi.go +++ b/currency/forexprovider/exchangeratesapi.io/exchangeratesapi.go @@ -29,10 +29,11 @@ func (e *ExchangeRates) Setup(config base.Settings) error { e.PrimaryProvider = config.PrimaryProvider e.APIKey = config.APIKey e.APIKeyLvl = config.APIKeyLvl - e.Requester = request.New(e.Name, + var err error + e.Requester, err = request.New(e.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut), request.WithLimiter(request.NewBasicRateLimit(rateLimitInterval, requestRate))) - return nil + return err } func (e *ExchangeRates) cleanCurrencies(baseCurrency, symbols string) string { diff --git a/currency/forexprovider/fixer.io/fixer.go b/currency/forexprovider/fixer.io/fixer.go index 8c1568a8..8730fc08 100644 --- a/currency/forexprovider/fixer.io/fixer.go +++ b/currency/forexprovider/fixer.io/fixer.go @@ -36,9 +36,10 @@ func (f *Fixer) Setup(config base.Settings) error { f.Name = config.Name f.Verbose = config.Verbose f.PrimaryProvider = config.PrimaryProvider - f.Requester = request.New(f.Name, + var err error + f.Requester, err = request.New(f.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut)) - return nil + return err } // GetSupportedCurrencies returns supported currencies diff --git a/currency/forexprovider/openexchangerates/openexchangerates.go b/currency/forexprovider/openexchangerates/openexchangerates.go index 5e64202a..94473cce 100644 --- a/currency/forexprovider/openexchangerates/openexchangerates.go +++ b/currency/forexprovider/openexchangerates/openexchangerates.go @@ -37,9 +37,10 @@ func (o *OXR) Setup(config base.Settings) error { o.Name = config.Name o.Verbose = config.Verbose o.PrimaryProvider = config.PrimaryProvider - o.Requester = request.New(o.Name, + var err error + o.Requester, err = request.New(o.Name, common.NewHTTPClientWithTimeout(base.DefaultTimeOut)) - return nil + return err } // GetRates is a wrapper function to return rates diff --git a/engine/exchange_manager.go b/engine/exchange_manager.go index a9dab9ba..c754bbee 100644 --- a/engine/exchange_manager.go +++ b/engine/exchange_manager.go @@ -94,12 +94,16 @@ func (m *ExchangeManager) RemoveExchange(exchName string) error { if m.Len() == 0 { return ErrNoExchangesLoaded } - _, err := m.GetExchangeByName(exchName) + exch, err := m.GetExchangeByName(exchName) if err != nil { return err } m.m.Lock() defer m.m.Unlock() + err = exch.GetBase().Requester.Shutdown() + if err != nil { + return err + } delete(m.exchanges, strings.ToLower(exchName)) log.Infof(log.ExchangeSys, "%s exchange unloaded successfully.\n", exchName) return nil diff --git a/exchanges/alphapoint/alphapoint_wrapper.go b/exchanges/alphapoint/alphapoint_wrapper.go index 1e84e1be..0d067fed 100644 --- a/exchanges/alphapoint/alphapoint_wrapper.go +++ b/exchanges/alphapoint/alphapoint_wrapper.go @@ -73,8 +73,11 @@ func (a *Alphapoint) SetDefaults() { }, } - a.Requester = request.New(a.Name, + a.Requester, err = request.New(a.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } } // FetchTradablePairs returns a list of the exchanges tradable pairs diff --git a/exchanges/binance/binance_mock_test.go b/exchanges/binance/binance_mock_test.go index cefeb9ae..2e54e6bc 100644 --- a/exchanges/binance/binance_mock_test.go +++ b/exchanges/binance/binance_mock_test.go @@ -46,7 +46,10 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalf("Mock server error %s", err) } - b.HTTPClient = newClient + err = b.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpointMap := b.API.Endpoints.GetURLMap() for k := range endpointMap { err = b.API.Endpoints.SetRunning(k, serverDetails) diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index 3452153a..e554aefd 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -186,9 +186,12 @@ func (b *Binance) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: spotAPIURL, diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 1424d6e8..c421073a 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -165,9 +165,12 @@ func (b *Bitfinex) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: bitfinexAPIURLBase, diff --git a/exchanges/bitflyer/bitflyer_wrapper.go b/exchanges/bitflyer/bitflyer_wrapper.go index 8b0bb911..5007fe14 100644 --- a/exchanges/bitflyer/bitflyer_wrapper.go +++ b/exchanges/bitflyer/bitflyer_wrapper.go @@ -94,9 +94,12 @@ func (b *Bitflyer) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: japanURL, diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 1a5223d2..d4ec06a9 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -129,9 +129,12 @@ func (b *Bithumb) SetDefaults() { }, }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: apiURL, diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 90e20ca9..90fbbb04 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -124,9 +124,12 @@ func (b *Bitmex) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: bitmexAPIURL, diff --git a/exchanges/bitstamp/bitstamp_mock_test.go b/exchanges/bitstamp/bitstamp_mock_test.go index 29c1cc58..5d0755c7 100644 --- a/exchanges/bitstamp/bitstamp_mock_test.go +++ b/exchanges/bitstamp/bitstamp_mock_test.go @@ -46,7 +46,10 @@ func TestMain(m *testing.M) { log.Fatalf("Mock server error %s", err) } - b.HTTPClient = newClient + err = b.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpointMap := b.API.Endpoints.GetURLMap() for k := range endpointMap { err = b.API.Endpoints.SetRunning(k, serverDetails+"/api") diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index c502acf5..f56e78b3 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -129,9 +129,12 @@ func (b *Bitstamp) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(bitstampRateInterval, bitstampRequestRate))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: bitstampAPIURL, diff --git a/exchanges/bittrex/bittrex_wrapper.go b/exchanges/bittrex/bittrex_wrapper.go index b9292930..c91f7ac7 100644 --- a/exchanges/bittrex/bittrex_wrapper.go +++ b/exchanges/bittrex/bittrex_wrapper.go @@ -120,9 +120,12 @@ func (b *Bittrex) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 816083f7..aceda50d 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -119,9 +119,12 @@ func (b *BTCMarkets) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: btcMarketsAPIURL, diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index 9c59178b..087a6aff 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -149,9 +149,12 @@ func (b *BTSE) SetDefaults() { }, } - b.Requester = request.New(b.Name, + b.Requester, err = request.New(b.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } b.API.Endpoints = b.NewEndpoints() err = b.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: btseAPIURL, diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index c1a007b8..9d117d34 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -130,9 +130,12 @@ func (c *CoinbasePro) SetDefaults() { }, } - c.Requester = request.New(c.Name, + c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } c.API.Endpoints = c.NewEndpoints() err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: coinbaseproAPIURL, diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 1ba6771e..115f03d5 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -113,8 +113,11 @@ func (c *COINUT) SetDefaults() { }, } - c.Requester = request.New(c.Name, + c.Requester, err = request.New(c.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } c.API.Endpoints = c.NewEndpoints() err = c.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: coinutAPIURL, diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 013385c0..d1812370 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "net/http" "net/url" "strconv" "strings" @@ -20,7 +19,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/currencystate" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -46,56 +44,8 @@ var ( ErrAuthenticatedRequestWithoutCredentialsSet = errors.New("authenticated HTTP request called but not supported due to unset/default API keys") errEndpointStringNotFound = errors.New("endpoint string not found") - errTransportNotSet = errors.New("transport not set, cannot set timeout") - - // ErrPairNotFound is an error message for when unable to find a currency pair - ErrPairNotFound = errors.New("pair not found") ) -func (b *Base) checkAndInitRequester() { - if b.Requester == nil { - b.Requester = request.New(b.Name, - &http.Client{Transport: new(http.Transport)}) - } -} - -// SetHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and -// also the underlying transports idle connection timeout -func (b *Base) SetHTTPClientTimeout(t time.Duration) error { - b.checkAndInitRequester() - b.Requester.HTTPClient.Timeout = t - tr, ok := b.Requester.HTTPClient.Transport.(*http.Transport) - if !ok { - return errTransportNotSet - } - tr.IdleConnTimeout = t - return nil -} - -// SetHTTPClient sets exchanges HTTP client -func (b *Base) SetHTTPClient(h *http.Client) { - b.checkAndInitRequester() - b.Requester.HTTPClient = h -} - -// GetHTTPClient gets the exchanges HTTP client -func (b *Base) GetHTTPClient() *http.Client { - b.checkAndInitRequester() - return b.Requester.HTTPClient -} - -// SetHTTPClientUserAgent sets the exchanges HTTP user agent -func (b *Base) SetHTTPClientUserAgent(ua string) { - b.checkAndInitRequester() - b.Requester.UserAgent = ua - b.HTTPUserAgent = ua -} - -// GetHTTPClientUserAgent gets the exchanges HTTP user agent -func (b *Base) GetHTTPClientUserAgent() string { - return b.HTTPUserAgent -} - // SetClientProxyAddress sets a proxy address for REST and websocket requests func (b *Base) SetClientProxyAddress(addr string) error { if addr == "" { @@ -441,17 +391,16 @@ func (b *Base) GetEnabledPairs(a asset.Item) (currency.Pairs, error) { // GetRequestFormattedPairAndAssetType is a method that returns the enabled currency pair of // along with its asset type. Only use when there is no chance of the same name crossing over func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, asset.Item, error) { - assetTypes := b.GetAssetTypes(false) - var response currency.Pair + assetTypes := b.GetAssetTypes(true) for i := range assetTypes { format, err := b.GetPairFormat(assetTypes[i], true) if err != nil { - return response, assetTypes[i], err + return currency.EMPTYPAIR, assetTypes[i], err } pairs, err := b.CurrencyPairs.GetPairs(assetTypes[i], true) if err != nil { - return response, assetTypes[i], err + return currency.EMPTYPAIR, assetTypes[i], err } for j := range pairs { @@ -461,8 +410,7 @@ func (b *Base) GetRequestFormattedPairAndAssetType(p string) (currency.Pair, ass } } } - return response, "", - fmt.Errorf("%s %w", p, ErrPairNotFound) + return currency.EMPTYPAIR, "", fmt.Errorf("%s %w", p, currency.ErrPairNotFound) } // GetAvailablePairs is a method that returns the available currency pairs @@ -607,7 +555,10 @@ func (b *Base) SetupDefaults(exch *config.Exchange) error { b.HTTPDebugging = exch.HTTPDebugging b.BypassConfigFormatUpgrades = exch.CurrencyPairs.BypassConfigFormatUpgrades - b.SetHTTPClientUserAgent(exch.HTTPUserAgent) + err = b.SetHTTPClientUserAgent(exch.HTTPUserAgent) + if err != nil { + return err + } b.SetCurrencyPairFormat() err = b.SetConfigPairs() diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index f6e914ec..c5df4a7b 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -5,9 +5,7 @@ import ( "errors" "fmt" "net" - "net/http" "os" - "strings" "testing" "time" @@ -193,76 +191,21 @@ func TestSetDefaultEndpoints(t *testing.T) { } } -func TestHTTPClient(t *testing.T) { - t.Parallel() - r := Base{Name: "asdf"} - err := r.SetHTTPClientTimeout(time.Second * 5) - if err != nil { - t.Fatal(err) - } - - if r.GetHTTPClient().Timeout != time.Second*5 { - t.Fatalf("TestHTTPClient unexpected value") - } - - r.Requester = nil - newClient := new(http.Client) - newClient.Timeout = time.Second * 10 - - r.SetHTTPClient(newClient) - if r.GetHTTPClient().Timeout != time.Second*10 { - t.Fatalf("TestHTTPClient unexpected value") - } - - r.Requester = nil - if r.GetHTTPClient() == nil { - t.Fatalf("TestHTTPClient unexpected value") - } - - b := Base{Name: "RAWR"} - - b.Requester = request.New(b.Name, new(http.Client)) - err = b.SetHTTPClientTimeout(time.Second * 5) - if !errors.Is(err, errTransportNotSet) { - t.Fatalf("received: %v but expected: %v", err, errTransportNotSet) - } - - b.Requester = request.New(b.Name, &http.Client{Transport: new(http.Transport)}) - err = b.SetHTTPClientTimeout(time.Second * 5) - if err != nil { - t.Fatal(err) - } - - if b.GetHTTPClient().Timeout != time.Second*5 { - t.Fatalf("TestHTTPClient unexpected value") - } - - newClient = new(http.Client) - newClient.Timeout = time.Second * 10 - - b.SetHTTPClient(newClient) - if b.GetHTTPClient().Timeout != time.Second*10 { - t.Fatalf("TestHTTPClient unexpected value") - } - - b.SetHTTPClientUserAgent("epicUserAgent") - if !strings.Contains(b.GetHTTPClientUserAgent(), "epicUserAgent") { - t.Error("user agent not set properly") - } -} - func TestSetClientProxyAddress(t *testing.T) { t.Parallel() - requester := request.New("rawr", + requester, err := request.New("rawr", common.NewHTTPClientWithTimeout(time.Second*15)) + if err != nil { + t.Fatal(err) + } newBase := Base{ Name: "rawr", Requester: requester} newBase.Websocket = stream.New() - err := newBase.SetClientProxyAddress("") + err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) } @@ -289,13 +232,6 @@ func TestSetClientProxyAddress(t *testing.T) { if newBase.Websocket.GetProxyAddress() != "http://www.valid.com" { t.Error("SetClientProxyAddress error", err) } - - // Nil out transport - newBase.Requester.HTTPClient.Transport = nil - err = newBase.SetClientProxyAddress("http://www.valid.com") - if err == nil { - t.Error("error cannot be nil") - } } func TestSetFeatureDefaults(t *testing.T) { @@ -1268,7 +1204,16 @@ func TestSetAPIKeys(t *testing.T) { func TestSetupDefaults(t *testing.T) { t.Parallel() - var b = Base{Name: "awesomeTest"} + newRequester, err := request.New("testSetupDefaults", + common.NewHTTPClientWithTimeout(0)) + if err != nil { + t.Fatal(err) + } + + var b = Base{ + Name: "awesomeTest", + Requester: newRequester, + } cfg := config.Exchange{ HTTPTimeout: time.Duration(-1), API: config.APIConfig{ @@ -1276,7 +1221,7 @@ func TestSetupDefaults(t *testing.T) { }, } - err := b.SetupDefaults(&cfg) + err = b.SetupDefaults(&cfg) if err != nil { t.Fatal(err) } @@ -2055,25 +2000,34 @@ func TestCheckTransientError(t *testing.T) { func TestDisableEnableRateLimiter(t *testing.T) { b := Base{} - b.checkAndInitRequester() err := b.EnableRateLimiter() - if err == nil { - t.Fatal("error cannot be nil") + if !errors.Is(err, request.ErrRequestSystemIsNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRequestSystemIsNil) } - err = b.DisableRateLimiter() + b.Requester, err = request.New("testingRateLimiter", common.NewHTTPClientWithTimeout(0)) if err != nil { t.Fatal(err) } err = b.DisableRateLimiter() - if err == nil { - t.Fatal("error cannot be nil") + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + err = b.DisableRateLimiter() + if !errors.Is(err, request.ErrRateLimiterAlreadyDisabled) { + t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyDisabled) } err = b.EnableRateLimiter() - if err != nil { - t.Fatal(err) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + err = b.EnableRateLimiter() + if !errors.Is(err, request.ErrRateLimiterAlreadyEnabled) { + t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyEnabled) } } diff --git a/exchanges/exchange_types.go b/exchanges/exchange_types.go index 4f35119e..887df7ea 100644 --- a/exchanges/exchange_types.go +++ b/exchanges/exchange_types.go @@ -219,7 +219,6 @@ type Base struct { CurrencyPairs currency.PairsManager Features Features HTTPTimeout time.Duration - HTTPUserAgent string HTTPRecording bool HTTPDebugging bool BypassConfigFormatUpgrades bool diff --git a/exchanges/exmo/exmo_wrapper.go b/exchanges/exmo/exmo_wrapper.go index 89e5e999..24f23eb6 100644 --- a/exchanges/exmo/exmo_wrapper.go +++ b/exchanges/exmo/exmo_wrapper.go @@ -110,9 +110,12 @@ func (e *EXMO) SetDefaults() { }, } - e.Requester = request.New(e.Name, + e.Requester, err = request.New(e.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(exmoRateInterval, exmoRequestRate))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } e.API.Endpoints = e.NewEndpoints() err = e.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: exmoAPIURL, diff --git a/exchanges/ftx/ftx_wrapper.go b/exchanges/ftx/ftx_wrapper.go index c5564061..2d935836 100644 --- a/exchanges/ftx/ftx_wrapper.go +++ b/exchanges/ftx/ftx_wrapper.go @@ -147,9 +147,12 @@ func (f *FTX) SetDefaults() { }, } - f.Requester = request.New(f.Name, + f.Requester, err = request.New(f.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(ratePeriod, rateLimit))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } f.API.Endpoints = f.NewEndpoints() err = f.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: ftxAPIURL, diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 4491da84..dd18e448 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -130,8 +130,11 @@ func (g *Gateio) SetDefaults() { }, }, } - g.Requester = request.New(g.Name, + g.Requester, err = request.New(g.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } g.API.Endpoints = g.NewEndpoints() err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: gateioTradeURL, diff --git a/exchanges/gemini/gemini_mock_test.go b/exchanges/gemini/gemini_mock_test.go index b2ee314d..05b19df2 100644 --- a/exchanges/gemini/gemini_mock_test.go +++ b/exchanges/gemini/gemini_mock_test.go @@ -45,7 +45,10 @@ func TestMain(m *testing.M) { log.Fatalf("Mock server error %s", err) } - g.HTTPClient = newClient + err = g.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpointMap := g.API.Endpoints.GetURLMap() for k := range endpointMap { err = g.API.Endpoints.SetRunning(k, serverDetails) diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index 0a6c1e75..95621d5d 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -113,9 +113,12 @@ func (g *Gemini) SetDefaults() { }, } - g.Requester = request.New(g.Name, + g.Requester, err = request.New(g.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } g.API.Endpoints = g.NewEndpoints() err = g.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: geminiAPIURL, diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index cc438d0c..b1e0036c 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -129,9 +129,12 @@ func (h *HitBTC) SetDefaults() { }, } - h.Requester = request.New(h.Name, + h.Requester, err = request.New(h.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } h.API.Endpoints = h.NewEndpoints() err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: apiURL, diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index ea457912..ad0cb099 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -158,9 +158,12 @@ func (h *HUOBI) SetDefaults() { }, } - h.Requester = request.New(h.Name, + h.Requester, err = request.New(h.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } h.API.Endpoints = h.NewEndpoints() err = h.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: huobiAPIURL, diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index e97baecd..26c2e07c 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -68,8 +68,8 @@ type IBotExchange interface { WithdrawCryptocurrencyFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error) WithdrawFiatFunds(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error) WithdrawFiatFundsToInternationalBank(ctx context.Context, withdrawRequest *withdraw.Request) (*withdraw.ExchangeResponse, error) - SetHTTPClientUserAgent(ua string) - GetHTTPClientUserAgent() string + SetHTTPClientUserAgent(ua string) error + GetHTTPClientUserAgent() (string, error) SetClientProxyAddress(addr string) error SupportsREST() bool GetSubscriptions() ([]stream.ChannelSubscription, error) diff --git a/exchanges/itbit/itbit_wrapper.go b/exchanges/itbit/itbit_wrapper.go index 79c6524c..1c0915a0 100644 --- a/exchanges/itbit/itbit_wrapper.go +++ b/exchanges/itbit/itbit_wrapper.go @@ -94,8 +94,11 @@ func (i *ItBit) SetDefaults() { }, } - i.Requester = request.New(i.Name, + i.Requester, err = request.New(i.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } i.API.Endpoints = i.NewEndpoints() err = i.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: itbitAPIURL, diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index a959e908..a2c66f6b 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -171,9 +171,12 @@ func (k *Kraken) SetDefaults() { }, } - k.Requester = request.New(k.Name, + k.Requester, err = request.New(k.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(request.NewBasicRateLimit(krakenRateInterval, krakenRequestRate))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } k.API.Endpoints = k.NewEndpoints() err = k.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: krakenAPIURL, diff --git a/exchanges/lbank/lbank_wrapper.go b/exchanges/lbank/lbank_wrapper.go index 9aab8bcc..c4348735 100644 --- a/exchanges/lbank/lbank_wrapper.go +++ b/exchanges/lbank/lbank_wrapper.go @@ -108,8 +108,11 @@ func (l *Lbank) SetDefaults() { }, }, } - l.Requester = request.New(l.Name, + l.Requester, err = request.New(l.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } l.API.Endpoints = l.NewEndpoints() err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: lbankAPIURL, diff --git a/exchanges/localbitcoins/localbitcoins_mock_test.go b/exchanges/localbitcoins/localbitcoins_mock_test.go index e9799289..688d85b8 100644 --- a/exchanges/localbitcoins/localbitcoins_mock_test.go +++ b/exchanges/localbitcoins/localbitcoins_mock_test.go @@ -44,7 +44,10 @@ func TestMain(m *testing.M) { log.Fatalf("Mock server error %s", err) } - l.HTTPClient = newClient + err = l.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpoints := l.API.Endpoints.GetURLMap() for k := range endpoints { err = l.API.Endpoints.SetRunning(k, serverDetails) diff --git a/exchanges/localbitcoins/localbitcoins_wrapper.go b/exchanges/localbitcoins/localbitcoins_wrapper.go index 0dbafa91..7c60cdf2 100644 --- a/exchanges/localbitcoins/localbitcoins_wrapper.go +++ b/exchanges/localbitcoins/localbitcoins_wrapper.go @@ -93,8 +93,11 @@ func (l *LocalBitcoins) SetDefaults() { }, } - l.Requester = request.New(l.Name, + l.Requester, err = request.New(l.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout)) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } l.API.Endpoints = l.NewEndpoints() err = l.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: localbitcoinsAPIURL, diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index e110e3f6..f689de5c 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -133,11 +133,14 @@ func (o *OKCoin) SetDefaults() { }, } - o.Requester = request.New(o.Name, + o.Requester, err = request.New(o.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), // TODO: Specify each individual endpoint rate limits as per docs request.WithLimiter(request.NewBasicRateLimit(okCoinRateInterval, okCoinStandardRequestRate)), ) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } o.API.Endpoints = o.NewEndpoints() err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: okCoinAPIURL, diff --git a/exchanges/okex/okex_wrapper.go b/exchanges/okex/okex_wrapper.go index ca67f1e5..c2638da2 100644 --- a/exchanges/okex/okex_wrapper.go +++ b/exchanges/okex/okex_wrapper.go @@ -191,11 +191,14 @@ func (o *OKEX) SetDefaults() { }, } - o.Requester = request.New(o.Name, + o.Requester, err = request.New(o.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), // TODO: Specify each individual endpoint rate limits as per docs request.WithLimiter(request.NewBasicRateLimit(okExRateInterval, okExRequestRate)), ) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } o.API.Endpoints = o.NewEndpoints() err = o.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: okExAPIURL, diff --git a/exchanges/poloniex/poloniex_mock_test.go b/exchanges/poloniex/poloniex_mock_test.go index c0507731..8f1af4db 100644 --- a/exchanges/poloniex/poloniex_mock_test.go +++ b/exchanges/poloniex/poloniex_mock_test.go @@ -45,7 +45,10 @@ func TestMain(m *testing.M) { log.Fatalf("Mock server error %s", err) } - p.HTTPClient = newClient + err = p.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpoints := p.API.Endpoints.GetURLMap() for k := range endpoints { err = p.API.Endpoints.SetRunning(k, serverDetails) diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index ed675308..462f91fc 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -133,9 +133,12 @@ func (p *Poloniex) SetDefaults() { }, } - p.Requester = request.New(p.Name, + p.Requester, err = request.New(p.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } p.API.Endpoints = p.NewEndpoints() err = p.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: poloniexAPIURL, diff --git a/exchanges/request/client.go b/exchanges/request/client.go new file mode 100644 index 00000000..53647281 --- /dev/null +++ b/exchanges/request/client.go @@ -0,0 +1,131 @@ +package request + +import ( + "errors" + "net/http" + "net/url" + "sync" + "time" +) + +var ( + // tracker is the global to maintain sanity between clients across all + // services using the request package. + tracker clientTracker + + errNoProxyURLSupplied = errors.New("no proxy URL supplied") + errCannotReuseHTTPClient = errors.New("cannot reuse http client") + errHTTPClientIsNil = errors.New("http client is nil") + errHTTPClientNotFound = errors.New("http client not found") +) + +// clientTracker attempts to maintain service/http.Client segregation +type clientTracker struct { + clients []*http.Client + sync.Mutex +} + +// checkAndRegister stops the sharing of the same http.Client between services. +func (c *clientTracker) checkAndRegister(newClient *http.Client) error { + if newClient == nil { + return errHTTPClientIsNil + } + c.Lock() + defer c.Unlock() + for x := range c.clients { + if newClient == c.clients[x] { + return errCannotReuseHTTPClient + } + } + c.clients = append(c.clients, newClient) + return nil +} + +// deRegister removes the *http.Client from being tracked +func (c *clientTracker) deRegister(oldClient *http.Client) error { + if oldClient == nil { + return errHTTPClientIsNil + } + c.Lock() + defer c.Unlock() + for x := range c.clients { + if oldClient != c.clients[x] { + continue + } + c.clients[x] = c.clients[len(c.clients)-1] + c.clients[len(c.clients)-1] = nil + c.clients = c.clients[:len(c.clients)-1] + return nil + } + return errHTTPClientNotFound +} + +// client wraps over a http client for better protection +type client struct { + protected *http.Client + m sync.RWMutex +} + +// newProtectedClient registers a http.Client to inhibit cross service usage and +// return a thread safe holder (*request.Client) with getter and setters for +// timeouts and transports. +func newProtectedClient(newClient *http.Client) (*client, error) { + if err := tracker.checkAndRegister(newClient); err != nil { + return nil, err + } + return &client{protected: newClient}, nil +} + +// setProxy sets a proxy address for the client transport +func (c *client) setProxy(p *url.URL) error { + if p == nil || p.String() == "" { + return errNoProxyURLSupplied + } + c.m.Lock() + defer c.m.Unlock() + // Check transport first so we don't set something and then error. + tr, ok := c.protected.Transport.(*http.Transport) + if !ok { + return errTransportNotSet + } + // This closes idle connections before an attempt at reassignment and + // boots any dangly routines. + tr.CloseIdleConnections() + tr.Proxy = http.ProxyURL(p) + tr.TLSHandshakeTimeout = proxyTLSTimeout + return nil +} + +// setHTTPClientTimeout sets the timeout value for the exchanges HTTP Client and +// also the underlying transports idle connection timeout +func (c *client) setHTTPClientTimeout(timeout time.Duration) error { + c.m.Lock() + defer c.m.Unlock() + // Check transport first so we don't set something and then error. + tr, ok := c.protected.Transport.(*http.Transport) + if !ok { + return errTransportNotSet + } + // This closes idle connections before an attempt at reassignment and + // boots any dangly routines. + tr.CloseIdleConnections() + tr.IdleConnTimeout = timeout + c.protected.Timeout = timeout + return nil +} + +// do sends request in a protected manner +func (c *client) do(request *http.Request) (resp *http.Response, err error) { + c.m.RLock() + resp, err = c.protected.Do(request) + c.m.RUnlock() + return +} + +// release de-registers the underlying client +func (c *client) release() error { + c.m.Lock() + err := tracker.deRegister(c.protected) + c.m.Unlock() + return err +} diff --git a/exchanges/request/client_test.go b/exchanges/request/client_test.go new file mode 100644 index 00000000..a68f7f87 --- /dev/null +++ b/exchanges/request/client_test.go @@ -0,0 +1,148 @@ +package request + +import ( + "errors" + "net/http" + "net/url" + "testing" + "time" + + "github.com/thrasher-corp/gocryptotrader/common" +) + +// this doesn't need to be included in binary +func (c *clientTracker) contains(check *http.Client) bool { + c.Lock() + defer c.Unlock() + for x := range c.clients { + if check == c.clients[x] { + return true + } + } + return false +} + +func TestCheckAndRegister(t *testing.T) { + t.Parallel() + err := tracker.checkAndRegister(nil) + if !errors.Is(err, errHTTPClientIsNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) + } + + newLovelyClient := new(http.Client) + err = tracker.checkAndRegister(newLovelyClient) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if !tracker.contains(newLovelyClient) { + t.Fatalf("received: '%v' but expected: '%v'", false, true) + } + + err = tracker.checkAndRegister(newLovelyClient) + if !errors.Is(err, errCannotReuseHTTPClient) { + t.Fatalf("received: '%v' but expected: '%v'", err, errCannotReuseHTTPClient) + } +} + +func TestDeRegister(t *testing.T) { + t.Parallel() + err := tracker.deRegister(nil) + if !errors.Is(err, errHTTPClientIsNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) + } + + newLovelyClient := new(http.Client) + err = tracker.deRegister(newLovelyClient) + if !errors.Is(err, errHTTPClientNotFound) { + t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientNotFound) + } + + err = tracker.checkAndRegister(newLovelyClient) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if !tracker.contains(newLovelyClient) { + t.Fatalf("received: '%v' but expected: '%v'", false, true) + } + + err = tracker.deRegister(newLovelyClient) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if tracker.contains(newLovelyClient) { + t.Fatalf("received: '%v' but expected: '%v'", true, false) + } +} + +func TestNewProtectedClient(t *testing.T) { + t.Parallel() + if _, err := newProtectedClient(nil); !errors.Is(err, errHTTPClientIsNil) { + t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) + } + + newLovelyClient := new(http.Client) + protec, err := newProtectedClient(newLovelyClient) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if protec.protected != newLovelyClient { + t.Fatal("unexpected value") + } +} + +func TestClientSetProxy(t *testing.T) { + t.Parallel() + err := (&client{}).setProxy(nil) + if !errors.Is(err, errNoProxyURLSupplied) { + t.Fatalf("received: '%v' but expected: '%v'", err, errNoProxyURLSupplied) + } + pp, err := url.Parse("lol.com") + if err != nil { + t.Fatal(err) + } + err = (&client{protected: new(http.Client)}).setProxy(pp) + if !errors.Is(err, errTransportNotSet) { + t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet) + } + err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setProxy(pp) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } +} + +func TestClientSetHTTPClientTimeout(t *testing.T) { + t.Parallel() + err := (&client{protected: new(http.Client)}).setHTTPClientTimeout(time.Second) + if !errors.Is(err, errTransportNotSet) { + t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet) + } + err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setHTTPClientTimeout(time.Second) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } +} + +func TestRelease(t *testing.T) { + t.Parallel() + newLovelyClient, err := newProtectedClient(common.NewHTTPClientWithTimeout(0)) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if !tracker.contains(newLovelyClient.protected) { + t.Fatalf("received: '%v' but expected: '%v'", false, true) + } + + err = newLovelyClient.release() + if !errors.Is(err, nil) { + t.Fatalf("received: '%v' but expected: '%v'", err, nil) + } + + if tracker.contains(newLovelyClient.protected) { + t.Fatalf("received: '%v' but expected: '%v'", true, false) + } +} diff --git a/exchanges/request/limit.go b/exchanges/request/limit.go index 8ff7a3f7..e2179399 100644 --- a/exchanges/request/limit.go +++ b/exchanges/request/limit.go @@ -3,12 +3,19 @@ package request import ( "context" "errors" + "fmt" "sync/atomic" "time" "golang.org/x/time/rate" ) +// Defines rate limiting errors +var ( + ErrRateLimiterAlreadyDisabled = errors.New("rate limiter already disabled") + ErrRateLimiterAlreadyEnabled = errors.New("rate limiter already enabled") +) + // Const here define individual functionality sub types for rate limiting const ( Unset EndpointLimit = iota @@ -60,6 +67,9 @@ func NewBasicRateLimit(interval time.Duration, actions int) Limiter { // InitiateRateLimit sleeps for designated end point rate limits func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) error { + if r == nil { + return ErrRequestSystemIsNil + } if atomic.LoadInt32(&r.disableRateLimiter) == 1 { return nil } @@ -73,16 +83,22 @@ func (r *Requester) InitiateRateLimit(ctx context.Context, e EndpointLimit) erro // DisableRateLimiter disables the rate limiting system for the exchange func (r *Requester) DisableRateLimiter() error { + if r == nil { + return ErrRequestSystemIsNil + } if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 0, 1) { - return errors.New("rate limiter already disabled") + return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyDisabled) } return nil } // EnableRateLimiter enables the rate limiting system for the exchange func (r *Requester) EnableRateLimiter() error { + if r == nil { + return ErrRequestSystemIsNil + } if !atomic.CompareAndSwapInt32(&r.disableRateLimiter, 1, 0) { - return errors.New("rate limiter already enabled") + return fmt.Errorf("%s %w", r.name, ErrRateLimiterAlreadyEnabled) } return nil } diff --git a/exchanges/request/request.go b/exchanges/request/request.go index 15c77938..99b678cb 100644 --- a/exchanges/request/request.go +++ b/exchanges/request/request.go @@ -20,7 +20,9 @@ import ( ) var ( - errRequestSystemIsNil = errors.New("request system is nil") + // ErrRequestSystemIsNil defines and error if the request system has not + // been set up yet. + ErrRequestSystemIsNil = errors.New("request system is nil") errMaxRequestJobs = errors.New("max request jobs reached") errRequestFunctionIsNil = errors.New("request function is nil") errRequestItemNil = errors.New("request item is nil") @@ -28,13 +30,18 @@ var ( errHeaderResponseMapIsNil = errors.New("header response map is nil") errFailedToRetryRequest = errors.New("failed to retry request") errContextRequired = errors.New("context is required") + errTransportNotSet = errors.New("transport not set, cannot set timeout") ) // New returns a new Requester -func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requester { +func New(name string, httpRequester *http.Client, opts ...RequesterOption) (*Requester, error) { + protectedClient, err := newProtectedClient(httpRequester) + if err != nil { + return nil, fmt.Errorf("cannot set up a new requester for %s: %w", name, err) + } r := &Requester{ - HTTPClient: httpRequester, - Name: name, + _HTTPClient: protectedClient, + name: name, backoff: DefaultBackoff(), retryPolicy: DefaultRetryPolicy, maxRetries: MaxRetryAttempts, @@ -46,13 +53,13 @@ func New(name string, httpRequester *http.Client, opts ...RequesterOption) *Requ o(r) } - return r + return r, nil } // SendPayload handles sending HTTP/HTTPS requests func (r *Requester) SendPayload(ctx context.Context, ep EndpointLimit, newRequest Generate) error { if r == nil { - return errRequestSystemIsNil + return ErrRequestSystemIsNil } if ctx == nil { @@ -107,8 +114,8 @@ func (i *Item) validateRequest(ctx context.Context, r *Requester) (*http.Request req.Header.Add(k, v) } - if r.UserAgent != "" && req.Header.Get(userAgent) == "" { - req.Header.Add(userAgent, r.UserAgent) + if r.userAgent != "" && req.Header.Get(userAgent) == "" { + req.Header.Add(userAgent, r.userAgent) } return req, nil @@ -141,22 +148,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe } if p.Verbose { - log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.Name, attempt, p.Path) + log.Debugf(log.RequestSys, "%s attempt %d request path: %s", r.name, attempt, p.Path) for k, d := range req.Header { - log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.Name, k, d) + log.Debugf(log.RequestSys, "%s request header [%s]: %s", r.name, k, d) } - log.Debugf(log.RequestSys, "%s request type: %s", r.Name, p.Method) + log.Debugf(log.RequestSys, "%s request type: %s", r.name, p.Method) if p.Body != nil { - log.Debugf(log.RequestSys, "%s request body: %v", r.Name, p.Body) + log.Debugf(log.RequestSys, "%s request body: %v", r.name, p.Body) } } start := time.Now() - resp, err := r.HTTPClient.Do(req) + resp, err := r._HTTPClient.do(req) if r.reporter != nil { - r.reporter.Latency(r.Name, p.Method, p.Path, time.Since(start)) + r.reporter.Latency(r.name, p.Method, p.Path, time.Since(start)) } if retry, checkErr := r.retryPolicy(resp, err); checkErr != nil { @@ -191,7 +198,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe if p.Verbose { log.Errorf(log.RequestSys, "%s request has failed. Retrying request in %s, attempt %d", - r.Name, + r.name, delay, attempt) } @@ -213,7 +220,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe if p.HTTPRecording { // This dumps http responses for future mocking implementations - err = mock.HTTPRecord(resp, r.Name, contents) + err = mock.HTTPRecord(resp, r.name, contents) if err != nil { return fmt.Errorf("mock recording failure %s", err) } @@ -228,21 +235,27 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusAccepted { return fmt.Errorf("%s unsuccessful HTTP status code: %d raw response: %s", - r.Name, + r.name, resp.StatusCode, string(contents)) } if p.HTTPDebugging { - dump, err := httputil.DumpResponse(resp, false) + dump, dumpErr := httputil.DumpResponse(resp, false) if err != nil { - log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", err) + log.Errorf(log.RequestSys, "DumpResponse invalid response: %v:", dumpErr) } log.Debugf(log.RequestSys, "DumpResponse Headers (%v):\n%s", p.Path, dump) log.Debugf(log.RequestSys, "DumpResponse Body (%v):\n %s", p.Path, string(contents)) } - resp.Body.Close() + err = resp.Body.Close() + if err != nil { + log.Errorf(log.RequestSys, + "%s failed to close request body %s", + r.name, + err) + } if p.Verbose { log.Debugf(log.RequestSys, "HTTP status: %s, Code: %v", @@ -251,7 +264,7 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe if !p.HTTPDebugging { log.Debugf(log.RequestSys, "%s raw response: %s", - r.Name, + r.name, string(contents)) } } @@ -259,6 +272,22 @@ func (r *Requester) doRequest(ctx context.Context, endpoint EndpointLimit, newRe } } +func (r *Requester) drainBody(body io.ReadCloser) { + if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil { + log.Errorf(log.RequestSys, + "%s failed to drain request body %s", + r.name, + err) + } + + if err := body.Close(); err != nil { + log.Errorf(log.RequestSys, + "%s failed to close request body %s", + r.name, + err) + } +} + // GetNonce returns a nonce for requests. This locks and enforces concurrent // nonce FIFO on the buffered job channel func (r *Requester) GetNonce(isNano bool) nonce.Value { @@ -285,27 +314,57 @@ func (r *Requester) GetNonceMilli() nonce.Value { return r.Nonce.GetInc() } -// SetProxy sets a proxy address to the client transport +// SetProxy sets a proxy address for the client transport func (r *Requester) SetProxy(p *url.URL) error { - if p.String() == "" { - return errors.New("no proxy URL supplied") + if r == nil { + return ErrRequestSystemIsNil } + return r._HTTPClient.setProxy(p) +} - t, ok := r.HTTPClient.Transport.(*http.Transport) - if !ok { - return errors.New("transport not set, cannot set proxy") +// SetHTTPClient sets exchanges HTTP client +func (r *Requester) SetHTTPClient(newClient *http.Client) error { + if r == nil { + return ErrRequestSystemIsNil } - t.Proxy = http.ProxyURL(p) - t.TLSHandshakeTimeout = proxyTLSTimeout + protectedClient, err := newProtectedClient(newClient) + if err != nil { + return err + } + r._HTTPClient = protectedClient return nil } -func (r *Requester) drainBody(body io.ReadCloser) { - defer body.Close() - if _, err := io.Copy(ioutil.Discard, io.LimitReader(body, drainBodyLimit)); err != nil { - log.Errorf(log.RequestSys, - "%s failed to drain request body %s", - r.Name, - err) +// SetClientTimeout sets the timeout value for the exchanges HTTP Client and +// also the underlying transports idle connection timeout +func (r *Requester) SetHTTPClientTimeout(timeout time.Duration) error { + if r == nil { + return ErrRequestSystemIsNil } + return r._HTTPClient.setHTTPClientTimeout(timeout) +} + +// SetHTTPClientUserAgent sets the exchanges HTTP user agent +func (r *Requester) SetHTTPClientUserAgent(userAgent string) error { + if r == nil { + return ErrRequestSystemIsNil + } + r.userAgent = userAgent + return nil +} + +// GetHTTPClientUserAgent gets the exchanges HTTP user agent +func (r *Requester) GetHTTPClientUserAgent() (string, error) { + if r == nil { + return "", ErrRequestSystemIsNil + } + return r.userAgent, nil +} + +// Shutdown releases persistent memory for garbage collection. +func (r *Requester) Shutdown() error { + if r == nil { + return ErrRequestSystemIsNil + } + return r._HTTPClient.release() } diff --git a/exchanges/request/request_test.go b/exchanges/request/request_test.go index e4d0dbd7..8eccdf3e 100644 --- a/exchanges/request/request_test.go +++ b/exchanges/request/request_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/thrasher-corp/gocryptotrader/common" "golang.org/x/time/rate" ) @@ -126,12 +127,15 @@ func TestNewRateLimit(t *testing.T) { func TestCheckRequest(t *testing.T) { t.Parallel() - r := New("TestRequest", + r, err := New("TestRequest", new(http.Client)) + if err != nil { + t.Fatal(err) + } ctx := context.Background() var check *Item - _, err := check.validateRequest(ctx, &Requester{}) + _, err = check.validateRequest(ctx, &Requester{}) if err == nil { t.Fatal(unexpected) } @@ -179,7 +183,7 @@ func TestCheckRequest(t *testing.T) { } // Test user agent set - r.UserAgent = "r00t axxs" + r.userAgent = "r00t axxs" req, err := check.validateRequest(ctx, r) if err != nil { t.Fatal(err) @@ -226,12 +230,15 @@ var globalshell = GlobalLimitTest{ func TestDoRequest(t *testing.T) { t.Parallel() - r := New("test", new(http.Client), WithLimiter(&globalshell)) + r, err := New("test", new(http.Client), WithLimiter(&globalshell)) + if err != nil { + t.Fatal(err) + } ctx := context.Background() - err := (*Requester)(nil).SendPayload(ctx, Unset, nil) - if !errors.Is(errRequestSystemIsNil, err) { - t.Fatalf("expected: %v but received: %v", errRequestSystemIsNil, err) + err = (*Requester)(nil).SendPayload(ctx, Unset, nil) + if !errors.Is(ErrRequestSystemIsNil, err) { + t.Fatalf("expected: %v but received: %v", ErrRequestSystemIsNil, err) } err = r.SendPayload(ctx, Unset, nil) if !errors.Is(errRequestFunctionIsNil, err) { @@ -305,8 +312,16 @@ func TestDoRequest(t *testing.T) { // reset jobs r.jobs = 0 + r._HTTPClient, err = newProtectedClient(common.NewHTTPClientWithTimeout(0)) + if err != nil { + t.Fatal(err) + } + // timeout checker - r.HTTPClient.Timeout = time.Millisecond * 50 + err = r._HTTPClient.setHTTPClientTimeout(time.Millisecond * 50) + if err != nil { + t.Fatal(err) + } err = r.SendPayload(ctx, UnAuth, func() (*Item, error) { return &Item{Path: testURL + "/timeout"}, nil }) @@ -314,7 +329,10 @@ func TestDoRequest(t *testing.T) { t.Fatalf("received: %v but expected: %v", err, errFailedToRetryRequest) } // reset timeout - r.HTTPClient.Timeout = 0 + err = r._HTTPClient.setHTTPClientTimeout(0) + if err != nil { + t.Fatal(err) + } // Check JSON var resp struct { @@ -403,7 +421,10 @@ func TestDoRequest_Retries(t *testing.T) { backoff := func(n int) time.Duration { return 0 } - r := New("test", new(http.Client), WithBackoff(backoff)) + r, err := New("test", new(http.Client), WithBackoff(backoff)) + if err != nil { + t.Fatal(err) + } var failed int32 var wg sync.WaitGroup wg.Add(4) @@ -444,8 +465,11 @@ func TestDoRequest_RetryNonRecoverable(t *testing.T) { backoff := func(n int) time.Duration { return 0 } - r := New("test", new(http.Client), WithBackoff(backoff)) - err := r.SendPayload(context.Background(), Unset, func() (*Item, error) { + r, err := New("test", new(http.Client), WithBackoff(backoff)) + if err != nil { + t.Fatal(err) + } + err = r.SendPayload(context.Background(), Unset, func() (*Item, error) { return &Item{ Method: http.MethodGet, Path: testURL + "/always-retry", @@ -466,8 +490,11 @@ func TestDoRequest_NotRetryable(t *testing.T) { backoff := func(n int) time.Duration { return time.Duration(n) * time.Millisecond } - r := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff)) - err := r.SendPayload(context.Background(), Unset, func() (*Item, error) { + r, err := New("test", new(http.Client), WithRetryPolicy(retry), WithBackoff(backoff)) + if err != nil { + t.Fatal(err) + } + err = r.SendPayload(context.Background(), Unset, func() (*Item, error) { return &Item{ Method: http.MethodGet, Path: testURL + "/always-retry", @@ -480,17 +507,22 @@ func TestDoRequest_NotRetryable(t *testing.T) { func TestGetNonce(t *testing.T) { t.Parallel() - r := New("test", + r, err := New("test", new(http.Client), WithLimiter(&globalshell)) - + if err != nil { + t.Fatal(err) + } if n1, n2 := r.GetNonce(false), r.GetNonce(false); n1 == n2 { t.Fatal(unexpected) } - r2 := New("test", + r2, err := New("test", new(http.Client), WithLimiter(&globalshell)) + if err != nil { + t.Fatal(err) + } if n1, n2 := r2.GetNonce(true), r2.GetNonce(true); n1 == n2 { t.Fatal(unexpected) } @@ -498,9 +530,12 @@ func TestGetNonce(t *testing.T) { func TestGetNonceMillis(t *testing.T) { t.Parallel() - r := New("test", + r, err := New("test", new(http.Client), WithLimiter(&globalshell)) + if err != nil { + t.Fatal(err) + } if m1, m2 := r.GetNonceMilli(), r.GetNonceMilli(); m1 == m2 { log.Fatal(unexpected) } @@ -508,9 +543,17 @@ func TestGetNonceMillis(t *testing.T) { func TestSetProxy(t *testing.T) { t.Parallel() - r := New("test", + var r *Requester + err := r.SetProxy(nil) + if !errors.Is(err, ErrRequestSystemIsNil) { + t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) + } + r, err = New("test", &http.Client{Transport: new(http.Transport)}, WithLimiter(&globalshell)) + if err != nil { + t.Fatal(err) + } u, err := url.Parse("http://www.google.com") if err != nil { t.Fatal(err) @@ -530,9 +573,12 @@ func TestSetProxy(t *testing.T) { } func TestBasicLimiter(t *testing.T) { - r := New("test", + r, err := New("test", new(http.Client), WithLimiter(NewBasicRateLimit(time.Second, 1))) + if err != nil { + t.Fatal(err) + } i := Item{ Path: "http://www.google.com", Method: http.MethodGet, @@ -540,7 +586,7 @@ func TestBasicLimiter(t *testing.T) { ctx := context.Background() tn := time.Now() - err := r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil }) + err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil }) if err != nil { t.Fatal(err) } @@ -561,13 +607,16 @@ func TestBasicLimiter(t *testing.T) { } func TestEnableDisableRateLimit(t *testing.T) { - r := New("TestRequest", + r, err := New("TestRequest", new(http.Client), WithLimiter(NewBasicRateLimit(time.Minute, 1))) + if err != nil { + t.Fatal(err) + } ctx := context.Background() var resp interface{} - err := r.SendPayload(ctx, Auth, func() (*Item, error) { + err = r.SendPayload(ctx, Auth, func() (*Item, error) { return &Item{ Method: http.MethodGet, Path: testURL, @@ -635,3 +684,71 @@ func TestEnableDisableRateLimit(t *testing.T) { // Correct test } } + +func TestSetHTTPClient(t *testing.T) { + var r *Requester + err := r.SetHTTPClient(nil) + if !errors.Is(err, ErrRequestSystemIsNil) { + t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) + } + client := new(http.Client) + r = new(Requester) + err = r.SetHTTPClient(client) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v', but expected: '%v'", err, nil) + } + err = r.SetHTTPClient(client) + if !errors.Is(err, errCannotReuseHTTPClient) { + t.Fatalf("received: '%v', but expected: '%v'", err, errCannotReuseHTTPClient) + } +} + +func TestSetHTTPClientTimeout(t *testing.T) { + var r *Requester + err := r.SetHTTPClientTimeout(0) + if !errors.Is(err, ErrRequestSystemIsNil) { + t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) + } + r = new(Requester) + err = r.SetHTTPClient(common.NewHTTPClientWithTimeout(2)) + if err != nil { + t.Fatal(err) + } + err = r.SetHTTPClientTimeout(time.Second) + if !errors.Is(err, nil) { + t.Fatalf("received: '%v', but expected: '%v'", err, nil) + } +} + +func TestSetHTTPClientUserAgent(t *testing.T) { + var r *Requester + err := r.SetHTTPClientUserAgent("") + if !errors.Is(err, ErrRequestSystemIsNil) { + t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) + } + r = new(Requester) + err = r.SetHTTPClientUserAgent("") + if !errors.Is(err, nil) { + t.Fatalf("received: '%v', but expected: '%v'", err, nil) + } +} + +func TestGetHTTPClientUserAgent(t *testing.T) { + var r *Requester + _, err := r.GetHTTPClientUserAgent() + if !errors.Is(err, ErrRequestSystemIsNil) { + t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) + } + r = new(Requester) + err = r.SetHTTPClientUserAgent("sillyness") + if !errors.Is(err, nil) { + t.Fatalf("received: '%v', but expected: '%v'", err, nil) + } + ua, err := r.GetHTTPClientUserAgent() + if !errors.Is(err, nil) { + t.Fatalf("received: '%v', but expected: '%v'", err, nil) + } + if ua != "sillyness" { + t.Fatal("unexpected value") + } +} diff --git a/exchanges/request/request_types.go b/exchanges/request/request_types.go index 6279d8fc..7bc58701 100644 --- a/exchanges/request/request_types.go +++ b/exchanges/request/request_types.go @@ -28,11 +28,11 @@ var ( // Requester struct for the request client type Requester struct { - HTTPClient *http.Client + _HTTPClient *client limiter Limiter reporter Reporter - Name string - UserAgent string + name string + userAgent string maxRetries int jobs int32 Nonce nonce.Nonce diff --git a/exchanges/sharedtestvalues/customex.go b/exchanges/sharedtestvalues/customex.go index 155ebc8b..7ea2cb22 100644 --- a/exchanges/sharedtestvalues/customex.go +++ b/exchanges/sharedtestvalues/customex.go @@ -198,11 +198,12 @@ func (c *CustomEx) WithdrawFiatFundsToInternationalBank(ctx context.Context, wit return nil, nil } -func (c *CustomEx) SetHTTPClientUserAgent(ua string) { +func (c *CustomEx) SetHTTPClientUserAgent(ua string) error { + return nil } -func (c *CustomEx) GetHTTPClientUserAgent() string { - return "" +func (c *CustomEx) GetHTTPClientUserAgent() (string, error) { + return "", nil } func (c *CustomEx) SetClientProxyAddress(addr string) error { diff --git a/exchanges/yobit/yobit_wrapper.go b/exchanges/yobit/yobit_wrapper.go index 3dbeea5a..3ab7e3e5 100644 --- a/exchanges/yobit/yobit_wrapper.go +++ b/exchanges/yobit/yobit_wrapper.go @@ -97,10 +97,13 @@ func (y *Yobit) SetDefaults() { }, } - y.Requester = request.New(y.Name, + y.Requester, err = request.New(y.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), // Server responses are cached every 2 seconds. request.WithLimiter(request.NewBasicRateLimit(time.Second, 1))) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } y.API.Endpoints = y.NewEndpoints() err = y.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: apiPublicURL, diff --git a/exchanges/zb/zb_mock_test.go b/exchanges/zb/zb_mock_test.go index f1ac27e7..158efa02 100644 --- a/exchanges/zb/zb_mock_test.go +++ b/exchanges/zb/zb_mock_test.go @@ -47,7 +47,10 @@ func TestMain(m *testing.M) { log.Fatalf("Mock server error %s", err) } - z.HTTPClient = newClient + err = z.SetHTTPClient(newClient) + if err != nil { + log.Fatalf("Mock server error %s", err) + } endpoints := z.API.Endpoints.GetURLMap() for k := range endpoints { err = z.API.Endpoints.SetRunning(k, serverDetails) diff --git a/exchanges/zb/zb_wrapper.go b/exchanges/zb/zb_wrapper.go index e0103993..68032917 100644 --- a/exchanges/zb/zb_wrapper.go +++ b/exchanges/zb/zb_wrapper.go @@ -130,9 +130,12 @@ func (z *ZB) SetDefaults() { }, } - z.Requester = request.New(z.Name, + z.Requester, err = request.New(z.Name, common.NewHTTPClientWithTimeout(exchange.DefaultHTTPTimeout), request.WithLimiter(SetRateLimit())) + if err != nil { + log.Errorln(log.ExchangeSys, err) + } z.API.Endpoints = z.NewEndpoints() err = z.API.Endpoints.SetDefaultEndpoints(map[exchange.URL]string{ exchange.RestSpot: zbTradeURL,