From 19b8957f3fd0f67cc587f5bc2f6fe322979e2f31 Mon Sep 17 00:00:00 2001 From: Adrian Gallagher Date: Tue, 10 Jun 2025 16:29:57 +1000 Subject: [PATCH] codebase: Replace !errors.Is(err, target) with testify (#1931) * tests: Replace !errors.Is(err, target) with testify equivalents * codebase: Manual !errors.Is(err, target) replacements * typo: Replace errMisMatchedEvent with errMismatchedEvent * tests: Enhance error messages for better output * tests: Refactor error assertions in various test cases to use require and improve clarity * misc linter: Fix assert should wording * tests: Simplify assertions in TestCreateSignals for clarity and conciseness * tests: Enhance assertion message in TestCreateSignals --- .github/workflows/misc.yml | 7 + backtester/common/common_test.go | 17 +- backtester/config/batcktesterconfig_test.go | 6 +- backtester/config/strategyconfig_test.go | 149 ++-- backtester/data/data.go | 4 +- backtester/data/data_test.go | 61 +- backtester/data/data_types.go | 2 +- backtester/data/kline/csv/csv_test.go | 9 +- .../data/kline/database/database_test.go | 9 +- backtester/data/kline/kline_test.go | 14 +- backtester/engine/backtest_test.go | 324 +++----- backtester/engine/live_test.go | 131 +-- backtester/engine/taskmanager_test.go | 83 +- .../eventholder/eventholder_test.go | 5 +- .../eventhandlers/exchange/exchange_test.go | 161 ++-- .../portfolio/compliance/compliance_test.go | 9 +- .../portfolio/holdings/holdings_test.go | 10 +- .../eventhandlers/portfolio/portfolio_test.go | 416 +++------- .../eventhandlers/portfolio/risk/risk_test.go | 27 +- .../eventhandlers/portfolio/size/size_test.go | 45 +- .../statistics/currencystatistics_test.go | 5 +- .../statistics/fundingstatistics_test.go | 44 +- .../statistics/statistics_test.go | 127 +-- .../strategies/base/base_test.go | 14 +- .../binancecashandcarry_test.go | 125 +-- .../dollarcostaverage_test.go | 14 +- .../eventhandlers/strategies/rsi/rsi_test.go | 32 +- .../strategies/strategies_test.go | 19 +- .../top2bottom2/top2bottom2_test.go | 27 +- backtester/funding/collateralpair_test.go | 107 +-- backtester/funding/funding_test.go | 207 ++--- backtester/funding/item_test.go | 50 +- backtester/funding/spotpair_test.go | 96 +-- .../trackingcurrencies_test.go | 18 +- backtester/plugins/strategies/loader_test.go | 9 +- backtester/report/chart_test.go | 24 +- backtester/report/report_test.go | 11 +- .../exchange_wrapper_standards_test.go | 12 +- common/common_test.go | 44 +- common/math/math_test.go | 180 ++-- config/config_test.go | 9 +- currency/code_test.go | 13 +- currency/currency_test.go | 21 +- .../exchangeratesapi_test.go | 35 +- currency/manager_test.go | 62 +- currency/pair_test.go | 16 +- currency/pairs_test.go | 80 +- database/database_test.go | 40 +- database/repository/candle/candle_test.go | 20 +- .../datahistoryjob/datahistoryjob_test.go | 6 +- database/repository/withdraw/withdraw_test.go | 13 +- engine/apiserver_test.go | 39 +- engine/communication_manager_test.go | 26 +- engine/connection_manager_test.go | 28 +- engine/currency_state_manager_test.go | 80 +- engine/database_connection_test.go | 60 +- engine/datahistory_manager_test.go | 335 +++----- engine/depositaddress_test.go | 40 +- engine/engine_test.go | 25 +- engine/event_manager_test.go | 44 +- engine/exchange_manager_test.go | 56 +- engine/helpers_test.go | 40 +- engine/ntp_manager_test.go | 44 +- engine/order_manager_test.go | 149 +--- engine/portfolio_manager_test.go | 25 +- engine/rpcserver_test.go | 456 +++------- engine/sync_manager_test.go | 68 +- engine/websocketroutine_manager_test.go | 54 +- engine/withdraw_manager_test.go | 62 +- exchanges/account/account_test.go | 9 +- exchanges/account/credentials_test.go | 13 +- exchanges/alert/alert_test.go | 5 +- exchanges/asset/asset_test.go | 10 +- exchanges/binance/binance_test.go | 223 ++--- exchanges/binance/ratelimit_test.go | 5 +- exchanges/binanceus/binanceus_test.go | 289 +++---- exchanges/bitfinex/bitfinex_test.go | 17 +- exchanges/btcmarkets/btcmarkets_test.go | 46 +- exchanges/bybit/bybit.go | 77 +- exchanges/bybit/bybit_test.go | 777 +++++++----------- exchanges/collateral/collateral_test.go | 12 +- exchanges/credentials_test.go | 33 +- .../currencystate/currency_state_test.go | 104 +-- exchanges/exchange_test.go | 145 ++-- exchanges/fill/fill_test.go | 7 +- exchanges/futures/futures_test.go | 377 +++------ exchanges/gateio/gateio_test.go | 57 +- exchanges/gemini/gemini_test.go | 9 +- exchanges/kline/kline_test.go | 79 +- exchanges/kline/request_test.go | 61 +- exchanges/kline/technical_analysis_test.go | 201 ++--- exchanges/kline/weighted_price_test.go | 8 +- exchanges/kraken/kraken_test.go | 11 +- exchanges/kucoin/kucoin_test.go | 4 +- exchanges/okx/okx_test.go | 5 +- exchanges/orderbook/calculator_test.go | 91 +- exchanges/orderbook/orderbook_test.go | 72 +- exchanges/orderbook/tranches_test.go | 47 +- exchanges/poloniex/currency_details_test.go | 77 +- exchanges/poloniex/poloniex_test.go | 145 +--- exchanges/request/client_test.go | 37 +- exchanges/request/request_test.go | 41 +- exchanges/ticker/ticker_test.go | 5 +- gctscript/modules/gct/errors_test.go | 14 +- gctscript/modules/gct/gct_test.go | 160 +--- .../modules/ta/indicators/indicators_test.go | 33 +- gctscript/vm/vm_test.go | 90 +- gctscript/wrappers/gct/gctwrapper_test.go | 104 +-- log/logger_test.go | 55 +- 109 files changed, 2485 insertions(+), 5670 deletions(-) diff --git a/.github/workflows/misc.yml b/.github/workflows/misc.yml index b53a2501..1c0160a9 100644 --- a/.github/workflows/misc.yml +++ b/.github/workflows/misc.yml @@ -44,3 +44,10 @@ jobs: grep -r -n --include='*_test.go' --color=always -E "errors.Is\([^,]+, nil" . || exit 0 echo "::error::Replace errors.Is(err, nil) with testify equivalents" exit 1 + + - name: Check for !errors.Is(err, target) usage + run: | + grep -r -n --include='*_test.go' --color=always -P '!errors\.Is\(\s*[^,]+\s*,\s*[^)]+\s*\)' . || exit 0 + echo "::error::Replace !errors.Is(err, target) with testify equivalents" + exit 1 + diff --git a/backtester/common/common_test.go b/backtester/common/common_test.go index c61845bd..66f88c64 100644 --- a/backtester/common/common_test.go +++ b/backtester/common/common_test.go @@ -1,7 +1,6 @@ package common import ( - "errors" "fmt" "testing" @@ -220,19 +219,13 @@ func TestPurgeColours(t *testing.T) { func TestGenerateFileName(t *testing.T) { t.Parallel() _, err := GenerateFileName("", "") - if !errors.Is(err, errCannotGenerateFileName) { - t.Errorf("received '%v' expected '%v'", err, errCannotGenerateFileName) - } + assert.ErrorIs(t, err, errCannotGenerateFileName) _, err = GenerateFileName("hello", "") - if !errors.Is(err, errCannotGenerateFileName) { - t.Errorf("received '%v' expected '%v'", err, errCannotGenerateFileName) - } + assert.ErrorIs(t, err, errCannotGenerateFileName) _, err = GenerateFileName("", "moto") - if !errors.Is(err, errCannotGenerateFileName) { - t.Errorf("received '%v' expected '%v'", err, errCannotGenerateFileName) - } + assert.ErrorIs(t, err, errCannotGenerateFileName) _, err = GenerateFileName("hello", "moto") assert.NoError(t, err) @@ -248,7 +241,5 @@ func TestRegisterBacktesterSubLoggers(t *testing.T) { assert.NoError(t, err) err = RegisterBacktesterSubLoggers() - if !errors.Is(err, log.ErrSubLoggerAlreadyRegistered) { - t.Errorf("received '%v' expected '%v'", err, log.ErrSubLoggerAlreadyRegistered) - } + assert.ErrorIs(t, err, log.ErrSubLoggerAlreadyRegistered) } diff --git a/backtester/config/batcktesterconfig_test.go b/backtester/config/batcktesterconfig_test.go index d88872d6..f4a8f953 100644 --- a/backtester/config/batcktesterconfig_test.go +++ b/backtester/config/batcktesterconfig_test.go @@ -1,10 +1,10 @@ package config import ( - "errors" "path/filepath" "testing" + "github.com/stretchr/testify/assert" "github.com/thrasher-corp/gocryptotrader/backtester/common" "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -32,9 +32,7 @@ func TestLoadBacktesterConfig(t *testing.T) { } _, err = ReadBacktesterConfigFromPath("test") - if !errors.Is(err, common.ErrFileNotFound) { - t.Errorf("received '%v' expected '%v'", err, common.ErrFileNotFound) - } + assert.ErrorIs(t, err, common.ErrFileNotFound) } func TestGenerateDefaultConfig(t *testing.T) { diff --git a/backtester/config/strategyconfig_test.go b/backtester/config/strategyconfig_test.go index 4a83172f..eb384427 100644 --- a/backtester/config/strategyconfig_test.go +++ b/backtester/config/strategyconfig_test.go @@ -1,7 +1,6 @@ package config import ( - "errors" "os" "path/filepath" "testing" @@ -74,30 +73,26 @@ func TestValidateDate(t *testing.T) { DatabaseData: &DatabaseData{}, } err = c.validateDate() - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) + c.DataSettings.DatabaseData.StartDate = time.Now() c.DataSettings.DatabaseData.EndDate = c.DataSettings.DatabaseData.StartDate err = c.validateDate() - if !errors.Is(err, gctcommon.ErrStartEqualsEnd) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrStartEqualsEnd) - } + assert.ErrorIs(t, err, gctcommon.ErrStartEqualsEnd) + c.DataSettings.DatabaseData.EndDate = c.DataSettings.DatabaseData.StartDate.Add(time.Minute) err = c.validateDate() assert.NoError(t, err) c.DataSettings.APIData = &APIData{} err = c.validateDate() - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) + c.DataSettings.APIData.StartDate = time.Now() c.DataSettings.APIData.EndDate = c.DataSettings.APIData.StartDate err = c.validateDate() - if !errors.Is(err, gctcommon.ErrStartEqualsEnd) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrStartEqualsEnd) - } + assert.ErrorIs(t, err, gctcommon.ErrStartEqualsEnd) + c.DataSettings.APIData.EndDate = c.DataSettings.APIData.StartDate.Add(time.Minute) err = c.validateDate() assert.NoError(t, err) @@ -107,102 +102,79 @@ func TestValidateCurrencySettings(t *testing.T) { t.Parallel() c := Config{} err := c.validateCurrencySettings() - if !errors.Is(err, errNoCurrencySettings) { - t.Errorf("received: %v, expected: %v", err, errNoCurrencySettings) - } + assert.ErrorIs(t, err, errNoCurrencySettings) + c.CurrencySettings = append(c.CurrencySettings, CurrencySettings{}) err = c.validateCurrencySettings() - if !errors.Is(err, errUnsetCurrency) { - t.Errorf("received: %v, expected: %v", err, errUnsetCurrency) - } + assert.ErrorIs(t, err, errUnsetCurrency) + leet := decimal.NewFromInt(1337) c.CurrencySettings[0].SpotDetails = &SpotDetails{InitialQuoteFunds: &leet} err = c.validateCurrencySettings() - if !errors.Is(err, errUnsetCurrency) { - t.Errorf("received: %v, expected: %v", err, errUnsetCurrency) - } + assert.ErrorIs(t, err, errUnsetCurrency) + c.CurrencySettings[0].Base = currency.NewCode("lol") err = c.validateCurrencySettings() - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received: %v, expected: %v", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + c.CurrencySettings[0].Asset = asset.Spot err = c.validateCurrencySettings() - if !errors.Is(err, errUnsetExchange) { - t.Errorf("received: %v, expected: %v", err, errUnsetExchange) - } + assert.ErrorIs(t, err, errUnsetExchange) + c.CurrencySettings[0].ExchangeName = "lol" err = c.validateCurrencySettings() assert.NoError(t, err) c.CurrencySettings[0].Asset = asset.PerpetualSwap err = c.validateCurrencySettings() - if !errors.Is(err, errPerpetualsUnsupported) { - t.Errorf("received: %v, expected: %v", err, errPerpetualsUnsupported) - } + assert.ErrorIs(t, err, errPerpetualsUnsupported) c.CurrencySettings[0].Asset = asset.USDTMarginedFutures c.CurrencySettings[0].Quote = currency.NewCode("PERP") err = c.validateCurrencySettings() - if !errors.Is(err, errPerpetualsUnsupported) { - t.Errorf("received: %v, expected: %v", err, errPerpetualsUnsupported) - } + assert.ErrorIs(t, err, errPerpetualsUnsupported) c.CurrencySettings[0].MinimumSlippagePercent = decimal.NewFromInt(2) c.CurrencySettings[0].MaximumSlippagePercent = decimal.NewFromInt(3) c.CurrencySettings[0].Quote = currency.NewCode("USD") err = c.validateCurrencySettings() - if !errors.Is(err, errFeatureIncompatible) { - t.Errorf("received: %v, expected: %v", err, errFeatureIncompatible) - } + assert.ErrorIs(t, err, errFeatureIncompatible) c.CurrencySettings[0].Asset = asset.Spot c.CurrencySettings[0].MinimumSlippagePercent = decimal.NewFromInt(-1) err = c.validateCurrencySettings() - if !errors.Is(err, errBadSlippageRates) { - t.Errorf("received: %v, expected: %v", err, errBadSlippageRates) - } + assert.ErrorIs(t, err, errBadSlippageRates) + c.CurrencySettings[0].MinimumSlippagePercent = decimal.NewFromInt(2) c.CurrencySettings[0].MaximumSlippagePercent = decimal.NewFromInt(-1) err = c.validateCurrencySettings() - if !errors.Is(err, errBadSlippageRates) { - t.Errorf("received: %v, expected: %v", err, errBadSlippageRates) - } + assert.ErrorIs(t, err, errBadSlippageRates) + c.CurrencySettings[0].MinimumSlippagePercent = decimal.NewFromInt(2) c.CurrencySettings[0].MaximumSlippagePercent = decimal.NewFromInt(1) err = c.validateCurrencySettings() - if !errors.Is(err, errBadSlippageRates) { - t.Errorf("received: %v, expected: %v", err, errBadSlippageRates) - } + assert.ErrorIs(t, err, errBadSlippageRates) c.CurrencySettings[0].SpotDetails = &SpotDetails{} err = c.validateCurrencySettings() - if !errors.Is(err, errBadInitialFunds) { - t.Errorf("received: %v, expected: %v", err, errBadInitialFunds) - } + assert.ErrorIs(t, err, errBadInitialFunds) z := decimal.Zero c.CurrencySettings[0].SpotDetails.InitialQuoteFunds = &z c.CurrencySettings[0].SpotDetails.InitialBaseFunds = &z err = c.validateCurrencySettings() - if !errors.Is(err, errBadInitialFunds) { - t.Errorf("received: %v, expected: %v", err, errBadInitialFunds) - } + assert.ErrorIs(t, err, errBadInitialFunds) c.CurrencySettings[0].SpotDetails.InitialQuoteFunds = &leet c.FundingSettings.UseExchangeLevelFunding = true err = c.validateCurrencySettings() - if !errors.Is(err, errBadInitialFunds) { - t.Errorf("received: %v, expected: %v", err, errBadInitialFunds) - } + assert.ErrorIs(t, err, errBadInitialFunds) c.CurrencySettings[0].SpotDetails.InitialQuoteFunds = &z c.CurrencySettings[0].SpotDetails.InitialBaseFunds = &leet c.FundingSettings.UseExchangeLevelFunding = true err = c.validateCurrencySettings() - if !errors.Is(err, errBadInitialFunds) { - t.Errorf("received: %v, expected: %v", err, errBadInitialFunds) - } + assert.ErrorIs(t, err, errBadInitialFunds) } func TestValidateMinMaxes(t *testing.T) { @@ -219,9 +191,8 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errSizeLessThanZero) { - t.Errorf("received %v expected %v", err, errSizeLessThanZero) - } + assert.ErrorIs(t, err, errSizeLessThanZero) + c.CurrencySettings = []CurrencySettings{ { SellSide: MinMax{ @@ -230,9 +201,8 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errSizeLessThanZero) { - t.Errorf("received %v expected %v", err, errSizeLessThanZero) - } + assert.ErrorIs(t, err, errSizeLessThanZero) + c.CurrencySettings = []CurrencySettings{ { SellSide: MinMax{ @@ -241,9 +211,7 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errSizeLessThanZero) { - t.Errorf("received %v expected %v", err, errSizeLessThanZero) - } + assert.ErrorIs(t, err, errSizeLessThanZero) c.CurrencySettings = []CurrencySettings{ { @@ -255,9 +223,7 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errMaxSizeMinSizeMismatch) { - t.Errorf("received %v expected %v", err, errMaxSizeMinSizeMismatch) - } + assert.ErrorIs(t, err, errMaxSizeMinSizeMismatch) c.CurrencySettings = []CurrencySettings{ { @@ -268,9 +234,7 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errMinMaxEqual) { - t.Errorf("received %v expected %v", err, errMinMaxEqual) - } + assert.ErrorIs(t, err, errMinMaxEqual) c.CurrencySettings = []CurrencySettings{ { @@ -287,27 +251,23 @@ func TestValidateMinMaxes(t *testing.T) { }, } err = c.validateMinMaxes() - if !errors.Is(err, errSizeLessThanZero) { - t.Errorf("received %v expected %v", err, errSizeLessThanZero) - } + assert.ErrorIs(t, err, errSizeLessThanZero) + c.PortfolioSettings = PortfolioSettings{ SellSide: MinMax{ MinimumSize: decimal.NewFromInt(-1), }, } err = c.validateMinMaxes() - if !errors.Is(err, errSizeLessThanZero) { - t.Errorf("received %v expected %v", err, errSizeLessThanZero) - } + assert.ErrorIs(t, err, errSizeLessThanZero) } func TestValidateStrategySettings(t *testing.T) { t.Parallel() c := &Config{} err := c.validateStrategySettings() - if !errors.Is(err, base.ErrStrategyNotFound) { - t.Errorf("received %v expected %v", err, base.ErrStrategyNotFound) - } + assert.ErrorIs(t, err, base.ErrStrategyNotFound) + c.StrategySettings = StrategySettings{Name: dca} err = c.validateStrategySettings() assert.NoError(t, err) @@ -319,30 +279,23 @@ func TestValidateStrategySettings(t *testing.T) { c.FundingSettings = FundingSettings{} c.FundingSettings.UseExchangeLevelFunding = true err = c.validateStrategySettings() - if !errors.Is(err, errExchangeLevelFundingDataRequired) { - t.Errorf("received %v expected %v", err, errExchangeLevelFundingDataRequired) - } + assert.ErrorIs(t, err, errExchangeLevelFundingDataRequired) + c.FundingSettings.ExchangeLevelFunding = []ExchangeLevelFunding{ { InitialFunds: decimal.NewFromInt(-1), }, } err = c.validateStrategySettings() - if !errors.Is(err, errBadInitialFunds) { - t.Errorf("received %v expected %v", err, errBadInitialFunds) - } + assert.ErrorIs(t, err, errBadInitialFunds) c.StrategySettings.SimultaneousSignalProcessing = false err = c.validateStrategySettings() - if !errors.Is(err, errSimultaneousProcessingRequired) { - t.Errorf("received %v expected %v", err, errSimultaneousProcessingRequired) - } + assert.ErrorIs(t, err, errSimultaneousProcessingRequired) c.FundingSettings.UseExchangeLevelFunding = false err = c.validateStrategySettings() - if !errors.Is(err, errExchangeLevelFundingRequired) { - t.Errorf("received %v expected %v", err, errExchangeLevelFundingRequired) - } + assert.ErrorIs(t, err, errExchangeLevelFundingRequired) } func TestPrintSettings(t *testing.T) { @@ -435,9 +388,7 @@ func TestValidate(t *testing.T) { c = nil err = c.Validate() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received %v expected %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestReadStrategyConfigFromFile(t *testing.T) { @@ -456,9 +407,7 @@ func TestReadStrategyConfigFromFile(t *testing.T) { assert.NoError(t, err) _, err = ReadStrategyConfigFromFile("test") - if !errors.Is(err, common.ErrFileNotFound) { - t.Errorf("received '%v' expected '%v'", err, common.ErrFileNotFound) - } + assert.ErrorIs(t, err, common.ErrFileNotFound) } func TestGenerateConfigForDCAAPICandles(t *testing.T) { diff --git a/backtester/data/data.go b/backtester/data/data.go index 5f94ea4e..41c575d3 100644 --- a/backtester/data/data.go +++ b/backtester/data/data.go @@ -158,7 +158,7 @@ func (b *Base) SetStream(s []Event) error { if s[x].GetExchange() != b.stream[0].GetExchange() || s[x].GetAssetType() != b.stream[0].GetAssetType() || !s[x].Pair().Equal(b.stream[0].Pair()) { - return fmt.Errorf("%w cannot set base stream from %v %v %v to %v %v %v", errMisMatchedEvent, s[x].GetExchange(), s[x].GetAssetType(), s[x].Pair(), b.stream[0].GetExchange(), b.stream[0].GetAssetType(), b.stream[0].Pair()) + return fmt.Errorf("%w cannot set base stream from %v %v %v to %v %v %v", errMismatchedEvent, s[x].GetExchange(), s[x].GetAssetType(), s[x].Pair(), b.stream[0].GetExchange(), b.stream[0].GetAssetType(), b.stream[0].Pair()) } } // due to the Next() function, we cannot take @@ -193,7 +193,7 @@ candles: if s[x].GetExchange() != b.stream[0].GetExchange() || s[x].GetAssetType() != b.stream[0].GetAssetType() || !s[x].Pair().Equal(b.stream[0].Pair()) { - return fmt.Errorf("%w %v %v %v received %v %v %v", errMisMatchedEvent, b.stream[0].GetExchange(), b.stream[0].GetAssetType(), b.stream[0].Pair(), s[x].GetExchange(), s[x].GetAssetType(), s[x].Pair()) + return fmt.Errorf("%w %v %v %v received %v %v %v", errMismatchedEvent, b.stream[0].GetExchange(), b.stream[0].GetAssetType(), b.stream[0].Pair(), s[x].GetExchange(), s[x].GetAssetType(), s[x].Pair()) } // todo change b.stream to map for y := len(b.stream) - 1; y >= 0; y-- { diff --git a/backtester/data/data_test.go b/backtester/data/data_test.go index 81bd4b19..ad5473cf 100644 --- a/backtester/data/data_test.go +++ b/backtester/data/data_test.go @@ -1,7 +1,6 @@ package data import ( - "errors" "strings" "testing" "time" @@ -75,18 +74,14 @@ func TestGetDataForCurrency(t *testing.T) { assert.NoError(t, err) _, err = d.GetDataForCurrency(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) _, err = d.GetDataForCurrency(&fakeEvent{Base: &event.Base{ Exchange: "lol", AssetType: asset.USDTMarginedFutures, CurrencyPair: currency.NewPair(currency.EMB, currency.DOGE), }}) - if !errors.Is(err, ErrHandlerNotFound) { - t.Errorf("received '%v' expected '%v'", err, ErrHandlerNotFound) - } + assert.ErrorIs(t, err, ErrHandlerNotFound) _, err = d.GetDataForCurrency(&fakeEvent{Base: &event.Base{ Exchange: exch, @@ -211,26 +206,18 @@ func TestSetStream(t *testing.T) { }, } err = b.SetStream([]Event{misMatchEvent}) - if !errors.Is(err, ErrInvalidEventSupplied) { - t.Fatalf("received '%v' expected '%v'", err, ErrInvalidEventSupplied) - } + require.ErrorIs(t, err, ErrInvalidEventSupplied) misMatchEvent.Time = time.Now() err = b.SetStream([]Event{misMatchEvent}) - if !errors.Is(err, errMisMatchedEvent) { - t.Fatalf("received '%v' expected '%v'", err, errMisMatchedEvent) - } + require.ErrorIs(t, err, errMismatchedEvent) err = b.SetStream([]Event{nil}) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Fatalf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + require.ErrorIs(t, err, gctcommon.ErrNilPointer) b = nil err = b.SetStream(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestNext(t *testing.T) { @@ -464,9 +451,7 @@ func TestIsLive(t *testing.T) { b = nil _, err = b.IsLive() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestSetLive(t *testing.T) { @@ -488,9 +473,7 @@ func TestSetLive(t *testing.T) { b = nil err = b.SetLive(false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestAppendStream(t *testing.T) { @@ -500,9 +483,8 @@ func TestAppendStream(t *testing.T) { Base: &event.Base{}, } err := b.AppendStream(e) - if !errors.Is(err, ErrInvalidEventSupplied) { - t.Errorf("received '%v' expected '%v'", err, ErrInvalidEventSupplied) - } + assert.ErrorIs(t, err, ErrInvalidEventSupplied) + if len(b.stream) != 0 { t.Errorf("received '%v' expected '%v'", len(b.stream), 0) } @@ -512,9 +494,7 @@ func TestAppendStream(t *testing.T) { e.AssetType = asset.Spot e.CurrencyPair = cp err = b.AppendStream(e) - if !errors.Is(err, ErrInvalidEventSupplied) { - t.Fatalf("received '%v' expected '%v'", err, ErrInvalidEventSupplied) - } + require.ErrorIs(t, err, ErrInvalidEventSupplied) e.Time = tt err = b.AppendStream(e, e) @@ -554,34 +534,29 @@ func TestAppendStream(t *testing.T) { }, } err = b.AppendStream(misMatchEvent) - if !errors.Is(err, errMisMatchedEvent) { - t.Fatalf("received '%v' expected '%v'", err, errMisMatchedEvent) - } + require.ErrorIs(t, err, errMismatchedEvent) + if len(b.stream) != 2 { t.Errorf("received '%v' expected '%v'", len(b.stream), 2) } err = b.AppendStream(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Fatalf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + require.ErrorIs(t, err, gctcommon.ErrNilPointer) + if len(b.stream) != 2 { t.Errorf("received '%v' expected '%v'", len(b.stream), 2) } err = b.AppendStream() - if !errors.Is(err, errNothingToAdd) { - t.Fatalf("received '%v' expected '%v'", err, errNothingToAdd) - } + require.ErrorIs(t, err, errNothingToAdd) + if len(b.stream) != 2 { t.Errorf("received '%v' expected '%v'", len(b.stream), 2) } b = nil err = b.AppendStream() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestFirst(t *testing.T) { diff --git a/backtester/data/data_types.go b/backtester/data/data_types.go index 60f42e52..921823b9 100644 --- a/backtester/data/data_types.go +++ b/backtester/data/data_types.go @@ -23,7 +23,7 @@ var ( ErrEndOfData = errors.New("no more data to retrieve") errNothingToAdd = errors.New("cannot append empty event to stream") - errMisMatchedEvent = errors.New("cannot add event to stream, does not match") + errMismatchedEvent = errors.New("cannot add event to stream, does not match") ) // HandlerHolder stores an event handler per exchange asset pair diff --git a/backtester/data/kline/csv/csv_test.go b/backtester/data/kline/csv/csv_test.go index e8c66d55..5b3551b4 100644 --- a/backtester/data/kline/csv/csv_test.go +++ b/backtester/data/kline/csv/csv_test.go @@ -1,7 +1,6 @@ package csv import ( - "errors" "path/filepath" "testing" @@ -56,9 +55,7 @@ func TestLoadDataInvalid(t *testing.T) { p, a, false) - if !errors.Is(err, common.ErrInvalidDataType) { - t.Errorf("received: %v, expected: %v", err, common.ErrInvalidDataType) - } + assert.ErrorIs(t, err, common.ErrInvalidDataType) _, err = LoadData( -1, @@ -68,7 +65,5 @@ func TestLoadDataInvalid(t *testing.T) { p, a, true) - if !errors.Is(err, errNoUSDData) { - t.Errorf("received: %v, expected: %v", err, errNoUSDData) - } + assert.ErrorIs(t, err, errNoUSDData) } diff --git a/backtester/data/kline/database/database_test.go b/backtester/data/kline/database/database_test.go index 4bf8d7ee..7b55c7a3 100644 --- a/backtester/data/kline/database/database_test.go +++ b/backtester/data/kline/database/database_test.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "os" "path/filepath" @@ -192,12 +191,8 @@ func TestLoadDataInvalid(t *testing.T) { dStart := time.Date(2020, 1, 0, 0, 0, 0, 0, time.UTC) dEnd := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) _, err := LoadData(dStart, dEnd, gctkline.FifteenMin.Duration(), exch, -1, p, a, false) - if !errors.Is(err, common.ErrInvalidDataType) { - t.Errorf("received: %v, expected: %v", err, common.ErrInvalidDataType) - } + assert.ErrorIs(t, err, common.ErrInvalidDataType) _, err = LoadData(dStart, dEnd, gctkline.FifteenMin.Duration(), exch, -1, p, a, true) - if !errors.Is(err, errNoUSDData) { - t.Errorf("received: %v, expected: %v", err, errNoUSDData) - } + assert.ErrorIs(t, err, errNoUSDData) } diff --git a/backtester/data/kline/kline_test.go b/backtester/data/kline/kline_test.go index 77f06104..79cc0b0c 100644 --- a/backtester/data/kline/kline_test.go +++ b/backtester/data/kline/kline_test.go @@ -1,7 +1,6 @@ package kline import ( - "errors" "testing" "time" @@ -31,9 +30,8 @@ func TestLoad(t *testing.T) { Base: &data.Base{}, } err := d.Load() - if !errors.Is(err, errNoCandleData) { - t.Errorf("received: %v, expected: %v", err, errNoCandleData) - } + assert.ErrorIs(t, err, errNoCandleData) + d.Item = &gctkline.Item{ Exchange: exch, Pair: p, @@ -155,9 +153,7 @@ func TestAppend(t *testing.T) { }, } err := d.AppendResults(&item) - if !errors.Is(err, gctkline.ErrItemNotEqual) { - t.Errorf("received: %v, expected: %v", err, gctkline.ErrItemNotEqual) - } + assert.ErrorIs(t, err, gctkline.ErrItemNotEqual) item.Exchange = testExchange item.Pair = p @@ -170,9 +166,7 @@ func TestAppend(t *testing.T) { assert.NoError(t, err) err = d.AppendResults(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStreamOpen(t *testing.T) { diff --git a/backtester/engine/backtest_test.go b/backtester/engine/backtest_test.go index 539ad3ea..5a0f8c9f 100644 --- a/backtester/engine/backtest_test.go +++ b/backtester/engine/backtest_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "path/filepath" "strings" "sync" @@ -59,20 +58,15 @@ func TestSetupFromConfig(t *testing.T) { require.NoError(t, err) err = bt.SetupFromConfig(nil, "", "", false) - if !errors.Is(err, errNilConfig) { - t.Errorf("received %v, expected %v", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) + cfg := &config.Config{} err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, gctkline.ErrInvalidInterval) { - t.Errorf("received: %v, expected: %v", err, gctkline.ErrInvalidInterval) - } + assert.ErrorIs(t, err, gctkline.ErrInvalidInterval) cfg.DataSettings.Interval = gctkline.OneMonth err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, base.ErrStrategyNotFound) { - t.Errorf("received: %v, expected: %v", err, base.ErrStrategyNotFound) - } + assert.ErrorIs(t, err, base.ErrStrategyNotFound) const testExchange = "bitfinex" @@ -85,9 +79,7 @@ func TestSetupFromConfig(t *testing.T) { }, } err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, base.ErrStrategyNotFound) { - t.Errorf("received: %v, expected: %v", err, base.ErrStrategyNotFound) - } + assert.ErrorIs(t, err, base.ErrStrategyNotFound) cfg.StrategySettings = config.StrategySettings{ Name: dollarcostaverage.Name, @@ -103,24 +95,20 @@ func TestSetupFromConfig(t *testing.T) { } cfg.DataSettings.DataType = common.CandleStr err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) + cfg.DataSettings.Interval = gctkline.OneMin cfg.CurrencySettings[0].MakerFee = &decimal.Zero cfg.CurrencySettings[0].TakerFee = &decimal.Zero err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) cfg.DataSettings.APIData.StartDate = time.Now().Truncate(gctkline.OneMin.Duration()).Add(-gctkline.OneMin.Duration() * 10) cfg.DataSettings.APIData.EndDate = cfg.DataSettings.APIData.StartDate.Add(gctkline.OneMin.Duration() * 5) cfg.DataSettings.APIData.InclusiveEndDate = true err = bt.SetupFromConfig(cfg, "", "", false) - if !errors.Is(err, holdings.ErrInitialFundsZero) { - t.Errorf("received: %v, expected: %v", err, holdings.ErrInitialFundsZero) - } + assert.ErrorIs(t, err, holdings.ErrInitialFundsZero) + cfg.FundingSettings.UseExchangeLevelFunding = true cfg.FundingSettings.ExchangeLevelFunding = []config.ExchangeLevelFunding{ { @@ -390,9 +378,7 @@ func TestLoadDataLive(t *testing.T) { RequestFormat: ¤cy.PairFormat{Uppercase: true}, } _, err = bt.loadData(cfg, exch, cp, asset.Spot, false) - if !errors.Is(err, gctkline.ErrCannotConstructInterval) { - t.Errorf("received: %v, expected: %v", err, gctkline.ErrCannotConstructInterval) - } + assert.ErrorIs(t, err, gctkline.ErrCannotConstructInterval) cfg.DataSettings.Interval = gctkline.OneMin _, err = bt.loadData(cfg, exch, cp, asset.Spot, false) @@ -427,9 +413,7 @@ func TestReset(t *testing.T) { bt = nil err = bt.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestFullCycle(t *testing.T) { @@ -539,18 +523,15 @@ func TestStop(t *testing.T) { tt := bt.MetaData.DateEnded err = bt.Stop() - if !errors.Is(err, errAlreadyRan) { - t.Errorf("received: %v, expected: %v", err, errAlreadyRan) - } + assert.ErrorIs(t, err, errAlreadyRan) + if !tt.Equal(bt.MetaData.DateEnded) { t.Errorf("received '%v' expected '%v'", bt.MetaData.DateEnded, tt) } bt = nil err = bt.Stop() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestFullCycleMulti(t *testing.T) { @@ -644,9 +625,7 @@ func TestFullCycleMulti(t *testing.T) { assert.NoError(t, err) err = bt.Run() - if !errors.Is(err, errNotSetup) { - t.Errorf("received: %v, expected: %v", err, errNotSetup) - } + assert.ErrorIs(t, err, errNotSetup) bt.MetaData.DateLoaded = time.Now() err = bt.Run() @@ -676,15 +655,11 @@ func TestTriggerLiquidationsForExchange(t *testing.T) { bt := BackTest{ shutdown: make(chan struct{}), } - expectedError := common.ErrNilEvent err := bt.triggerLiquidationsForExchange(nil, nil) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) cp := currency.NewBTCUSDT() a := asset.USDTMarginedFutures - expectedError = gctcommon.ErrNilPointer ev := &evkline.Kline{ Base: &event.Base{ Exchange: testExchange, @@ -693,9 +668,7 @@ func TestTriggerLiquidationsForExchange(t *testing.T) { }, } err = bt.triggerLiquidationsForExchange(ev, nil) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.Portfolio = &portfolioOverride{} pnl := &portfolio.PNLSummary{} @@ -730,7 +703,6 @@ func TestTriggerLiquidationsForExchange(t *testing.T) { RangeHolder: &gctkline.IntervalRangeHolder{}, } bt.Statistic = &statistics.Statistic{} - expectedError = nil bt.EventQueue = &eventholder.Holder{} bt.Funding = &funding.FundManager{} @@ -738,24 +710,18 @@ func TestTriggerLiquidationsForExchange(t *testing.T) { assert.NoError(t, err) err = bt.Statistic.SetEventForOffset(ev) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "SetEventForOffset should not error") + pnl.Exchange = ev.Exchange pnl.Asset = ev.AssetType pnl.Pair = ev.CurrencyPair err = bt.triggerLiquidationsForExchange(ev, pnl) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "triggerLiquidationsForExchange should not error") + ev2 := bt.EventQueue.NextEvent() ev2o, ok := ev2.(order.Event) - if !ok { - t.Fatal("expected order event") - } - if ev2o.GetDirection() != gctorder.Short { - t.Error("expected liquidation order") - } + require.True(t, ok, "NextEvent must return an order event") + assert.Equal(t, gctorder.Short, ev2o.GetDirection()) } func TestUpdateStatsForDataEvent(t *testing.T) { @@ -766,11 +732,9 @@ func TestUpdateStatsForDataEvent(t *testing.T) { Portfolio: &fakeFolio{}, shutdown: make(chan struct{}), } - expectedError := common.ErrNilEvent + err := bt.updateStatsForDataEvent(nil, nil) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) cp := currency.NewBTCUSDT() a := asset.Futures @@ -782,28 +746,21 @@ func TestUpdateStatsForDataEvent(t *testing.T) { }, } - expectedError = gctcommon.ErrNilPointer err = bt.updateStatsForDataEvent(ev, nil) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - expectedError = nil + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + f, err := funding.SetupFundingManager(&engine.ExchangeManager{}, false, true, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetupFundingManager must not error") + b, err := funding.CreateItem(testExchange, a, cp.Base, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + quote, err := funding.CreateItem(testExchange, a, cp.Quote, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + pair, err := funding.CreateCollateral(b, quote) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateCollateral must not error") + bt.Funding = f exch := &binance.Binance{} exch.Name = testExchange @@ -829,19 +786,14 @@ func TestUpdateStatsForDataEvent(t *testing.T) { }, } _, err = bt.Portfolio.TrackFuturesOrder(fl, pair) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "TrackFuturesOrder should not error") err = bt.updateStatsForDataEvent(ev, pair) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "updateStatsForDataEvent should not error") } func TestProcessSignalEvent(t *testing.T) { t.Parallel() - var expectedError error bt := &BackTest{ Statistic: &fakeStats{}, Funding: &funding.FundManager{}, @@ -860,29 +812,24 @@ func TestProcessSignalEvent(t *testing.T) { }, } err := bt.Statistic.SetEventForOffset(de) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must not error") + ev := &signal.Signal{ Base: de.Base, } f, err := funding.SetupFundingManager(&engine.ExchangeManager{}, false, true, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetupFundingManager must not error") + b, err := funding.CreateItem(testExchange, a, cp.Base, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + quote, err := funding.CreateItem(testExchange, a, cp.Quote, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + pair, err := funding.CreateCollateral(b, quote) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateCollateral must not error") + bt.Funding = f exch := &binance.Binance{} exch.Name = testExchange @@ -893,22 +840,17 @@ func TestProcessSignalEvent(t *testing.T) { }) ev.Direction = gctorder.Short err = bt.Statistic.SetEventForOffset(ev) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must not error") + err = bt.processSignalEvent(ev, pair) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "processSignalEvent should not error") } func TestProcessOrderEvent(t *testing.T) { t.Parallel() - var expectedError error pt, err := portfolio.Setup(&size.Size{}, &risk.Risk{}, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "Setup must not error") + bt := &BackTest{ Statistic: &statistics.Statistic{}, Funding: &funding.FundManager{}, @@ -928,29 +870,24 @@ func TestProcessOrderEvent(t *testing.T) { }, } err = bt.Statistic.SetEventForOffset(de) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must not error") + ev := &order.Order{ Base: de.Base, } f, err := funding.SetupFundingManager(&engine.ExchangeManager{}, false, true, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetupFundingManager must not error") + b, err := funding.CreateItem(testExchange, a, cp.Base, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + quote, err := funding.CreateItem(testExchange, a, cp.Quote, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + pair, err := funding.CreateCollateral(b, quote) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateCollateral must not error") + bt.Funding = f exch := &binance.Binance{} exch.Name = testExchange @@ -959,9 +896,7 @@ func TestProcessOrderEvent(t *testing.T) { Pair: cp, Asset: a, }) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) bt.Exchange.SetExchangeAssetCurrencySettings(a, cp, &exchange.Settings{ Exchange: exch, @@ -970,9 +905,8 @@ func TestProcessOrderEvent(t *testing.T) { }) ev.Direction = gctorder.Short err = bt.Statistic.SetEventForOffset(ev) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must not error") + tt := time.Now() bt.DataHolder = data.NewHandlerHolder() k := &kline.DataFromKline{ @@ -1016,13 +950,11 @@ func TestProcessOrderEvent(t *testing.T) { assert.NoError(t, err) err = bt.processOrderEvent(ev, pair) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "processOrderEvent must not error") + ev2 := bt.EventQueue.NextEvent() - if _, ok := ev2.(fill.Event); !ok { - t.Fatal("expected fill event") - } + _, ok := ev2.(fill.Event) + require.True(t, ok, "NextEvent must return a fill event") } func TestProcessFillEvent(t *testing.T) { @@ -1127,7 +1059,6 @@ func TestProcessFillEvent(t *testing.T) { func TestProcessFuturesFillEvent(t *testing.T) { t.Parallel() - var expectedError error bt := &BackTest{ Statistic: &fakeStats{}, Funding: &funding.FundManager{}, @@ -1147,33 +1078,26 @@ func TestProcessFuturesFillEvent(t *testing.T) { }, } err := bt.Statistic.SetEventForOffset(de) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must note error") + ev := &fill.Fill{ Base: de.Base, } em := engine.NewExchangeManager() exch, err := em.NewExchangeByName(testExchange) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) exch.SetDefaults() err = em.Add(exch) require.NoError(t, err) b, err := funding.CreateItem(testExchange, a, cp.Base, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + quote, err := funding.CreateItem(testExchange, a, cp.Quote, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + pair, err := funding.CreateCollateral(b, quote) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateCollateral must not error") bt.exchangeManager = em bt.Exchange.SetExchangeAssetCurrencySettings(a, cp, &exchange.Settings{ @@ -1183,9 +1107,8 @@ func TestProcessFuturesFillEvent(t *testing.T) { }) ev.Direction = gctorder.Short err = bt.Statistic.SetEventForOffset(ev) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "SetEventForOffset must not error") + tt := time.Now() bt.DataHolder = data.NewHandlerHolder() k := &kline.DataFromKline{ @@ -1223,7 +1146,7 @@ func TestProcessFuturesFillEvent(t *testing.T) { }, } err = k.Load() - assert.NoError(t, err) + require.NoError(t, err) ev.Order = &gctorder.Detail{ Exchange: testExchange, @@ -1236,12 +1159,10 @@ func TestProcessFuturesFillEvent(t *testing.T) { Date: time.Now(), } err = bt.DataHolder.SetDataForCurrency(testExchange, a, cp, k) - assert.NoError(t, err) + require.NoError(t, err) err = bt.processFuturesFillEvent(ev, pair) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "processFuturesFillEvent should not error") } func TestCloseAllPositions(t *testing.T) { @@ -1254,9 +1175,7 @@ func TestCloseAllPositions(t *testing.T) { bt.Strategy = &dollarcostaverage.Strategy{} err = bt.CloseAllPositions() - if !errors.Is(err, errLiveOnly) { - t.Errorf("received '%v' expected '%v'", err, errLiveOnly) - } + assert.ErrorIs(t, err, errLiveOnly) bt.shutdown = make(chan struct{}) dc := &dataChecker{ @@ -1265,16 +1184,12 @@ func TestCloseAllPositions(t *testing.T) { } bt.LiveDataHandler = dc err = bt.CloseAllPositions() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.shutdown = make(chan struct{}) bt.Strategy = &binancecashandcarry.Strategy{} err = bt.CloseAllPositions() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.shutdown = make(chan struct{}) bt.Portfolio = &fakeFolio{} @@ -1327,9 +1242,7 @@ func TestRunLive(t *testing.T) { assert.NoError(t, err) err = bt.RunLive() - if !errors.Is(err, errLiveOnly) { - t.Errorf("received '%v' expected '%v'", err, errLiveOnly) - } + assert.ErrorIs(t, err, errLiveOnly) bt.Funding = &funding.FundManager{} bt.Reports = &report.Data{} @@ -1457,17 +1370,14 @@ func TestLiveLoop(t *testing.T) { func TestSetExchangeCredentials(t *testing.T) { t.Parallel() err := setExchangeCredentials(nil, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + cfg := &config.Config{} f := &binanceus.Binanceus{} f.SetDefaults() b := f.GetBase() err = setExchangeCredentials(cfg, b) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) ld := &config.LiveData{} cfg.DataSettings = config.DataSettings{ @@ -1478,21 +1388,15 @@ func TestSetExchangeCredentials(t *testing.T) { ld.RealOrders = true err = setExchangeCredentials(cfg, b) - if !errors.Is(err, errIntervalUnset) { - t.Errorf("received '%v' expected '%v'", err, errIntervalUnset) - } + assert.ErrorIs(t, err, errIntervalUnset) cfg.DataSettings.Interval = gctkline.OneMin err = setExchangeCredentials(cfg, b) - if !errors.Is(err, errNoCredsNoLive) { - t.Errorf("received '%v' expected '%v'", err, errNoCredsNoLive) - } + assert.ErrorIs(t, err, errNoCredsNoLive) cfg.DataSettings.LiveData.ExchangeCredentials = []config.Credentials{{}} err = setExchangeCredentials(cfg, b) - if !errors.Is(err, gctexchange.ErrCredentialsAreEmpty) { - t.Errorf("received '%v' expected '%v'", err, gctexchange.ErrCredentialsAreEmpty) - } + assert.ErrorIs(t, err, gctexchange.ErrCredentialsAreEmpty) // requires valid credentials here to get complete coverage // enter them here @@ -1547,9 +1451,7 @@ func TestGenerateSummary(t *testing.T) { bt = nil _, err = bt.GenerateSummary() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestSetupMetaData(t *testing.T) { @@ -1573,9 +1475,7 @@ func TestSetupMetaData(t *testing.T) { bt = nil err = bt.SetupMetaData() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestIsRunning(t *testing.T) { @@ -1713,9 +1613,8 @@ func TestExecuteStrategy(t *testing.T) { shutdown: make(chan struct{}), } err := bt.ExecuteStrategy(false) - if !errors.Is(err, errNotSetup) { - t.Errorf("received '%v' expected '%v'", err, errNotSetup) - } + assert.ErrorIs(t, err, errNotSetup) + id, err := uuid.NewV4() assert.NoError(t, err) @@ -1725,17 +1624,13 @@ func TestExecuteStrategy(t *testing.T) { bt.MetaData.DateStarted = time.Now() bt.m.Unlock() err = bt.ExecuteStrategy(false) - if !errors.Is(err, errTaskIsRunning) { - t.Errorf("received '%v' expected '%v'", err, errTaskIsRunning) - } + assert.ErrorIs(t, err, errTaskIsRunning) err = bt.Stop() assert.NoError(t, err) err = bt.ExecuteStrategy(true) - if !errors.Is(err, errAlreadyRan) { - t.Errorf("received '%v' expected '%v'", err, errAlreadyRan) - } + assert.ErrorIs(t, err, errAlreadyRan) bt.m.Lock() bt.MetaData.DateStarted = time.Time{} @@ -1764,20 +1659,14 @@ func TestExecuteStrategy(t *testing.T) { bt.shutdown = make(chan struct{}) bt.m.Unlock() err = bt.ExecuteStrategy(true) - if !errors.Is(err, errCannotHandleRequest) { - t.Errorf("received '%v' expected '%v'", err, errCannotHandleRequest) - } + assert.ErrorIs(t, err, errCannotHandleRequest) err = bt.ExecuteStrategy(false) - if !errors.Is(err, errLiveOnly) { - t.Errorf("received '%v' expected '%v'", err, errLiveOnly) - } + assert.ErrorIs(t, err, errLiveOnly) bt = nil err = bt.ExecuteStrategy(false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestNewBacktesterFromConfigs(t *testing.T) { @@ -1816,9 +1705,8 @@ func TestProcessSingleDataEvent(t *testing.T) { } err := bt.processSingleDataEvent(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + cp := currency.NewBTCUSDT() a := asset.Spot ev := &evkline.Kline{ @@ -1831,9 +1719,7 @@ func TestProcessSingleDataEvent(t *testing.T) { }, } err = bt.processSingleDataEvent(ev, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) f, err := funding.SetupFundingManager(&engine.ExchangeManager{}, false, true, false) assert.NoError(t, err) diff --git a/backtester/engine/live_test.go b/backtester/engine/live_test.go index bac936cc..7de05d25 100644 --- a/backtester/engine/live_test.go +++ b/backtester/engine/live_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "sync" "sync/atomic" "testing" @@ -30,27 +29,19 @@ func TestSetupLiveDataHandler(t *testing.T) { bt := &BackTest{} var err error err = bt.SetupLiveDataHandler(-1, -1, false, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.exchangeManager = engine.NewExchangeManager() err = bt.SetupLiveDataHandler(-1, -1, false, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.DataHolder = &data.HandlerHolder{} err = bt.SetupLiveDataHandler(-1, -1, false, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.Reports = &report.Data{} err = bt.SetupLiveDataHandler(-1, -1, false, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt.Funding = &funding.FundManager{} err = bt.SetupLiveDataHandler(-1, -1, false, false) @@ -69,9 +60,7 @@ func TestSetupLiveDataHandler(t *testing.T) { bt = nil err = bt.SetupLiveDataHandler(-1, -1, false, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStart(t *testing.T) { @@ -86,15 +75,11 @@ func TestStart(t *testing.T) { dc.wg.Wait() atomic.CompareAndSwapUint32(&dc.started, 0, 1) err = dc.Start() - if !errors.Is(err, engine.ErrSubSystemAlreadyStarted) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, engine.ErrSubSystemAlreadyStarted) var dh *dataChecker err = dh.Start() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestDataCheckerIsRunning(t *testing.T) { @@ -122,9 +107,7 @@ func TestLiveHandlerStop(t *testing.T) { shutdown: make(chan bool), } err := dc.Stop() - if !errors.Is(err, engine.ErrSubSystemNotStarted) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, engine.ErrSubSystemNotStarted) dc.started = 1 err = dc.Stop() @@ -132,15 +115,11 @@ func TestLiveHandlerStop(t *testing.T) { dc.shutdown = make(chan bool) err = dc.Stop() - if !errors.Is(err, engine.ErrSubSystemNotStarted) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, engine.ErrSubSystemNotStarted) var dh *dataChecker err = dh.Stop() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestLiveHandlerStopFromError(t *testing.T) { @@ -149,14 +128,11 @@ func TestLiveHandlerStopFromError(t *testing.T) { shutdownErr: make(chan bool, 10), } err := dc.SignalStopFromError(errNoCredsNoLive) - if !errors.Is(err, engine.ErrSubSystemNotStarted) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, engine.ErrSubSystemNotStarted) err = dc.SignalStopFromError(nil) - if !errors.Is(err, errNilError) { - t.Errorf("received '%v' expected '%v'", err, errNilError) - } + assert.ErrorIs(t, err, errNilError) + dc.started = 1 var wg sync.WaitGroup wg.Add(1) @@ -169,9 +145,7 @@ func TestLiveHandlerStopFromError(t *testing.T) { var dh *dataChecker err = dh.SignalStopFromError(errNoCredsNoLive) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestDataFetcher(t *testing.T) { @@ -185,22 +159,16 @@ func TestDataFetcher(t *testing.T) { } dc.wg.Add(1) err := dc.DataFetcher() - if !errors.Is(err, engine.ErrSubSystemNotStarted) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, engine.ErrSubSystemNotStarted) dc.started = 1 dc.wg.Add(1) err = dc.DataFetcher() - if !errors.Is(err, ErrLiveDataTimeout) { - t.Errorf("received '%v' expected '%v'", err, ErrLiveDataTimeout) - } + assert.ErrorIs(t, err, ErrLiveDataTimeout) var dh *dataChecker err = dh.DataFetcher() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestUpdated(t *testing.T) { @@ -238,48 +206,34 @@ func TestLiveHandlerReset(t *testing.T) { } var dh *dataChecker err = dh.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestAppendDataSource(t *testing.T) { t.Parallel() dataHandler := &dataChecker{} err := dataHandler.AppendDataSource(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) setup := &liveDataSourceSetup{} err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) setup.exchange = &binance.Binance{} err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, common.ErrInvalidDataType) { - t.Errorf("received '%v' expected '%v'", err, common.ErrInvalidDataType) - } + assert.ErrorIs(t, err, common.ErrInvalidDataType) setup.dataType = common.DataCandle err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v' expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) setup.asset = asset.Spot err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Errorf("received '%v' expected '%v'", err, currency.ErrCurrencyPairEmpty) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) setup.pair = currency.NewBTCUSDT() err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, kline.ErrInvalidInterval) { - t.Errorf("received '%v' expected '%v'", err, kline.ErrInvalidInterval) - } + assert.ErrorIs(t, err, kline.ErrInvalidInterval) setup.interval = kline.OneDay err = dataHandler.AppendDataSource(setup) @@ -290,15 +244,11 @@ func TestAppendDataSource(t *testing.T) { } err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, errDataSourceExists) { - t.Errorf("received '%v' expected '%v'", err, errDataSourceExists) - } + assert.ErrorIs(t, err, errDataSourceExists) dataHandler = nil err = dataHandler.AppendDataSource(setup) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestFetchLatestData(t *testing.T) { @@ -370,9 +320,7 @@ func TestLoadCandleData(t *testing.T) { processedData: make(map[int64]struct{}), } _, err := l.loadCandleData(time.Now()) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) exch := &binanceus.Binanceus{} exch.SetDefaults() @@ -407,9 +355,7 @@ func TestLoadCandleData(t *testing.T) { var ldh *liveDataSourceDataHandler _, err = ldh.loadCandleData(time.Now()) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestSetDataForClosingAllPositions(t *testing.T) { @@ -464,14 +410,11 @@ func TestSetDataForClosingAllPositions(t *testing.T) { assert.NoError(t, err) err = dataHandler.SetDataForClosingAllPositions() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) err = dataHandler.SetDataForClosingAllPositions(nil) - if !errors.Is(err, errNilData) { - t.Errorf("received '%v' expected '%v'", err, errNilData) - } + assert.ErrorIs(t, err, errNilData) + err = dataHandler.SetDataForClosingAllPositions(&signal.Signal{ Base: &event.Base{ Offset: 3, @@ -516,9 +459,7 @@ func TestSetDataForClosingAllPositions(t *testing.T) { dataHandler = nil err = dataHandler.SetDataForClosingAllPositions() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestIsRealOrders(t *testing.T) { @@ -537,9 +478,7 @@ func TestUpdateFunding(t *testing.T) { t.Parallel() d := &dataChecker{} err := d.UpdateFunding(false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) ff := &fakeFunding{} d.funding = ff @@ -567,9 +506,7 @@ func TestUpdateFunding(t *testing.T) { d = nil err = d.UpdateFunding(false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestClosedChan(t *testing.T) { diff --git a/backtester/engine/taskmanager_test.go b/backtester/engine/taskmanager_test.go index 2f305bf4..6d896129 100644 --- a/backtester/engine/taskmanager_test.go +++ b/backtester/engine/taskmanager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "testing" "time" @@ -26,9 +25,7 @@ func TestAddRun(t *testing.T) { t.Parallel() rm := NewTaskManager() err := rm.AddTask(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bt := &BackTest{} err = rm.AddTask(bt) @@ -42,18 +39,15 @@ func TestAddRun(t *testing.T) { } err = rm.AddTask(bt) - if !errors.Is(err, errTaskAlreadyMonitored) { - t.Errorf("received '%v' expected '%v'", err, errTaskAlreadyMonitored) - } + assert.ErrorIs(t, err, errTaskAlreadyMonitored) + if len(rm.tasks) != 1 { t.Errorf("received '%v' expected '%v'", len(rm.tasks), 1) } rm = nil err = rm.AddTask(bt) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestGetSummary(t *testing.T) { @@ -63,9 +57,7 @@ func TestGetSummary(t *testing.T) { assert.NoError(t, err) _, err = rm.GetSummary(id) - if !errors.Is(err, errTaskNotFound) { - t.Errorf("received '%v' expected '%v'", err, errTaskNotFound) - } + assert.ErrorIs(t, err, errTaskNotFound) bt := &BackTest{ Strategy: &binancecashandcarry.Strategy{}, @@ -83,9 +75,7 @@ func TestGetSummary(t *testing.T) { rm = nil _, err = rm.GetSummary(id) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestList(t *testing.T) { @@ -114,9 +104,7 @@ func TestList(t *testing.T) { rm = nil _, err = rm.List() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStopRun(t *testing.T) { @@ -133,9 +121,7 @@ func TestStopRun(t *testing.T) { assert.NoError(t, err) err = rm.StopTask(id) - if !errors.Is(err, errTaskNotFound) { - t.Errorf("received '%v' expected '%v'", err, errTaskNotFound) - } + assert.ErrorIs(t, err, errTaskNotFound) bt := &BackTest{ Strategy: &fakeStrat{}, @@ -147,9 +133,7 @@ func TestStopRun(t *testing.T) { assert.NoError(t, err) err = rm.StopTask(bt.MetaData.ID) - if !errors.Is(err, errTaskHasNotRan) { - t.Errorf("received '%v' expected '%v'", err, errTaskHasNotRan) - } + assert.ErrorIs(t, err, errTaskHasNotRan) bt.m.Lock() bt.MetaData.DateStarted = time.Now() @@ -158,15 +142,11 @@ func TestStopRun(t *testing.T) { assert.NoError(t, err) err = rm.StopTask(bt.MetaData.ID) - if !errors.Is(err, errAlreadyRan) { - t.Errorf("received '%v' expected '%v'", err, errAlreadyRan) - } + assert.ErrorIs(t, err, errAlreadyRan) rm = nil err = rm.StopTask(id) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStopAllRuns(t *testing.T) { @@ -200,9 +180,7 @@ func TestStopAllRuns(t *testing.T) { rm = nil _, err = rm.StopAllTasks() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStartRun(t *testing.T) { @@ -219,9 +197,7 @@ func TestStartRun(t *testing.T) { assert.NoError(t, err) err = rm.StartTask(id) - if !errors.Is(err, errTaskNotFound) { - t.Errorf("received '%v' expected '%v'", err, errTaskNotFound) - } + assert.ErrorIs(t, err, errTaskNotFound) bt := &BackTest{ Strategy: &binancecashandcarry.Strategy{}, @@ -237,9 +213,8 @@ func TestStartRun(t *testing.T) { assert.NoError(t, err) err = rm.StartTask(bt.MetaData.ID) - if !errors.Is(err, errTaskIsRunning) { - t.Errorf("received '%v' expected '%v'", err, errTaskIsRunning) - } + assert.ErrorIs(t, err, errTaskIsRunning) + bt.m.Lock() bt.MetaData.DateEnded = time.Now() bt.MetaData.Closed = true @@ -247,15 +222,11 @@ func TestStartRun(t *testing.T) { bt.m.Unlock() err = rm.StartTask(bt.MetaData.ID) - if !errors.Is(err, errAlreadyRan) { - t.Errorf("received '%v' expected '%v'", err, errAlreadyRan) - } + assert.ErrorIs(t, err, errAlreadyRan) rm = nil err = rm.StartTask(id) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestStartAllRuns(t *testing.T) { @@ -287,9 +258,7 @@ func TestStartAllRuns(t *testing.T) { rm = nil _, err = rm.StartAllTasks() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestClearRun(t *testing.T) { @@ -300,9 +269,7 @@ func TestClearRun(t *testing.T) { assert.NoError(t, err) err = rm.ClearTask(id) - if !errors.Is(err, errTaskNotFound) { - t.Errorf("received '%v' expected '%v'", err, errTaskNotFound) - } + assert.ErrorIs(t, err, errTaskNotFound) bt := &BackTest{ Strategy: &binancecashandcarry.Strategy{}, @@ -318,9 +285,7 @@ func TestClearRun(t *testing.T) { bt.MetaData.DateStarted = time.Now() bt.m.Unlock() err = rm.ClearTask(bt.MetaData.ID) - if !errors.Is(err, errCannotClear) { - t.Errorf("received '%v' expected '%v'", err, errCannotClear) - } + assert.ErrorIs(t, err, errCannotClear) bt.m.Lock() bt.MetaData.DateStarted = time.Time{} @@ -337,9 +302,7 @@ func TestClearRun(t *testing.T) { rm = nil err = rm.ClearTask(id) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestClearAllRuns(t *testing.T) { @@ -398,7 +361,5 @@ func TestClearAllRuns(t *testing.T) { rm = nil _, _, err = rm.ClearAllTasks() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } diff --git a/backtester/eventhandlers/eventholder/eventholder_test.go b/backtester/eventhandlers/eventholder/eventholder_test.go index 0926354a..84465155 100644 --- a/backtester/eventhandlers/eventholder/eventholder_test.go +++ b/backtester/eventhandlers/eventholder/eventholder_test.go @@ -1,7 +1,6 @@ package eventholder import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -22,9 +21,7 @@ func TestReset(t *testing.T) { e = nil err = e.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestAppendEvent(t *testing.T) { diff --git a/backtester/eventhandlers/exchange/exchange_test.go b/backtester/eventhandlers/exchange/exchange_test.go index 27b16f72..f54d6419 100644 --- a/backtester/eventhandlers/exchange/exchange_test.go +++ b/backtester/eventhandlers/exchange/exchange_test.go @@ -1,7 +1,6 @@ package exchange import ( - "errors" "testing" "time" @@ -116,9 +115,7 @@ func TestReset(t *testing.T) { e = nil err = e.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestSetCurrency(t *testing.T) { @@ -295,18 +292,14 @@ func TestExecuteOrder(t *testing.T) { assert.NoError(t, err) _, err = e.ExecuteOrder(o, d, bot.OrderManager, &fakeFund{}) - if !errors.Is(err, errNoCurrencySettingsFound) { - t.Error(err) - } + assert.ErrorIs(t, err, errNoCurrencySettingsFound) cs.UseRealOrders = true cs.CanUseExchangeLimits = true o.Direction = gctorder.Sell e.CurrencySettings = []Settings{cs} _, err = e.ExecuteOrder(o, d, bot.OrderManager, &fakeFund{}) - if !errors.Is(err, exchange.ErrCredentialsAreEmpty) { - t.Errorf("received: %v but expected: %v", err, exchange.ErrCredentialsAreEmpty) - } + assert.ErrorIs(t, err, exchange.ErrCredentialsAreEmpty) o.LiquidatingPosition = true _, err = e.ExecuteOrder(o, d, bot.OrderManager, &fakeFund{}) @@ -323,9 +316,7 @@ func TestExecuteOrder(t *testing.T) { e.CurrencySettings[0].Asset = asset.Spot e.CurrencySettings[0].UseRealOrders = false _, err = e.ExecuteOrder(o, d, bot.OrderManager, &fakeFund{}) - if !errors.Is(err, gctorder.ErrAmountIsInvalid) { - t.Errorf("received: %v but expected: %v", err, gctorder.ErrAmountIsInvalid) - } + assert.ErrorIs(t, err, gctorder.ErrAmountIsInvalid) } func TestExecuteOrderBuySellSizeLimit(t *testing.T) { @@ -509,9 +500,7 @@ func TestApplySlippageToPrice(t *testing.T) { } _, err = applySlippageToPrice(gctorder.UnknownSide, decimal.NewFromInt(1), decimal.NewFromFloat(0.9)) - if !errors.Is(err, gctorder.ErrSideIsInvalid) { - t.Errorf("received '%v' expected '%v'", err, nil) - } + assert.ErrorIs(t, err, gctorder.ErrSideIsInvalid) } func TestReduceAmountToFitPortfolioLimit(t *testing.T) { @@ -538,19 +527,14 @@ func TestReduceAmountToFitPortfolioLimit(t *testing.T) { func TestVerifyOrderWithinLimits(t *testing.T) { t.Parallel() err := verifyOrderWithinLimits(nil, decimal.Zero, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received %v expected %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) err = verifyOrderWithinLimits(&fill.Fill{}, decimal.Zero, nil) - if !errors.Is(err, errNilCurrencySettings) { - t.Errorf("received %v expected %v", err, errNilCurrencySettings) - } + assert.ErrorIs(t, err, errNilCurrencySettings) err = verifyOrderWithinLimits(&fill.Fill{}, decimal.Zero, &Settings{}) - if !errors.Is(err, errInvalidDirection) { - t.Errorf("received %v expected %v", err, errInvalidDirection) - } + assert.ErrorIs(t, err, errInvalidDirection) + f := &fill.Fill{ Direction: gctorder.Buy, } @@ -565,14 +549,11 @@ func TestVerifyOrderWithinLimits(t *testing.T) { } f.Base = &event.Base{} err = verifyOrderWithinLimits(f, decimal.NewFromFloat(0.5), s) - if !errors.Is(err, errExceededPortfolioLimit) { - t.Errorf("received %v expected %v", err, errExceededPortfolioLimit) - } + assert.ErrorIs(t, err, errExceededPortfolioLimit) + f.Direction = gctorder.Buy err = verifyOrderWithinLimits(f, decimal.NewFromInt(2), s) - if !errors.Is(err, errExceededPortfolioLimit) { - t.Errorf("received %v expected %v", err, errExceededPortfolioLimit) - } + assert.ErrorIs(t, err, errExceededPortfolioLimit) f.Direction = gctorder.Sell s.SellSide = MinMax{ @@ -580,25 +561,18 @@ func TestVerifyOrderWithinLimits(t *testing.T) { MaximumSize: decimal.NewFromInt(1), } err = verifyOrderWithinLimits(f, decimal.NewFromFloat(0.5), s) - if !errors.Is(err, errExceededPortfolioLimit) { - t.Errorf("received %v expected %v", err, errExceededPortfolioLimit) - } + assert.ErrorIs(t, err, errExceededPortfolioLimit) + f.Direction = gctorder.Sell err = verifyOrderWithinLimits(f, decimal.NewFromInt(2), s) - if !errors.Is(err, errExceededPortfolioLimit) { - t.Errorf("received %v expected %v", err, errExceededPortfolioLimit) - } + assert.ErrorIs(t, err, errExceededPortfolioLimit) } func TestAllocateFundsPostOrder(t *testing.T) { t.Parallel() - expectedError := common.ErrNilEvent err := allocateFundsPostOrder(nil, nil, nil, decimal.Zero, decimal.Zero, decimal.Zero, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) - expectedError = gctcommon.ErrNilPointer f := &fill.Fill{ Base: &event.Base{ AssetType: asset.Spot, @@ -606,101 +580,66 @@ func TestAllocateFundsPostOrder(t *testing.T) { Direction: gctorder.Buy, } err = allocateFundsPostOrder(f, nil, nil, decimal.Zero, decimal.Zero, decimal.Zero, decimal.Zero, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) - expectedError = nil one := decimal.NewFromInt(1) item, err := funding.CreateItem(testExchange, asset.Spot, currency.BTC, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + item2, err := funding.CreateItem(testExchange, asset.Spot, currency.USDT, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + err = item.Reserve(one) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "Reserve must not error") + err = item2.Reserve(one) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "Reserve must not error") + fundPair, err := funding.CreatePair(item, item2) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreatePair must not error") + f.Order = &gctorder.Detail{} err = allocateFundsPostOrder(f, fundPair, nil, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "allocateFundsPostOrder must not error") + f.SetDirection(gctorder.Sell) err = allocateFundsPostOrder(f, fundPair, nil, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "allocateFundsPostOrder must not error") - expectedError = gctorder.ErrSubmissionIsNil - orderError := gctorder.ErrSubmissionIsNil - err = allocateFundsPostOrder(f, fundPair, orderError, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + err = allocateFundsPostOrder(f, fundPair, gctorder.ErrSubmissionIsNil, one, one, one, one, decimal.Zero) + assert.ErrorIs(t, err, gctorder.ErrSubmissionIsNil) f.AssetType = asset.Futures f.SetDirection(gctorder.Short) - expectedError = nil item3, err := funding.CreateItem(testExchange, asset.Futures, currency.BTC, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + item4, err := funding.CreateItem(testExchange, asset.Futures, currency.USDT, decimal.NewFromInt(1337), decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateItem must not error") + err = item3.Reserve(one) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "Reserve must not error") + err = item4.Reserve(one) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "Reserve must not error") + collateralPair, err := funding.CreateCollateral(item, item2) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "CreateCollateral must not error") + + err = allocateFundsPostOrder(f, collateralPair, gctorder.ErrSubmissionIsNil, one, one, one, one, decimal.Zero) + assert.ErrorIs(t, err, gctorder.ErrSubmissionIsNil) - expectedError = gctorder.ErrSubmissionIsNil - err = allocateFundsPostOrder(f, collateralPair, orderError, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - expectedError = nil err = allocateFundsPostOrder(f, collateralPair, nil, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "allocateFundsPostOrder must not error") - expectedError = gctorder.ErrSubmissionIsNil f.SetDirection(gctorder.Long) - err = allocateFundsPostOrder(f, collateralPair, orderError, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - expectedError = nil + err = allocateFundsPostOrder(f, collateralPair, gctorder.ErrSubmissionIsNil, one, one, one, one, decimal.Zero) + assert.ErrorIs(t, err, gctorder.ErrSubmissionIsNil) + err = allocateFundsPostOrder(f, collateralPair, nil, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "allocateFundsPostOrder should not error") f.AssetType = asset.Margin - expectedError = common.ErrInvalidDataType err = allocateFundsPostOrder(f, collateralPair, nil, one, one, one, one, decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrInvalidDataType) } diff --git a/backtester/eventhandlers/portfolio/compliance/compliance_test.go b/backtester/eventhandlers/portfolio/compliance/compliance_test.go index dd5b8d57..74fe9056 100644 --- a/backtester/eventhandlers/portfolio/compliance/compliance_test.go +++ b/backtester/eventhandlers/portfolio/compliance/compliance_test.go @@ -1,7 +1,6 @@ package compliance import ( - "errors" "testing" "time" @@ -15,9 +14,7 @@ func TestAddSnapshot(t *testing.T) { m := Manager{} tt := time.Now() err := m.AddSnapshot(&Snapshot{}, true) - if !errors.Is(err, errSnapshotNotFound) { - t.Errorf("received: %v, expected: %v", err, errSnapshotNotFound) - } + assert.ErrorIs(t, err, errSnapshotNotFound) err = m.AddSnapshot(&Snapshot{ Timestamp: tt, @@ -70,9 +67,7 @@ func TestGetSnapshotAtTime(t *testing.T) { } _, err = m.GetSnapshotAtTime(time.Now().Add(time.Hour)) - if !errors.Is(err, errSnapshotNotFound) { - t.Errorf("received: %v, expected: %v", err, errSnapshotNotFound) - } + assert.ErrorIs(t, err, errSnapshotNotFound) } func TestGetLatestSnapshot(t *testing.T) { diff --git a/backtester/eventhandlers/portfolio/holdings/holdings_test.go b/backtester/eventhandlers/portfolio/holdings/holdings_test.go index a744bef9..bacfd42f 100644 --- a/backtester/eventhandlers/portfolio/holdings/holdings_test.go +++ b/backtester/eventhandlers/portfolio/holdings/holdings_test.go @@ -1,7 +1,6 @@ package holdings import ( - "errors" "testing" "time" @@ -58,9 +57,8 @@ func collateral(t *testing.T) *funding.CollateralPair { func TestCreate(t *testing.T) { t.Parallel() _, err := Create(nil, pair(t)) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + _, err = Create(&fill.Fill{ Base: &event.Base{AssetType: asset.Spot}, }, pair(t)) @@ -101,9 +99,7 @@ func TestUpdateValue(t *testing.T) { assert.NoError(t, err) err = h.UpdateValue(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) h.BaseSize = decimal.NewFromInt(1) err = h.UpdateValue(&kline.Kline{ diff --git a/backtester/eventhandlers/portfolio/portfolio_test.go b/backtester/eventhandlers/portfolio/portfolio_test.go index 4fe9cff8..c733641c 100644 --- a/backtester/eventhandlers/portfolio/portfolio_test.go +++ b/backtester/eventhandlers/portfolio/portfolio_test.go @@ -1,7 +1,6 @@ package portfolio import ( - "errors" "testing" "time" @@ -48,27 +47,20 @@ func TestReset(t *testing.T) { p = nil err = p.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestSetup(t *testing.T) { t.Parallel() _, err := Setup(nil, nil, decimal.NewFromInt(-1)) - if !errors.Is(err, errSizeManagerUnset) { - t.Errorf("received: %v, expected: %v", err, errSizeManagerUnset) - } + assert.ErrorIs(t, err, errSizeManagerUnset) _, err = Setup(&size.Size{}, nil, decimal.NewFromInt(-1)) - if !errors.Is(err, errNegativeRiskFreeRate) { - t.Errorf("received: %v, expected: %v", err, errNegativeRiskFreeRate) - } + assert.ErrorIs(t, err, errNegativeRiskFreeRate) _, err = Setup(&size.Size{}, nil, decimal.NewFromInt(1)) - if !errors.Is(err, errRiskManagerUnset) { - t.Errorf("received: %v, expected: %v", err, errRiskManagerUnset) - } + assert.ErrorIs(t, err, errRiskManagerUnset) + var p *Portfolio p, err = Setup(&size.Size{}, &risk.Risk{}, decimal.NewFromInt(1)) assert.NoError(t, err) @@ -82,26 +74,18 @@ func TestSetupCurrencySettingsMap(t *testing.T) { t.Parallel() p := &Portfolio{} err := p.SetCurrencySettingsMap(nil) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) err = p.SetCurrencySettingsMap(&exchange.Settings{}) - if !errors.Is(err, errExchangeUnset) { - t.Errorf("received: %v, expected: %v", err, errExchangeUnset) - } + assert.ErrorIs(t, err, errExchangeUnset) ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff}) - if !errors.Is(err, errAssetUnset) { - t.Errorf("received: %v, expected: %v", err, errAssetUnset) - } + assert.ErrorIs(t, err, errAssetUnset) err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Spot}) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("received: %v, expected: %v", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Spot, Pair: currency.NewBTCUSDT()}) assert.NoError(t, err) @@ -112,15 +96,12 @@ func TestSetHoldings(t *testing.T) { p := &Portfolio{} err := p.SetHoldingsForTimestamp(&holdings.Holding{}) - if !errors.Is(err, errHoldingsNoTimestamp) { - t.Errorf("received: %v, expected: %v", err, errHoldingsNoTimestamp) - } + assert.ErrorIs(t, err, errHoldingsNoTimestamp) + tt := time.Now() err = p.SetHoldingsForTimestamp(&holdings.Holding{Timestamp: tt}) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -158,9 +139,7 @@ func TestGetLatestHoldingsForAllCurrencies(t *testing.T) { Pair: currency.NewBTCUSDT(), Timestamp: tt, }) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -212,9 +191,7 @@ func TestViewHoldingAtTimePeriod(t *testing.T) { }, } _, err := p.ViewHoldingAtTimePeriod(s) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -222,9 +199,7 @@ func TestViewHoldingAtTimePeriod(t *testing.T) { assert.NoError(t, err) _, err = p.ViewHoldingAtTimePeriod(s) - if !errors.Is(err, errNoHoldings) { - t.Errorf("received: %v, expected: %v", err, errNoHoldings) - } + assert.ErrorIs(t, err, errNoHoldings) err = p.SetHoldingsForTimestamp(&holdings.Holding{ Offset: 1, @@ -258,14 +233,11 @@ func TestUpdate(t *testing.T) { t.Parallel() p := Portfolio{} err := p.UpdateHoldings(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) err = p.UpdateHoldings(&kline.Kline{}, nil) - if !errors.Is(err, funding.ErrFundsNotFound) { - t.Errorf("received '%v' expected '%v'", err, funding.ErrFundsNotFound) - } + assert.ErrorIs(t, err, funding.ErrFundsNotFound) + bc, err := funding.CreateItem(testExchange, asset.Spot, currency.BTC, decimal.NewFromInt(1), decimal.Zero) if err != nil { t.Fatal(err) @@ -281,9 +253,7 @@ func TestUpdate(t *testing.T) { err = p.UpdateHoldings(&kline.Kline{ Base: b, }, pair) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v' expected '%v'", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) tt := time.Now() err = p.SetHoldingsForTimestamp(&holdings.Holding{ @@ -293,9 +263,7 @@ func TestUpdate(t *testing.T) { Pair: currency.NewBTCUSDT(), Timestamp: tt, }) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -316,9 +284,7 @@ func TestGetComplianceManager(t *testing.T) { t.Parallel() p := Portfolio{} _, err := p.getComplianceManager("", asset.Empty, currency.EMPTYPAIR) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -338,16 +304,12 @@ func TestAddComplianceSnapshot(t *testing.T) { t.Parallel() p := Portfolio{} err := p.addComplianceSnapshot(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) err = p.addComplianceSnapshot(&fill.Fill{ Base: &event.Base{}, }) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange @@ -373,9 +335,7 @@ func TestOnFill(t *testing.T) { t.Parallel() p := Portfolio{} _, err := p.OnFill(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) f := &fill.Fill{ Base: &event.Base{ @@ -390,9 +350,8 @@ func TestOnFill(t *testing.T) { }, } _, err = p.OnFill(f, nil) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Spot, Pair: currency.NewBTCUSDT()}) @@ -411,9 +370,7 @@ func TestOnFill(t *testing.T) { t.Fatal(err) } _, err = p.OnFill(f, pair) - if !errors.Is(err, errHoldingsNoTimestamp) { - t.Errorf("received: %v, expected: %v", err, errHoldingsNoTimestamp) - } + assert.ErrorIs(t, err, errHoldingsNoTimestamp) f.Time = time.Now() _, err = p.OnFill(f, pair) @@ -428,30 +385,25 @@ func TestOnSignal(t *testing.T) { t.Parallel() p := Portfolio{} _, err := p.OnSignal(nil, nil, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + b := &event.Base{} s := &signal.Signal{ Base: b, } _, err = p.OnSignal(s, &exchange.Settings{}, nil) - if !errors.Is(err, errSizeManagerUnset) { - t.Errorf("received: %v, expected: %v", err, errSizeManagerUnset) - } + assert.ErrorIs(t, err, errSizeManagerUnset) + p.sizeManager = &size.Size{} _, err = p.OnSignal(s, &exchange.Settings{}, nil) - if !errors.Is(err, errRiskManagerUnset) { - t.Errorf("received: %v, expected: %v", err, errRiskManagerUnset) - } + assert.ErrorIs(t, err, errRiskManagerUnset) p.riskManager = &risk.Risk{} _, err = p.OnSignal(s, &exchange.Settings{}, nil) - if !errors.Is(err, funding.ErrFundsNotFound) { - t.Errorf("received: %v, expected: %v", err, funding.ErrFundsNotFound) - } + assert.ErrorIs(t, err, funding.ErrFundsNotFound) + bc, err := funding.CreateItem(testExchange, asset.Spot, currency.BTC, leet, decimal.Zero) if err != nil { t.Fatal(err) @@ -465,15 +417,12 @@ func TestOnSignal(t *testing.T) { t.Fatal(err) } _, err = p.OnSignal(s, &exchange.Settings{}, funds) - if !errors.Is(err, errInvalidDirection) { - t.Errorf("received: %v, expected: %v", err, errInvalidDirection) - } + assert.ErrorIs(t, err, errInvalidDirection) s.Direction = gctorder.Buy _, err = p.OnSignal(s, &exchange.Settings{}, funds) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Spot, Pair: currency.NewBTCUSD()}) @@ -515,9 +464,8 @@ func TestOnSignal(t *testing.T) { Timestamp: time.Now(), QuoteSize: leet, }) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + cs := &exchange.Settings{Exchange: ff, Asset: asset.Spot, Pair: currency.NewBTCUSD()} err = p.SetCurrencySettingsMap(cs) assert.NoError(t, err) @@ -554,19 +502,16 @@ func TestOnSignal(t *testing.T) { cs.Asset = asset.Futures err = p.SetCurrencySettingsMap(cs) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + s.Direction = gctorder.Long _, err = p.OnSignal(s, cs, collateralFunds) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + cp := currency.NewBTCUSD() _, err = p.getSettings(testExchange, asset.Futures, cp) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + exchangeSettings := &Settings{} exchangeSettings.FuturesTracker, err = futures.SetupMultiPositionTracker(&futures.MultiPositionTrackerSetup{ Exchange: testExchange, @@ -595,9 +540,7 @@ func TestOnSignal(t *testing.T) { s.Direction = gctorder.ClosePosition _, err = p.OnSignal(s, cs, collateralFunds) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) } func TestGetLatestHoldings(t *testing.T) { @@ -606,9 +549,7 @@ func TestGetLatestHoldings(t *testing.T) { HoldingsSnapshots: make(map[int64]*holdings.Holding), } _, err := s.GetLatestHoldings() - if !errors.Is(err, errNoHoldings) { - t.Errorf("received: %v, expected: %v", err, errNoHoldings) - } + assert.ErrorIs(t, err, errNoHoldings) tt := time.Now() s.HoldingsSnapshots[tt.UnixNano()] = &holdings.Holding{Timestamp: tt} @@ -628,9 +569,8 @@ func TestGetSnapshotAtTime(t *testing.T) { _, err := p.GetLatestOrderSnapshotForEvent(&kline.Kline{ Base: b, }) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + cp := currency.NewPair(currency.XRP, currency.DOGE) ff := &binance.Binance{} ff.Name = testExchange @@ -685,9 +625,8 @@ func TestGetLatestSnapshot(t *testing.T) { t.Parallel() p := Portfolio{} _, err := p.GetLatestOrderSnapshots() - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + cp := currency.NewPair(currency.XRP, currency.DOGE) ff := &binance.Binance{} ff.Name = testExchange @@ -751,9 +690,7 @@ func TestCalculatePNL(t *testing.T) { Base: &event.Base{}, } err := p.UpdatePNL(ev, decimal.Zero) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received: %v, expected: %v", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) exch := &binance.Binance{} exch.Name = testExchange @@ -767,9 +704,8 @@ func TestCalculatePNL(t *testing.T) { Pair: pair, Asset: a, }) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + tt := time.Now().Add(time.Hour) tt0 := time.Now().Add(-time.Hour) ev.Exchange = exch.Name @@ -778,9 +714,7 @@ func TestCalculatePNL(t *testing.T) { ev.Time = tt0 err = p.UpdatePNL(ev, decimal.Zero) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received: %v, expected: %v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) od := &gctorder.Detail{ Price: 1336, @@ -864,34 +798,27 @@ func TestTrackFuturesOrder(t *testing.T) { t.Parallel() p := &Portfolio{} _, err := p.TrackFuturesOrder(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + _, err = p.TrackFuturesOrder(&fill.Fill{}, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + fundPair := &funding.SpotPair{} _, err = p.TrackFuturesOrder(&fill.Fill{}, fundPair) - if !errors.Is(err, gctorder.ErrSubmissionIsNil) { - t.Errorf("received '%v' expected '%v", err, gctorder.ErrSubmissionIsNil) - } + assert.ErrorIs(t, err, gctorder.ErrSubmissionIsNil) od := &gctorder.Detail{} _, err = p.TrackFuturesOrder(&fill.Fill{ Order: od, }, fundPair) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) od.AssetType = asset.Futures _, err = p.TrackFuturesOrder(&fill.Fill{ Order: od, }, fundPair) - if !errors.Is(err, funding.ErrNotCollateral) { - t.Errorf("received '%v' expected '%v", err, funding.ErrNotCollateral) - } + assert.ErrorIs(t, err, funding.ErrNotCollateral) + cp := currency.NewBTCUSD() od.Pair = cp od.Exchange = testExchange @@ -925,16 +852,12 @@ func TestTrackFuturesOrder(t *testing.T) { _, err = p.TrackFuturesOrder(&fill.Fill{ Order: od, }, collat) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v' expected '%v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = od.Exchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Futures, Pair: cp}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) _, err = p.TrackFuturesOrder(&fill.Fill{ Order: od, @@ -944,9 +867,7 @@ func TestTrackFuturesOrder(t *testing.T) { CurrencyPair: cp, }, }, collat) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v' expected '%v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) od.Side = gctorder.Long _, err = p.TrackFuturesOrder(&fill.Fill{ @@ -958,9 +879,7 @@ func TestTrackFuturesOrder(t *testing.T) { Time: od.Date, }, }, collat) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v' expected '%v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) _, err = p.TrackFuturesOrder(&fill.Fill{ Order: od, @@ -972,9 +891,7 @@ func TestTrackFuturesOrder(t *testing.T) { Time: od.Date, }, }, collat) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v' expected '%v", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) } func TestGetHoldingsForTime(t *testing.T) { @@ -983,18 +900,15 @@ func TestGetHoldingsForTime(t *testing.T) { HoldingsSnapshots: make(map[int64]*holdings.Holding), } _, err := s.GetHoldingsForTime(time.Now()) - if !errors.Is(err, errNoHoldings) { - t.Errorf("received '%v' expected '%v", err, errNoHoldings) - } + assert.ErrorIs(t, err, errNoHoldings) + tt := time.Now() s.HoldingsSnapshots[tt.UnixNano()] = &holdings.Holding{ Timestamp: tt, Offset: 1337, } _, err = s.GetHoldingsForTime(time.Unix(1337, 0)) - if !errors.Is(err, errNoHoldings) { - t.Errorf("received '%v' expected '%v", err, errNoHoldings) - } + assert.ErrorIs(t, err, errNoHoldings) h, err := s.GetHoldingsForTime(tt) assert.NoError(t, err) @@ -1007,11 +921,9 @@ func TestGetHoldingsForTime(t *testing.T) { func TestGetPositions(t *testing.T) { t.Parallel() p := &Portfolio{} - expectedError := common.ErrNilEvent _, err := p.GetPositions(nil) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + ev := &fill.Fill{ Base: &event.Base{ Exchange: testExchange, @@ -1022,24 +934,18 @@ func TestGetPositions(t *testing.T) { ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: ev.AssetType, Pair: ev.Pair()}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } - expectedError = errNoPortfolioSettings + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + _, err = p.GetPositions(ev) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) } func TestGetLatestPNLForEvent(t *testing.T) { t.Parallel() p := &Portfolio{} - expectedError := common.ErrNilEvent _, err := p.GetLatestPNLForEvent(nil) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + ev := &fill.Fill{ Base: &event.Base{ Exchange: testExchange, @@ -1050,14 +956,10 @@ func TestGetLatestPNLForEvent(t *testing.T) { ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: ev.AssetType, Pair: ev.Pair()}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } - expectedError = errNoPortfolioSettings + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + _, err = p.GetLatestPNLForEvent(ev) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) mpt, err := futures.SetupMultiPositionTracker(&futures.MultiPositionTrackerSetup{ Exchange: testExchange, @@ -1067,8 +969,7 @@ func TestGetLatestPNLForEvent(t *testing.T) { CollateralCurrency: currency.USDT, OfflineCalculation: true, }) - assert.NoError(t, err) - + require.NoError(t, err, "SetupMultiPositionTracker must not error") s := &Settings{ FuturesTracker: mpt, } @@ -1080,7 +981,6 @@ func TestGetLatestPNLForEvent(t *testing.T) { Quote: ev.Pair().Quote.Item, Asset: asset.Futures, }] = s - expectedError = nil err = s.FuturesTracker.TrackNewOrder(&gctorder.Detail{ Exchange: ev.GetExchange(), AssetType: ev.AssetType, @@ -1091,32 +991,25 @@ func TestGetLatestPNLForEvent(t *testing.T) { Date: time.Now(), Side: gctorder.Buy, }) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + require.NoError(t, err, "TrackNewOrder must not error") + latest, err := p.GetLatestPNLForEvent(ev) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if latest == nil { - t.Error("unexpected") - } + require.NoError(t, err, "GetLatestPNLForEvent must not error") + assert.NotNil(t, latest, "GetLatestPNLForEvent should return a non-nil result") } func TestGetFuturesSettingsFromEvent(t *testing.T) { t.Parallel() p := &Portfolio{} _, err := p.getFuturesSettingsFromEvent(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Fatalf("received '%v' expected '%v'", err, common.ErrNilEvent) - } + require.ErrorIs(t, err, common.ErrNilEvent) + b := &event.Base{} _, err = p.getFuturesSettingsFromEvent(&fill.Fill{ Base: b, }) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Fatalf("received '%v' expected '%v'", err, futures.ErrNotFuturesAsset) - } + require.ErrorIs(t, err, futures.ErrNotFuturesAsset) + b.Exchange = testExchange b.CurrencyPair = currency.NewBTCUSDT() b.AssetType = asset.Futures @@ -1124,25 +1017,18 @@ func TestGetFuturesSettingsFromEvent(t *testing.T) { Base: b, } _, err = p.getFuturesSettingsFromEvent(ev) - if !errors.Is(err, errNoPortfolioSettings) { - t.Fatalf("received '%v' expected '%v'", err, errNoPortfolioSettings) - } + require.ErrorIs(t, err, errNoPortfolioSettings) ff := &binance.Binance{} ff.Name = testExchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: ev.AssetType, Pair: ev.Pair()}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } - _, err = p.getFuturesSettingsFromEvent(ev) - if !errors.Is(err, errNoPortfolioSettings) { - t.Fatalf("received '%v' expected '%v'", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) _, err = p.getFuturesSettingsFromEvent(ev) - if !errors.Is(err, errNoPortfolioSettings) { - t.Fatalf("received '%v' expected '%v'", err, errNoPortfolioSettings) - } + require.ErrorIs(t, err, errNoPortfolioSettings) + + _, err = p.getFuturesSettingsFromEvent(ev) + require.ErrorIs(t, err, errNoPortfolioSettings) } func TestGetUnrealisedPNL(t *testing.T) { @@ -1287,73 +1173,42 @@ func TestGetDirection(t *testing.T) { func TestCannotPurchase(t *testing.T) { t.Parallel() - expectedError := common.ErrNilEvent _, err := cannotPurchase(nil, nil) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) s := &signal.Signal{ Base: &event.Base{}, } - expectedError = gctcommon.ErrNilPointer _, err = cannotPurchase(s, nil) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) o := &order.Order{ Base: &event.Base{}, } s.Direction = gctorder.Buy - expectedError = nil result, err := cannotPurchase(s, o) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if result.Direction != gctorder.CouldNotBuy { - t.Errorf("received '%v' expected '%v'", result.Direction, gctorder.CouldNotBuy) - } + require.NoError(t, err, "cannotPurchase must not error") + assert.Equal(t, gctorder.CouldNotBuy, result.Direction) s.Direction = gctorder.Sell - expectedError = nil result, err = cannotPurchase(s, o) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if result.Direction != gctorder.CouldNotSell { - t.Errorf("received '%v' expected '%v'", result.Direction, gctorder.CouldNotSell) - } + require.NoError(t, err, "cannotPurchase must not error") + assert.Equal(t, gctorder.CouldNotSell, result.Direction) s.Direction = gctorder.Short - expectedError = nil result, err = cannotPurchase(s, o) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if result.Direction != gctorder.CouldNotShort { - t.Errorf("received '%v' expected '%v'", result.Direction, gctorder.CouldNotShort) - } + require.NoError(t, err, "cannotPurchase must not error") + assert.Equal(t, gctorder.CouldNotShort, result.Direction) s.Direction = gctorder.Long - expectedError = nil result, err = cannotPurchase(s, o) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if result.Direction != gctorder.CouldNotLong { - t.Errorf("received '%v' expected '%v'", result.Direction, gctorder.CouldNotLong) - } + require.NoError(t, err, "cannotPurchase must not error") + assert.Equal(t, gctorder.CouldNotLong, result.Direction) s.Direction = gctorder.UnknownSide - expectedError = nil result, err = cannotPurchase(s, o) - if !errors.Is(err, expectedError) { - t.Fatalf("received '%v' expected '%v'", err, expectedError) - } - if result.Direction != gctorder.DoNothing { - t.Errorf("received '%v' expected '%v'", result.Direction, gctorder.DoNothing) - } + require.NoError(t, err, "cannotPurchase must not error") + assert.Equal(t, gctorder.DoNothing, result.Direction) } func TestCreateLiquidationOrdersForExchange(t *testing.T) { @@ -1361,9 +1216,7 @@ func TestCreateLiquidationOrdersForExchange(t *testing.T) { p := &Portfolio{} _, err := p.CreateLiquidationOrdersForExchange(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Fatalf("received '%v' expected '%v'", err, common.ErrNilEvent) - } + require.ErrorIs(t, err, common.ErrNilEvent) b := &event.Base{} @@ -1371,9 +1224,7 @@ func TestCreateLiquidationOrdersForExchange(t *testing.T) { Base: b, } _, err = p.CreateLiquidationOrdersForExchange(ev, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Fatalf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + require.ErrorIs(t, err, gctcommon.ErrNilPointer) funds := &funding.FundManager{} _, err = p.CreateLiquidationOrdersForExchange(ev, funds) @@ -1383,9 +1234,8 @@ func TestCreateLiquidationOrdersForExchange(t *testing.T) { ff.Name = testExchange cp := currency.NewBTCUSDT() err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Futures, Pair: cp}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: ff, Asset: asset.Spot, Pair: cp}) assert.NoError(t, err) @@ -1394,9 +1244,7 @@ func TestCreateLiquidationOrdersForExchange(t *testing.T) { require.NoError(t, err) _, err = p.getSettings(ff.Name, asset.Futures, cp) - if !errors.Is(err, errNoPortfolioSettings) { - t.Fatalf("received '%v' expected '%v'", err, errNoPortfolioSettings) - } + require.ErrorIs(t, err, errNoPortfolioSettings) od := &gctorder.Detail{ Exchange: ff.Name, @@ -1475,17 +1323,13 @@ func TestCheckLiquidationStatus(t *testing.T) { t.Parallel() p := &Portfolio{} err := p.CheckLiquidationStatus(nil, nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) ev := &kline.Kline{ Base: &event.Base{}, } err = p.CheckLiquidationStatus(ev, nil, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) item := asset.Futures pair := currency.NewBTCUSDT() @@ -1499,15 +1343,11 @@ func TestCheckLiquidationStatus(t *testing.T) { assert.NoError(t, err) err = p.CheckLiquidationStatus(ev, collat, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) pnl := &PNLSummary{} err = p.CheckLiquidationStatus(ev, collat, pnl) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) pnl.Asset = asset.Futures ev.AssetType = asset.Futures @@ -1516,13 +1356,11 @@ func TestCheckLiquidationStatus(t *testing.T) { exch := &binance.Binance{} exch.Name = ev.Exchange err = p.SetCurrencySettingsMap(&exchange.Settings{Exchange: exch, Asset: asset.Futures, Pair: pair}) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) + _, err = p.getSettings(ev.Exchange, ev.AssetType, ev.Pair()) - if !errors.Is(err, errNoPortfolioSettings) { - t.Errorf("received '%v', expected '%v'", err, errNoPortfolioSettings) - } + assert.ErrorIs(t, err, errNoPortfolioSettings) + od := &gctorder.Detail{ Price: 1336, Amount: 20, @@ -1565,9 +1403,7 @@ func TestSetHoldingsForEvent(t *testing.T) { t.Parallel() p := &Portfolio{} err := p.SetHoldingsForEvent(nil, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) item, err := funding.CreateItem(testExchange, asset.Spot, currency.BTC, decimal.Zero, decimal.Zero) assert.NoError(t, err) @@ -1576,14 +1412,10 @@ func TestSetHoldingsForEvent(t *testing.T) { assert.NoError(t, err) err = p.SetHoldingsForEvent(cp.FundReader(), nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) err = p.SetHoldingsForEvent(cp.FundReader(), nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) tt := time.Now() ev := &signal.Signal{ diff --git a/backtester/eventhandlers/portfolio/risk/risk_test.go b/backtester/eventhandlers/portfolio/risk/risk_test.go index d203e01d..16f4ecf4 100644 --- a/backtester/eventhandlers/portfolio/risk/risk_test.go +++ b/backtester/eventhandlers/portfolio/risk/risk_test.go @@ -1,7 +1,6 @@ package risk import ( - "errors" "testing" "github.com/shopspring/decimal" @@ -56,9 +55,8 @@ func TestEvaluateOrder(t *testing.T) { t.Parallel() r := Risk{} _, err := r.EvaluateOrder(nil, nil, compliance.Snapshot{}) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + p := currency.NewBTCUSDT() e := "binance" a := asset.Spot @@ -72,9 +70,7 @@ func TestEvaluateOrder(t *testing.T) { h := []holdings.Holding{} r.CurrencySettings = make(map[key.ExchangePairAsset]*CurrencySettings) _, err = r.EvaluateOrder(o, h, compliance.Snapshot{}) - if !errors.Is(err, errNoCurrencySettings) { - t.Error(err) - } + assert.ErrorIs(t, err, errNoCurrencySettings) r.CurrencySettings[key.ExchangePairAsset{ Exchange: e, @@ -105,14 +101,11 @@ func TestEvaluateOrder(t *testing.T) { Asset: a, }].MaximumHoldingRatio = decimal.Zero _, err = r.EvaluateOrder(o, h, compliance.Snapshot{}) - if !errors.Is(err, errLeverageNotAllowed) { - t.Error(err) - } + assert.ErrorIs(t, err, errLeverageNotAllowed) + r.CanUseLeverage = true _, err = r.EvaluateOrder(o, h, compliance.Snapshot{}) - if !errors.Is(err, errCannotPlaceLeverageOrder) { - t.Error(err) - } + assert.ErrorIs(t, err, errCannotPlaceLeverageOrder) r.MaximumLeverage = decimal.NewFromInt(33) r.CurrencySettings[key.ExchangePairAsset{ @@ -141,9 +134,7 @@ func TestEvaluateOrder(t *testing.T) { }, }, }) - if !errors.Is(err, errCannotPlaceLeverageOrder) { - t.Error(err) - } + assert.ErrorIs(t, err, errCannotPlaceLeverageOrder) h = append(h, holdings.Holding{Pair: p, BaseValue: decimal.NewFromInt(1337)}, holdings.Holding{Pair: p, BaseValue: decimal.NewFromFloat(1337.42)}) r.CurrencySettings[key.ExchangePairAsset{ @@ -157,7 +148,5 @@ func TestEvaluateOrder(t *testing.T) { h = append(h, holdings.Holding{Pair: currency.NewPair(currency.DOGE, currency.LTC), BaseValue: decimal.NewFromInt(1337)}) _, err = r.EvaluateOrder(o, h, compliance.Snapshot{}) - if !errors.Is(err, errCannotPlaceLeverageOrder) { - t.Error(err) - } + assert.ErrorIs(t, err, errCannotPlaceLeverageOrder) } diff --git a/backtester/eventhandlers/portfolio/size/size_test.go b/backtester/eventhandlers/portfolio/size/size_test.go index a7221c84..acdd5f28 100644 --- a/backtester/eventhandlers/portfolio/size/size_test.go +++ b/backtester/eventhandlers/portfolio/size/size_test.go @@ -1,7 +1,6 @@ package size import ( - "errors" "testing" "time" @@ -79,9 +78,7 @@ func TestSizingUnderMinSize(t *testing.T) { feeRate := decimal.NewFromFloat(0.02) buyLimit := decimal.NewFromInt(1) _, _, err := sizer.calculateBuySize(price, availableFunds, feeRate, buyLimit, globalMinMax) - if !errors.Is(err, errLessThanMinimum) { - t.Errorf("received: %v, expected: %v", err, errLessThanMinimum) - } + assert.ErrorIs(t, err, errLessThanMinimum) } func TestMaximumBuySizeEqualZero(t *testing.T) { @@ -140,9 +137,7 @@ func TestSizingErrors(t *testing.T) { feeRate := decimal.NewFromFloat(0.02) buyLimit := decimal.NewFromInt(1) _, _, err := sizer.calculateBuySize(price, availableFunds, feeRate, buyLimit, globalMinMax) - if !errors.Is(err, errNoFunds) { - t.Errorf("received: %v, expected: %v", err, errNoFunds) - } + assert.ErrorIs(t, err, errNoFunds) } func TestCalculateSellSize(t *testing.T) { @@ -161,14 +156,12 @@ func TestCalculateSellSize(t *testing.T) { feeRate := decimal.NewFromFloat(0.02) sellLimit := decimal.NewFromInt(1) _, _, err := sizer.calculateSellSize(price, availableFunds, feeRate, sellLimit, globalMinMax) - if !errors.Is(err, errNoFunds) { - t.Errorf("received: %v, expected: %v", err, errNoFunds) - } + assert.ErrorIs(t, err, errNoFunds) + availableFunds = decimal.NewFromInt(1337) _, _, err = sizer.calculateSellSize(price, availableFunds, feeRate, sellLimit, globalMinMax) - if !errors.Is(err, errLessThanMinimum) { - t.Errorf("received: %v, expected: %v", err, errLessThanMinimum) - } + assert.ErrorIs(t, err, errLessThanMinimum) + price = decimal.NewFromInt(12) availableFunds = decimal.NewFromInt(1339) amount, fee, err := sizer.calculateSellSize(price, availableFunds, feeRate, sellLimit, globalMinMax) @@ -186,9 +179,8 @@ func TestSizeOrder(t *testing.T) { t.Parallel() s := Size{} _, _, err := s.SizeOrder(nil, decimal.Zero, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + o := &order.Order{ Base: &event.Base{ Offset: 1, @@ -201,19 +193,14 @@ func TestSizeOrder(t *testing.T) { } cs := &exchange.Settings{} _, _, err = s.SizeOrder(o, decimal.Zero, cs) - if !errors.Is(err, errNoFunds) { - t.Errorf("received: %v, expected: %v", err, errNoFunds) - } + assert.ErrorIs(t, err, errNoFunds) _, _, err = s.SizeOrder(o, decimal.NewFromInt(1337), cs) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received: %v, expected: %v", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) + o.Direction = gctorder.Buy _, _, err = s.SizeOrder(o, decimal.NewFromInt(1337), cs) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received: %v, expected: %v", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) o.ClosePrice = decimal.NewFromInt(1) s.BuySide.MaximumSize = decimal.NewFromInt(1) @@ -245,14 +232,10 @@ func TestSizeOrder(t *testing.T) { // TODO adjust when Binance futures wrappers are implemented cs.Exchange = &exch _, _, err = s.SizeOrder(o, decimal.NewFromInt(1337), cs) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) o.ClosePrice = decimal.NewFromInt(1000000000) o.Amount = decimal.NewFromInt(1000000000) _, _, err = s.SizeOrder(o, decimal.NewFromInt(1337), cs) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) } diff --git a/backtester/eventhandlers/statistics/currencystatistics_test.go b/backtester/eventhandlers/statistics/currencystatistics_test.go index 80e6eb99..e75f7896 100644 --- a/backtester/eventhandlers/statistics/currencystatistics_test.go +++ b/backtester/eventhandlers/statistics/currencystatistics_test.go @@ -1,7 +1,6 @@ package statistics import ( - "errors" "testing" "time" @@ -284,9 +283,7 @@ func TestCalculateHighestCommittedFunds(t *testing.T) { c.Asset = asset.Binary err = c.calculateHighestCommittedFunds() - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestAnalysePNLGrowth(t *testing.T) { diff --git a/backtester/eventhandlers/statistics/fundingstatistics_test.go b/backtester/eventhandlers/statistics/fundingstatistics_test.go index 5119343c..bc843fac 100644 --- a/backtester/eventhandlers/statistics/fundingstatistics_test.go +++ b/backtester/eventhandlers/statistics/fundingstatistics_test.go @@ -1,7 +1,6 @@ package statistics import ( - "errors" "testing" "time" @@ -22,9 +21,8 @@ import ( func TestCalculateFundingStatistics(t *testing.T) { t.Parallel() _, err := CalculateFundingStatistics(nil, nil, decimal.Zero, gctkline.OneHour) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) + f, err := funding.SetupFundingManager(&engine.ExchangeManager{}, true, true, false) assert.NoError(t, err) @@ -41,9 +39,7 @@ func TestCalculateFundingStatistics(t *testing.T) { assert.NoError(t, err) _, err = CalculateFundingStatistics(f, nil, decimal.Zero, gctkline.OneHour) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) usdKline := gctkline.Item{ Exchange: "binance", @@ -67,9 +63,7 @@ func TestCalculateFundingStatistics(t *testing.T) { assert.NoError(t, err) err = f.AddUSDTrackingData(dfk) - if !errors.Is(err, funding.ErrUSDTrackingDisabled) { - t.Errorf("received %v expected %v", err, funding.ErrUSDTrackingDisabled) - } + assert.ErrorIs(t, err, funding.ErrUSDTrackingDisabled) cs := make(map[key.ExchangePairAsset]*CurrencyPairStatistic) _, err = CalculateFundingStatistics(f, cs, decimal.Zero, gctkline.OneHour) @@ -94,9 +88,8 @@ func TestCalculateFundingStatistics(t *testing.T) { Asset: asset.Spot, }] = &CurrencyPairStatistic{} _, err = CalculateFundingStatistics(f, cs, decimal.Zero, gctkline.OneHour) - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received %v expected %v", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) + err = f.CreateSnapshot(usdKline.Candles[0].Time) assert.NoError(t, err) @@ -115,17 +108,13 @@ func TestCalculateFundingStatistics(t *testing.T) { func TestCalculateIndividualFundingStatistics(t *testing.T) { _, err := CalculateIndividualFundingStatistics(true, nil, nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) _, err = CalculateIndividualFundingStatistics(true, &funding.ReportItem{}, nil) assert.NoError(t, err) _, err = CalculateIndividualFundingStatistics(false, &funding.ReportItem{}, nil) - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received %v expected %v", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) ri := &funding.ReportItem{ Snapshots: []funding.ItemSnapshot{ @@ -146,17 +135,14 @@ func TestCalculateIndividualFundingStatistics(t *testing.T) { }, } _, err = CalculateIndividualFundingStatistics(false, ri, rs) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) rs[0].stat = &CurrencyPairStatistic{} ri.USDInitialFunds = decimal.NewFromInt(1000) ri.USDFinalFunds = decimal.NewFromInt(1337) _, err = CalculateIndividualFundingStatistics(false, ri, rs) - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received %v expected %v", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) + cp := currency.NewBTCUSD() ri.USDPairCandle = &kline.DataFromKline{ Base: &data.Base{}, @@ -196,9 +182,7 @@ func TestCalculateIndividualFundingStatistics(t *testing.T) { func TestFundingStatisticsPrintResults(t *testing.T) { f := FundingStatistics{} err := f.PrintResults(false) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) funds, err := funding.SetupFundingManager(&engine.ExchangeManager{}, true, true, false) assert.NoError(t, err) @@ -224,9 +208,7 @@ func TestFundingStatisticsPrintResults(t *testing.T) { f.TotalUSDStatistics = &TotalFundingStatistics{} f.Report.DisableUSDTracking = false err = f.PrintResults(false) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received %v expected %v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) f.TotalUSDStatistics = &TotalFundingStatistics{ GeometricRatios: &Ratios{}, diff --git a/backtester/eventhandlers/statistics/statistics_test.go b/backtester/eventhandlers/statistics/statistics_test.go index bb4651d8..d5ab7e40 100644 --- a/backtester/eventhandlers/statistics/statistics_test.go +++ b/backtester/eventhandlers/statistics/statistics_test.go @@ -1,7 +1,6 @@ package statistics import ( - "errors" "testing" "time" @@ -50,9 +49,7 @@ func TestReset(t *testing.T) { s = nil err = s.Reset() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestAddDataEventForTime(t *testing.T) { @@ -63,9 +60,8 @@ func TestAddDataEventForTime(t *testing.T) { p := currency.NewBTCUSDT() s := Statistic{} err := s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&kline.Kline{ Base: &event.Base{ Exchange: exch, @@ -103,13 +99,11 @@ func TestAddSignalEventForTime(t *testing.T) { p := currency.NewBTCUSDT() s := Statistic{} err := s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&signal.Signal{}) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + s.ExchangeAssetPairStatistics = make(map[key.ExchangePairAsset]*CurrencyPairStatistic) b := &event.Base{} err = s.SetEventForOffset(&signal.Signal{ @@ -148,13 +142,11 @@ func TestAddExchangeEventForTime(t *testing.T) { p := currency.NewBTCUSDT() s := Statistic{} err := s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&order.Order{}) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + s.ExchangeAssetPairStatistics = make(map[key.ExchangePairAsset]*CurrencyPairStatistic) b := &event.Base{} @@ -194,13 +186,11 @@ func TestAddFillEventForTime(t *testing.T) { p := currency.NewBTCUSDT() s := Statistic{} err := s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&fill.Fill{}) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + s.ExchangeAssetPairStatistics = make(map[key.ExchangePairAsset]*CurrencyPairStatistic) b := &event.Base{} err = s.SetEventForOffset(&fill.Fill{ @@ -245,14 +235,11 @@ func TestAddHoldingsForTime(t *testing.T) { p := currency.NewBTCUSDT() s := Statistic{} err := s.AddHoldingsForTime(&holdings.Holding{}) - if !errors.Is(err, errExchangeAssetPairStatsUnset) { - t.Errorf("received: %v, expected: %v", err, errExchangeAssetPairStatsUnset) - } + assert.ErrorIs(t, err, errExchangeAssetPairStatsUnset) + s.ExchangeAssetPairStatistics = make(map[key.ExchangePairAsset]*CurrencyPairStatistic) err = s.AddHoldingsForTime(&holdings.Holding{}) - if !errors.Is(err, errCurrencyStatisticsUnset) { - t.Errorf("received: %v, expected: %v", err, errCurrencyStatisticsUnset) - } + assert.ErrorIs(t, err, errCurrencyStatisticsUnset) err = s.SetEventForOffset(&kline.Kline{ Base: &event.Base{ @@ -302,24 +289,19 @@ func TestAddComplianceSnapshotForTime(t *testing.T) { s := Statistic{} err := s.AddComplianceSnapshotForTime(nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.AddComplianceSnapshotForTime(nil, &fill.Fill{}) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) err = s.AddComplianceSnapshotForTime(&compliance.Snapshot{}, &fill.Fill{}) - if !errors.Is(err, errExchangeAssetPairStatsUnset) { - t.Errorf("received: %v, expected: %v", err, errExchangeAssetPairStatsUnset) - } + assert.ErrorIs(t, err, errExchangeAssetPairStatsUnset) + s.ExchangeAssetPairStatistics = make(map[key.ExchangePairAsset]*CurrencyPairStatistic) b := &event.Base{} err = s.AddComplianceSnapshotForTime(&compliance.Snapshot{}, &fill.Fill{Base: b}) - if !errors.Is(err, errCurrencyStatisticsUnset) { - t.Errorf("received: %v, expected: %v", err, errCurrencyStatisticsUnset) - } + assert.ErrorIs(t, err, errCurrencyStatisticsUnset) + b.Exchange = exch b.Time = tt b.Interval = gctkline.OneDay @@ -483,9 +465,8 @@ func TestPrintAllEventsChronologically(t *testing.T) { a := asset.Spot p := currency.NewBTCUSDT() err := s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&kline.Kline{ Base: &event.Base{ Exchange: exch, @@ -540,9 +521,7 @@ func TestCalculateTheResults(t *testing.T) { t.Parallel() s := Statistic{} err := s.CalculateAllResults() - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) tt := time.Now().Add(-gctkline.OneDay.Duration() * 7) tt2 := time.Now().Add(-gctkline.OneDay.Duration() * 6) @@ -551,9 +530,8 @@ func TestCalculateTheResults(t *testing.T) { p := currency.NewBTCUSDT() p2 := currency.NewPair(currency.XRP, currency.DOGE) err = s.SetEventForOffset(nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + err = s.SetEventForOffset(&kline.Kline{ Base: &event.Base{ Exchange: exch, @@ -741,13 +719,10 @@ func TestCalculateTheResults(t *testing.T) { s.FundManager = funds err = s.CalculateAllResults() - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received '%v' expected '%v'", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) + err = s.CalculateAllResults() - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received '%v' expected '%v'", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) funds, err = funding.SetupFundingManager(&engine.ExchangeManager{}, false, true, false) assert.NoError(t, err) @@ -760,9 +735,7 @@ func TestCalculateTheResults(t *testing.T) { s.FundManager = funds err = s.CalculateAllResults() - if !errors.Is(err, errMissingSnapshots) { - t.Errorf("received '%v' expected '%v'", err, errMissingSnapshots) - } + assert.ErrorIs(t, err, errMissingSnapshots) err = s.AddComplianceSnapshotForTime(&compliance.Snapshot{Timestamp: tt2}, signal4) assert.NoError(t, err) @@ -847,9 +820,7 @@ func TestCalculateBiggestEventDrawdown(t *testing.T) { }) _, err := CalculateBiggestEventDrawdown(nil) - if !errors.Is(err, errReceivedNoData) { - t.Errorf("received %v expected %v", err, errReceivedNoData) - } + assert.ErrorIs(t, err, errReceivedNoData) resp, err := CalculateBiggestEventDrawdown(events) assert.NoError(t, err) @@ -872,37 +843,27 @@ func TestCalculateBiggestEventDrawdown(t *testing.T) { }, } _, err = CalculateBiggestEventDrawdown(bogusEvent) - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received %v expected %v", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) } func TestCalculateBiggestValueAtTimeDrawdown(t *testing.T) { var interval gctkline.Interval _, err := CalculateBiggestValueAtTimeDrawdown(nil, interval) - if !errors.Is(err, errReceivedNoData) { - t.Errorf("received %v expected %v", err, errReceivedNoData) - } + assert.ErrorIs(t, err, errReceivedNoData) _, err = CalculateBiggestValueAtTimeDrawdown(nil, interval) - if !errors.Is(err, errReceivedNoData) { - t.Errorf("received %v expected %v", err, errReceivedNoData) - } + assert.ErrorIs(t, err, errReceivedNoData) } func TestAddPNLForTime(t *testing.T) { t.Parallel() s := &Statistic{} err := s.AddPNLForTime(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received %v expected %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) sum := &portfolio.PNLSummary{} err = s.AddPNLForTime(sum) - if !errors.Is(err, errExchangeAssetPairStatsUnset) { - t.Errorf("received %v expected %v", err, errExchangeAssetPairStatsUnset) - } + assert.ErrorIs(t, err, errExchangeAssetPairStatsUnset) tt := time.Now().Add(-gctkline.OneDay.Duration() * 7) exch := testExchange @@ -926,17 +887,13 @@ func TestAddPNLForTime(t *testing.T) { assert.NoError(t, err) err = s.AddPNLForTime(sum) - if !errors.Is(err, errCurrencyStatisticsUnset) { - t.Errorf("received %v expected %v", err, errCurrencyStatisticsUnset) - } + assert.ErrorIs(t, err, errCurrencyStatisticsUnset) sum.Exchange = exch sum.Asset = a sum.Pair = p err = s.AddPNLForTime(sum) - if !errors.Is(err, errNoDataAtOffset) { - t.Errorf("received %v expected %v", err, errNoDataAtOffset) - } + assert.ErrorIs(t, err, errNoDataAtOffset) sum.Offset = 1 err = s.AddPNLForTime(sum) diff --git a/backtester/eventhandlers/strategies/base/base_test.go b/backtester/eventhandlers/strategies/base/base_test.go index 75e67010..b614e6cb 100644 --- a/backtester/eventhandlers/strategies/base/base_test.go +++ b/backtester/eventhandlers/strategies/base/base_test.go @@ -1,7 +1,6 @@ package base import ( - "errors" "testing" "time" @@ -22,14 +21,11 @@ func TestGetBase(t *testing.T) { t.Parallel() s := Strategy{} _, err := s.GetBaseData(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received: %v, expected: %v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) _, err = s.GetBaseData(datakline.NewDataFromKline()) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + tt := time.Now() exch := "binance" a := asset.Spot @@ -80,7 +76,5 @@ func TestCloseAllPositions(t *testing.T) { t.Parallel() s := &Strategy{} _, err := s.CloseAllPositions(nil, nil) - if !errors.Is(err, gctcommon.ErrFunctionNotSupported) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrFunctionNotSupported) - } + assert.ErrorIs(t, err, gctcommon.ErrFunctionNotSupported) } diff --git a/backtester/eventhandlers/strategies/binancecashandcarry/binancecashandcarry_test.go b/backtester/eventhandlers/strategies/binancecashandcarry/binancecashandcarry_test.go index 8dc0ad3b..0003656e 100644 --- a/backtester/eventhandlers/strategies/binancecashandcarry/binancecashandcarry_test.go +++ b/backtester/eventhandlers/strategies/binancecashandcarry/binancecashandcarry_test.go @@ -1,12 +1,12 @@ package binancecashandcarry import ( - "errors" "testing" "time" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/backtester/common" "github.com/thrasher-corp/gocryptotrader/backtester/data" datakline "github.com/thrasher-corp/gocryptotrader/backtester/data/kline" @@ -67,23 +67,17 @@ func TestSetCustomSettings(t *testing.T) { mappalopalous[openShortDistancePercentageString] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[closeShortDistancePercentageString] = float14 mappalopalous[openShortDistancePercentageString] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[closeShortDistancePercentageString] = float14 mappalopalous["lol"] = float14 err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) } func TestOnSignal(t *testing.T) { @@ -92,9 +86,7 @@ func TestOnSignal(t *testing.T) { openShortDistancePercentage: decimal.NewFromInt(14), } _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, base.ErrSimultaneousProcessingOnly) { - t.Errorf("received: %v, expected: %v", err, base.ErrSimultaneousProcessingOnly) - } + assert.ErrorIs(t, err, base.ErrSimultaneousProcessingOnly) } func TestSetDefaults(t *testing.T) { @@ -141,9 +133,7 @@ func TestSortSignals(t *testing.T) { RangeHolder: &gctkline.IntervalRangeHolder{}, } _, err = sortSignals([]data.Handler{da}) - if !errors.Is(err, errNotSetup) { - t.Errorf("received: %v, expected: %v", err, errNotSetup) - } + assert.ErrorIs(t, err, errNotSetup) d2 := &data.Base{} err = d2.SetStream([]data.Event{&eventkline.Kline{ @@ -178,35 +168,23 @@ func TestSortSignals(t *testing.T) { func TestCreateSignals(t *testing.T) { t.Parallel() s := Strategy{} - expectedError := gctcommon.ErrNilPointer _, err := s.createSignals(nil, nil, nil, decimal.Zero, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) spotSignal := &signal.Signal{ Base: &event.Base{AssetType: asset.Spot}, } _, err = s.createSignals(nil, spotSignal, nil, decimal.Zero, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) // targeting first case - expectedError = nil futuresSignal := &signal.Signal{ Base: &event.Base{AssetType: asset.Futures}, } resp, err := s.createSignals(nil, spotSignal, futuresSignal, decimal.Zero, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } - if len(resp) != 1 { - t.Errorf("received '%v' expected '%v", len(resp), 1) - } - if resp[0].GetAssetType() != asset.Spot { - t.Errorf("received '%v' expected '%v", resp[0].GetAssetType(), asset.Spot) - } + require.NoError(t, err, "createSignals must not error") + require.Len(t, resp, 1, "createSignals must return one signal") + assert.Equal(t, asset.Spot, resp[0].GetAssetType()) // targeting second case: pos := []futures.Position{ @@ -215,80 +193,47 @@ func TestCreateSignals(t *testing.T) { }, } resp, err = s.createSignals(pos, spotSignal, futuresSignal, decimal.Zero, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } - if len(resp) != 2 { - t.Errorf("received '%v' expected '%v", len(resp), 2) - } + require.NoError(t, err, "createSignals must not error") + require.Len(t, resp, 2, "createSignals must return two signals") caseTested := false for i := range resp { if resp[i].GetAssetType().IsFutures() { - if resp[i].GetDirection() != gctorder.ClosePosition { - t.Errorf("received '%v' expected '%v", resp[i].GetDirection(), gctorder.ClosePosition) - } + assert.Equal(t, gctorder.ClosePosition, resp[i].GetDirection()) caseTested = true + break } } - if !caseTested { - t.Fatal("unhandled issue in test scenario") - } + require.True(t, caseTested, "Unhandled issue in test scenario") // targeting third case resp, err = s.createSignals(pos, spotSignal, futuresSignal, decimal.Zero, true) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } - if len(resp) != 2 { - t.Errorf("received '%v' expected '%v", len(resp), 2) - } + require.NoError(t, err, "createSignals must not error") + require.Len(t, resp, 2, "createSignals must return two signals") + caseTested = false for i := range resp { if resp[i].GetAssetType().IsFutures() { - if resp[i].GetDirection() != gctorder.ClosePosition { - t.Errorf("received '%v' expected '%v", resp[i].GetDirection(), gctorder.ClosePosition) - } + assert.Equal(t, gctorder.ClosePosition, resp[i].GetDirection()) caseTested = true + break } } - if !caseTested { - t.Fatal("unhandled issue in test scenario") - } + require.True(t, caseTested, "Unhandled issue in test scenario") // targeting first case after a cash and carry is completed, have a new one opened pos[0].Status = gctorder.Closed resp, err = s.createSignals(pos, spotSignal, futuresSignal, decimal.NewFromInt(1337), true) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } - if len(resp) != 1 { - t.Errorf("received '%v' expected '%v", len(resp), 1) - } - caseTested = false - for i := range resp { - if resp[i].GetAssetType() == asset.Spot { - if resp[i].GetDirection() != gctorder.Buy { - t.Errorf("received '%v' expected '%v", resp[i].GetDirection(), gctorder.Buy) - } - if resp[i].GetFillDependentEvent() == nil { - t.Errorf("received '%v' expected '%v'", nil, "fill dependent event") - } - caseTested = true - } - } - if !caseTested { - t.Fatal("unhandled issue in test scenario") - } + require.NoError(t, err, "createSignals must not error") + require.Len(t, resp, 1, "createSignals must return one signal") + assert.Equal(t, asset.Spot, resp[0].GetAssetType()) + assert.Equal(t, gctorder.Buy, resp[0].GetDirection()) + assert.NotNil(t, resp[0].GetFillDependentEvent(), "GetFillDependentEvent should not return nil") // targeting default case pos[0].Status = gctorder.UnknownStatus resp, err = s.createSignals(pos, spotSignal, futuresSignal, decimal.NewFromInt(1337), true) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v", err, expectedError) - } - if len(resp) != 2 { - t.Errorf("received '%v' expected '%v", len(resp), 2) - } + require.NoError(t, err, "createSignals must not error") + assert.Len(t, resp, 2, "createSignals should return two signals") } // fakeFunds overrides default implementation @@ -324,9 +269,7 @@ func TestOnSimultaneousSignals(t *testing.T) { t.Parallel() s := Strategy{} _, err := s.OnSimultaneousSignals(nil, nil, nil) - if !errors.Is(err, base.ErrNoDataToProcess) { - t.Errorf("received '%v' expected '%v", err, base.ErrNoDataToProcess) - } + assert.ErrorIs(t, err, base.ErrNoDataToProcess) cp := currency.NewBTCUSD() d := &datakline.DataFromKline{ @@ -363,15 +306,11 @@ func TestOnSimultaneousSignals(t *testing.T) { } f := &fakeFunds{} _, err = s.OnSimultaneousSignals(signals, f, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) p := &portfolerino{} _, err = s.OnSimultaneousSignals(signals, f, p) - if !errors.Is(err, errNotSetup) { - t.Errorf("received '%v' expected '%v", err, errNotSetup) - } + assert.ErrorIs(t, err, errNotSetup) d2 := &datakline.DataFromKline{ Base: &data.Base{}, diff --git a/backtester/eventhandlers/strategies/dollarcostaverage/dollarcostaverage_test.go b/backtester/eventhandlers/strategies/dollarcostaverage/dollarcostaverage_test.go index dfd8eb45..934f83ea 100644 --- a/backtester/eventhandlers/strategies/dollarcostaverage/dollarcostaverage_test.go +++ b/backtester/eventhandlers/strategies/dollarcostaverage/dollarcostaverage_test.go @@ -1,7 +1,6 @@ package dollarcostaverage import ( - "errors" "testing" "time" @@ -37,17 +36,13 @@ func TestSupportsSimultaneousProcessing(t *testing.T) { func TestSetCustomSettings(t *testing.T) { s := Strategy{} err := s.SetCustomSettings(nil) - if !errors.Is(err, base.ErrCustomSettingsUnsupported) { - t.Errorf("received: %v, expected: %v", err, base.ErrCustomSettingsUnsupported) - } + assert.ErrorIs(t, err, base.ErrCustomSettingsUnsupported) } func TestOnSignal(t *testing.T) { s := Strategy{} _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) dStart := time.Date(2020, 1, 0, 0, 0, 0, 0, time.UTC) dEnd := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) @@ -124,9 +119,8 @@ func TestOnSignal(t *testing.T) { func TestOnSignals(t *testing.T) { s := Strategy{} _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + dStart := time.Date(2020, 1, 0, 0, 0, 0, 0, time.UTC) dEnd := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) exch := "binance" diff --git a/backtester/eventhandlers/strategies/rsi/rsi_test.go b/backtester/eventhandlers/strategies/rsi/rsi_test.go index 8dff0f10..8a112502 100644 --- a/backtester/eventhandlers/strategies/rsi/rsi_test.go +++ b/backtester/eventhandlers/strategies/rsi/rsi_test.go @@ -1,13 +1,13 @@ package rsi import ( - "errors" "strings" "testing" "time" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/backtester/common" "github.com/thrasher-corp/gocryptotrader/backtester/data" "github.com/thrasher-corp/gocryptotrader/backtester/data/kline" @@ -54,39 +54,30 @@ func TestSetCustomSettings(t *testing.T) { mappalopalous[rsiPeriodKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[rsiPeriodKey] = float14 mappalopalous[rsiLowKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[rsiLowKey] = float14 mappalopalous[rsiHighKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[rsiHighKey] = float14 mappalopalous["lol"] = float14 err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) } func TestOnSignal(t *testing.T) { t.Parallel() s := Strategy{} _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + dStart := time.Date(2020, 1, 0, 0, 0, 0, 0, time.UTC) dEnd := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) exch := "binance" @@ -120,9 +111,7 @@ func TestOnSignal(t *testing.T) { } var resp signal.Event _, err = s.OnSignal(da, nil, nil) - if !errors.Is(err, base.ErrTooMuchBadData) { - t.Fatalf("expected: %v, received %v", base.ErrTooMuchBadData, err) - } + require.ErrorIs(t, err, base.ErrTooMuchBadData) s.rsiPeriod = decimal.NewFromInt(1) _, err = s.OnSignal(da, nil, nil) @@ -166,9 +155,8 @@ func TestOnSignals(t *testing.T) { t.Parallel() s := Strategy{} _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, common.ErrNilEvent) { - t.Errorf("received: %v, expected: %v", err, common.ErrNilEvent) - } + assert.ErrorIs(t, err, common.ErrNilEvent) + dInsert := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) exch := "binance" a := asset.Spot diff --git a/backtester/eventhandlers/strategies/strategies_test.go b/backtester/eventhandlers/strategies/strategies_test.go index dd1cb920..cd6812ab 100644 --- a/backtester/eventhandlers/strategies/strategies_test.go +++ b/backtester/eventhandlers/strategies/strategies_test.go @@ -1,7 +1,6 @@ package strategies import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -26,13 +25,10 @@ func TestLoadStrategyByName(t *testing.T) { t.Parallel() var resp Handler _, err := LoadStrategyByName("test", false) - if !errors.Is(err, base.ErrStrategyNotFound) { - t.Errorf("received: %v, expected: %v", err, base.ErrStrategyNotFound) - } + assert.ErrorIs(t, err, base.ErrStrategyNotFound) + _, err = LoadStrategyByName("test", true) - if !errors.Is(err, base.ErrStrategyNotFound) { - t.Errorf("received: %v, expected: %v", err, base.ErrStrategyNotFound) - } + assert.ErrorIs(t, err, base.ErrStrategyNotFound) resp, err = LoadStrategyByName(dollarcostaverage.Name, false) assert.NoError(t, err) @@ -60,13 +56,10 @@ func TestLoadStrategyByName(t *testing.T) { func TestAddStrategy(t *testing.T) { t.Parallel() err := AddStrategy(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) + err = AddStrategy(new(dollarcostaverage.Strategy)) - if !errors.Is(err, ErrStrategyAlreadyExists) { - t.Errorf("received '%v' expected '%v'", err, ErrStrategyAlreadyExists) - } + assert.ErrorIs(t, err, ErrStrategyAlreadyExists) err = AddStrategy(new(customStrategy)) assert.NoError(t, err) diff --git a/backtester/eventhandlers/strategies/top2bottom2/top2bottom2_test.go b/backtester/eventhandlers/strategies/top2bottom2/top2bottom2_test.go index cddf9b14..3f9455e1 100644 --- a/backtester/eventhandlers/strategies/top2bottom2/top2bottom2_test.go +++ b/backtester/eventhandlers/strategies/top2bottom2/top2bottom2_test.go @@ -1,7 +1,6 @@ package top2bottom2 import ( - "errors" "strings" "testing" "time" @@ -61,47 +60,37 @@ func TestSetCustomSettings(t *testing.T) { mappalopalous[mfiPeriodKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[mfiPeriodKey] = float14 mappalopalous[mfiLowKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[mfiLowKey] = float14 mappalopalous[mfiHighKey] = "14" err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) mappalopalous[mfiHighKey] = float14 mappalopalous["lol"] = float14 err = s.SetCustomSettings(mappalopalous) - if !errors.Is(err, base.ErrInvalidCustomSettings) { - t.Errorf("received: %v, expected: %v", err, base.ErrInvalidCustomSettings) - } + assert.ErrorIs(t, err, base.ErrInvalidCustomSettings) } func TestOnSignal(t *testing.T) { t.Parallel() s := Strategy{} - if _, err := s.OnSignal(nil, nil, nil); !errors.Is(err, errStrategyOnlySupportsSimultaneousProcessing) { - t.Errorf("received: %v, expected: %v", err, errStrategyOnlySupportsSimultaneousProcessing) - } + _, err := s.OnSignal(nil, nil, nil) + assert.ErrorIs(t, err, errStrategyOnlySupportsSimultaneousProcessing) } func TestOnSignals(t *testing.T) { t.Parallel() s := Strategy{} _, err := s.OnSignal(nil, nil, nil) - if !errors.Is(err, errStrategyOnlySupportsSimultaneousProcessing) { - t.Errorf("received: %v, expected: %v", err, errStrategyOnlySupportsSimultaneousProcessing) - } + assert.ErrorIs(t, err, errStrategyOnlySupportsSimultaneousProcessing) + dInsert := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) exch := "binance" a := asset.Spot diff --git a/backtester/funding/collateralpair_test.go b/backtester/funding/collateralpair_test.go index fc9c6604..e6a9b258 100644 --- a/backtester/funding/collateralpair_test.go +++ b/backtester/funding/collateralpair_test.go @@ -1,10 +1,11 @@ package funding import ( - "errors" "testing" "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" gctorder "github.com/thrasher-corp/gocryptotrader/exchanges/order" @@ -32,11 +33,8 @@ func TestCollateralTakeProfit(t *testing.T) { available: decimal.NewFromInt(1), }, } - var expectedError error err := c.TakeProfit(decimal.NewFromInt(1), decimal.NewFromInt(1)) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "TakeProfit should not error") } func TestCollateralCollateralCurrency(t *testing.T) { @@ -85,9 +83,8 @@ func TestCollateralGetPairReader(t *testing.T) { contract: &Item{}, collateral: &Item{}, } - if _, err := c.GetPairReader(); !errors.Is(err, ErrNotPair) { - t.Errorf("received '%v' expected '%v'", err, ErrNotPair) - } + _, err := c.GetPairReader() + assert.ErrorIs(t, err, ErrNotPair) } func TestCollateralGetCollateralReader(t *testing.T) { @@ -95,20 +92,14 @@ func TestCollateralGetCollateralReader(t *testing.T) { c := &CollateralPair{ collateral: &Item{available: decimal.NewFromInt(1337)}, } - var expectedError error cr, err := c.GetCollateralReader() - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if cr != c { - t.Error("expected the same thing") - } + require.NoError(t, err, "GetCollateralReader must not error") + assert.Equal(t, cr, c) } func TestCollateralUpdateContracts(t *testing.T) { t.Parallel() b := gctorder.Buy - var expectedError error c := &CollateralPair{ collateral: &Item{ asset: asset.Futures, @@ -119,26 +110,23 @@ func TestCollateralUpdateContracts(t *testing.T) { } leet := decimal.NewFromInt(1337) err := c.UpdateContracts(gctorder.Buy, leet) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "UpdateContracts should not error") + if !c.contract.available.Equal(leet) { t.Errorf("received '%v' expected '%v'", c.contract.available, leet) } b = gctorder.Sell err = c.UpdateContracts(gctorder.Buy, leet) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "UpdateContracts should not error") + if !c.contract.available.Equal(decimal.Zero) { t.Errorf("received '%v' expected '%v'", c.contract.available, decimal.Zero) } c.currentDirection = nil err = c.UpdateContracts(gctorder.Buy, leet) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "UpdateContracts should not error") + if !c.contract.available.Equal(leet) { t.Errorf("received '%v' expected '%v'", c.contract.available, leet) } @@ -156,24 +144,15 @@ func TestCollateralReleaseContracts(t *testing.T) { currentDirection: &b, } - expectedError := errPositiveOnly err := c.ReleaseContracts(decimal.Zero) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, errPositiveOnly) - expectedError = errCannotAllocate err = c.ReleaseContracts(decimal.NewFromInt(1337)) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, errCannotAllocate) - expectedError = nil c.contract.available = decimal.NewFromInt(1337) err = c.ReleaseContracts(decimal.NewFromInt(1337)) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "ReleaseContracts should not error") } func TestCollateralFundReader(t *testing.T) { @@ -192,9 +171,8 @@ func TestCollateralPairReleaser(t *testing.T) { collateral: &Item{}, contract: &Item{}, } - if _, err := c.PairReleaser(); !errors.Is(err, ErrNotPair) { - t.Errorf("received '%v' expected '%v'", err, ErrNotPair) - } + _, err := c.PairReleaser() + assert.ErrorIs(t, err, ErrNotPair) } func TestCollateralFundReserver(t *testing.T) { @@ -213,10 +191,8 @@ func TestCollateralCollateralReleaser(t *testing.T) { collateral: &Item{}, contract: &Item{}, } - var expectedError error - if _, err := c.CollateralReleaser(); !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + _, err := c.CollateralReleaser() + assert.NoError(t, err, "CollateralReleaser should not error") } func TestCollateralFundReleaser(t *testing.T) { @@ -239,45 +215,22 @@ func TestCollateralReserve(t *testing.T) { }, contract: &Item{asset: asset.Futures}, } - var expectedError error err := c.Reserve(decimal.NewFromInt(1), gctorder.Long) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if !c.collateral.reserved.Equal(decimal.NewFromInt(1)) { - t.Errorf("received '%v' expected '%v'", c.collateral.reserved, decimal.NewFromInt(1)) - } - if !c.collateral.available.Equal(decimal.NewFromInt(1336)) { - t.Errorf("received '%v' expected '%v'", c.collateral.available, decimal.NewFromInt(1336)) - } + require.NoError(t, err, "Reserve must not error") + assert.Equal(t, decimal.NewFromInt(1), c.collateral.reserved) + assert.Equal(t, decimal.NewFromInt(1336), c.collateral.available) err = c.Reserve(decimal.NewFromInt(1), gctorder.Short) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if !c.collateral.reserved.Equal(decimal.NewFromInt(2)) { - t.Errorf("received '%v' expected '%v'", c.collateral.reserved, decimal.NewFromInt(2)) - } - if !c.collateral.available.Equal(decimal.NewFromInt(1335)) { - t.Errorf("received '%v' expected '%v'", c.collateral.available, decimal.NewFromInt(1335)) - } + require.NoError(t, err, "Reserve must not error") + assert.Equal(t, decimal.NewFromInt(2), c.collateral.reserved) + assert.Equal(t, decimal.NewFromInt(1335), c.collateral.available) err = c.Reserve(decimal.NewFromInt(2), gctorder.ClosePosition) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if !c.collateral.reserved.Equal(decimal.NewFromInt(4)) { - t.Errorf("received '%v' expected '%v'", c.collateral.reserved, decimal.Zero) - } - if !c.collateral.available.Equal(decimal.NewFromInt(1333)) { - t.Errorf("received '%v' expected '%v'", c.collateral.available, decimal.NewFromInt(1333)) - } - - expectedError = errCannotAllocate + require.NoError(t, err, "Reserve must not error") + assert.Equal(t, decimal.NewFromInt(4), c.collateral.reserved) + assert.Equal(t, decimal.NewFromInt(1333), c.collateral.available) err = c.Reserve(decimal.NewFromInt(2), gctorder.Buy) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, errCannotAllocate) } func TestCollateralLiquidate(t *testing.T) { diff --git a/backtester/funding/funding_test.go b/backtester/funding/funding_test.go index 8d78e027..9223fbcc 100644 --- a/backtester/funding/funding_test.go +++ b/backtester/funding/funding_test.go @@ -1,7 +1,6 @@ package funding import ( - "errors" "testing" "time" @@ -98,32 +97,24 @@ func TestTransfer(t *testing.T) { items: nil, } err := f.Transfer(decimal.Zero, nil, nil, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + err = f.Transfer(decimal.Zero, &Item{}, nil, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + err = f.Transfer(decimal.Zero, &Item{}, &Item{}, false) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = f.Transfer(elite, &Item{}, &Item{}, false) - if !errors.Is(err, errNotEnoughFunds) { - t.Errorf("received '%v' expected '%v'", err, errNotEnoughFunds) - } + assert.ErrorIs(t, err, errNotEnoughFunds) + item1 := &Item{exchange: "hello", asset: a, currency: base, available: elite} err = f.Transfer(elite, item1, item1, false) - if !errors.Is(err, errCannotTransferToSameFunds) { - t.Errorf("received '%v' expected '%v'", err, errCannotTransferToSameFunds) - } + assert.ErrorIs(t, err, errCannotTransferToSameFunds) item2 := &Item{exchange: "hello", asset: a, currency: quote} err = f.Transfer(elite, item1, item2, false) - if !errors.Is(err, errTransferMustBeSameCurrency) { - t.Errorf("received '%v' expected '%v'", err, errTransferMustBeSameCurrency) - } + assert.ErrorIs(t, err, errTransferMustBeSameCurrency) item2.exchange = "moto" item2.currency = base @@ -159,9 +150,7 @@ func TestAddItem(t *testing.T) { assert.NoError(t, err) err = f.AddItem(baseItem) - if !errors.Is(err, ErrAlreadyExists) { - t.Errorf("received '%v' expected '%v'", err, ErrAlreadyExists) - } + assert.ErrorIs(t, err, ErrAlreadyExists) } func TestExists(t *testing.T) { @@ -237,9 +226,7 @@ func TestAddPair(t *testing.T) { assert.NoError(t, err) err = f.AddPair(p) - if !errors.Is(err, ErrAlreadyExists) { - t.Errorf("received '%v' expected '%v'", err, ErrAlreadyExists) - } + assert.ErrorIs(t, err, ErrAlreadyExists) } func TestGetFundingForEvent(t *testing.T) { @@ -247,9 +234,8 @@ func TestGetFundingForEvent(t *testing.T) { e := &fakeEvent{} f := FundManager{} _, err := f.GetFundingForEvent(e) - if !errors.Is(err, ErrFundsNotFound) { - t.Errorf("received '%v' expected '%v'", err, ErrFundsNotFound) - } + assert.ErrorIs(t, err, ErrFundsNotFound) + baseItem, err := CreateItem(exchName, a, pair.Base, decimal.Zero, decimal.Zero) assert.NoError(t, err) @@ -270,9 +256,8 @@ func TestGetFundingForEAP(t *testing.T) { t.Parallel() f := FundManager{} _, err := f.getFundingForEAP(exchName, a, pair) - if !errors.Is(err, ErrFundsNotFound) { - t.Errorf("received '%v' expected '%v'", err, ErrFundsNotFound) - } + assert.ErrorIs(t, err, ErrFundsNotFound) + baseItem, err := CreateItem(exchName, a, pair.Base, decimal.Zero, decimal.Zero) assert.NoError(t, err) @@ -289,20 +274,16 @@ func TestGetFundingForEAP(t *testing.T) { assert.NoError(t, err) _, err = CreatePair(baseItem, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + _, err = CreatePair(nil, quoteItem) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + p, err = CreatePair(baseItem, quoteItem) assert.NoError(t, err) err = f.AddPair(p) - if !errors.Is(err, ErrAlreadyExists) { - t.Errorf("received '%v' expected '%v'", err, ErrAlreadyExists) - } + assert.ErrorIs(t, err, ErrAlreadyExists) } func TestGenerateReport(t *testing.T) { @@ -392,9 +373,7 @@ func TestCreateSnapshot(t *testing.T) { t.Parallel() f := FundManager{} err := f.CreateSnapshot(time.Time{}) - if !errors.Is(err, gctcommon.ErrDateUnset) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrDateUnset) - } + assert.ErrorIs(t, err, gctcommon.ErrDateUnset) f.items = append(f.items, &Item{}) dfk := &kline.DataFromKline{ @@ -408,9 +387,7 @@ func TestCreateSnapshot(t *testing.T) { }, } err = dfk.Load() - if !errors.Is(err, data.ErrInvalidEventSupplied) { - t.Errorf("received '%v' expected '%v'", err, nil) - } + assert.ErrorIs(t, err, data.ErrInvalidEventSupplied) f.items = append(f.items, &Item{ exchange: "test", @@ -430,14 +407,10 @@ func TestAddUSDTrackingData(t *testing.T) { t.Parallel() f := FundManager{} err := f.AddUSDTrackingData(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) err = f.AddUSDTrackingData(kline.NewDataFromKline()) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) dfk := &kline.DataFromKline{ Base: &data.Base{}, @@ -450,9 +423,8 @@ func TestAddUSDTrackingData(t *testing.T) { }, } err = dfk.Load() - if !errors.Is(err, data.ErrInvalidEventSupplied) { - t.Errorf("received '%v' expected '%v'", err, data.ErrInvalidEventSupplied) - } + assert.ErrorIs(t, err, data.ErrInvalidEventSupplied) + quoteItem, err := CreateItem(exchName, a, pair.Quote, elite, decimal.Zero) assert.NoError(t, err) @@ -461,15 +433,11 @@ func TestAddUSDTrackingData(t *testing.T) { f.disableUSDTracking = true err = f.AddUSDTrackingData(dfk) - if !errors.Is(err, ErrUSDTrackingDisabled) { - t.Errorf("received '%v' expected '%v'", err, ErrUSDTrackingDisabled) - } + assert.ErrorIs(t, err, ErrUSDTrackingDisabled) f.disableUSDTracking = false err = f.AddUSDTrackingData(dfk) - if !errors.Is(err, errCannotMatchTrackingToItem) { - t.Errorf("received '%v' expected '%v'", err, errCannotMatchTrackingToItem) - } + assert.ErrorIs(t, err, errCannotMatchTrackingToItem) dfk = &kline.DataFromKline{ Base: &data.Base{}, @@ -517,9 +485,8 @@ func TestFundingLiquidate(t *testing.T) { t.Parallel() f := FundManager{} err := f.Liquidate(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + f.items = append(f.items, &Item{ exchange: "test", asset: asset.Spot, @@ -545,9 +512,8 @@ func TestHasExchangeBeenLiquidated(t *testing.T) { t.Parallel() f := FundManager{} err := f.Liquidate(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + f.items = append(f.items, &Item{ exchange: "test", asset: asset.Spot, @@ -626,20 +592,12 @@ func TestRealisePNL(t *testing.T) { isCollateral: true, }) - var expectedError error err := f.RealisePNL("test", asset.Futures, currency.BTC, decimal.NewFromInt(1)) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if !f.items[0].available.Equal(decimal.NewFromInt(1337)) { - t.Errorf("received '%v' expected '%v'", f.items[0].available, decimal.NewFromInt(1337)) - } + require.NoError(t, err, "RealisePNL must not error") + assert.Equal(t, decimal.NewFromInt(1337), f.items[0].available) - expectedError = ErrFundsNotFound err = f.RealisePNL("test2", asset.Futures, currency.BTC, decimal.NewFromInt(1)) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, ErrFundsNotFound) } func TestCreateCollateral(t *testing.T) { @@ -658,32 +616,21 @@ func TestCreateCollateral(t *testing.T) { available: decimal.NewFromInt(1336), } - var expectedError error _, err := CreateCollateral(collat, contract) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "CreateCollateral should not error") - expectedError = gctcommon.ErrNilPointer _, err = CreateCollateral(nil, contract) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) _, err = CreateCollateral(collat, nil) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) } func TestUpdateCollateral(t *testing.T) { t.Parallel() f := &FundManager{} - expectedError := common.ErrNilEvent err := f.UpdateCollateralForEvent(nil, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, common.ErrNilEvent) ev := &signal.Signal{ Base: &event.Base{ @@ -700,22 +647,16 @@ func TestUpdateCollateral(t *testing.T) { }) em := engine.NewExchangeManager() exch, err := em.NewExchangeByName(exchName) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) exch.SetDefaults() err = em.Add(exch) require.NoError(t, err) f.exchangeManager = em - expectedError = nil err = f.UpdateCollateralForEvent(ev, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.NoError(t, err, "UpdateCollateralForEvent should not error") - expectedError = gctcommon.ErrNotYetImplemented f.items = append(f.items, &Item{ exchange: exchName, asset: asset.Futures, @@ -724,9 +665,7 @@ func TestUpdateCollateral(t *testing.T) { isCollateral: true, }) err = f.UpdateCollateralForEvent(ev, false) - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) } func TestCreateFuturesCurrencyCode(t *testing.T) { @@ -740,20 +679,14 @@ func TestLinkCollateralCurrency(t *testing.T) { t.Parallel() f := FundManager{} err := f.LinkCollateralCurrency(nil, currency.EMPTYCODE) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) item := &Item{} err = f.LinkCollateralCurrency(item, currency.EMPTYCODE) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) err = f.LinkCollateralCurrency(item, currency.BTC) - if !errors.Is(err, errNotFutures) { - t.Errorf("received '%v', expected '%v'", err, errNotFutures) - } + assert.ErrorIs(t, err, errNotFutures) item.asset = asset.Futures err = f.LinkCollateralCurrency(item, currency.BTC) @@ -764,9 +697,7 @@ func TestLinkCollateralCurrency(t *testing.T) { } err = f.LinkCollateralCurrency(item, currency.LTC) - if !errors.Is(err, ErrAlreadyExists) { - t.Errorf("received '%v', expected '%v'", err, ErrAlreadyExists) - } + assert.ErrorIs(t, err, ErrAlreadyExists) f.items = append(f.items, item.pairedWith) item.pairedWith = nil @@ -778,25 +709,17 @@ func TestSetFunding(t *testing.T) { t.Parallel() f := &FundManager{} err := f.SetFunding("", 0, nil, false) - if !errors.Is(err, engine.ErrExchangeNameIsEmpty) { - t.Errorf("received '%v', expected '%v'", err, engine.ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, engine.ErrExchangeNameIsEmpty) err = f.SetFunding(exchName, 0, nil, false) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) err = f.SetFunding(exchName, asset.Spot, nil, false) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) bal := &account.Balance{} err = f.SetFunding(exchName, asset.Spot, bal, false) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("received '%v', expected '%v'", err, currency.ErrCurrencyCodeEmpty) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) bal.Currency = currency.BTC bal.Total = 1337 @@ -829,9 +752,7 @@ func TestUpdateFundingFromLiveData(t *testing.T) { t.Parallel() f := &FundManager{} err := f.UpdateFundingFromLiveData(false) - if !errors.Is(err, engine.ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, engine.ErrNilSubsystem) - } + assert.ErrorIs(t, err, engine.ErrNilSubsystem) f.exchangeManager = engine.NewExchangeManager() err = f.UpdateFundingFromLiveData(false) @@ -843,9 +764,7 @@ func TestUpdateFundingFromLiveData(t *testing.T) { require.NoError(t, err) err = f.UpdateFundingFromLiveData(false) - if !errors.Is(err, exchange.ErrCredentialsAreEmpty) { - t.Errorf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) - } + assert.ErrorIs(t, err, exchange.ErrCredentialsAreEmpty) // enter api keys to gain coverage here apiKey := "" @@ -868,9 +787,7 @@ func TestUpdateAllCollateral(t *testing.T) { t.Parallel() f := &FundManager{} err := f.UpdateAllCollateral(false, false) - if !errors.Is(err, engine.ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, engine.ErrNilSubsystem) - } + assert.ErrorIs(t, err, engine.ErrNilSubsystem) f.exchangeManager = engine.NewExchangeManager() err = f.UpdateAllCollateral(false, false) @@ -882,9 +799,7 @@ func TestUpdateAllCollateral(t *testing.T) { require.NoError(t, err) err = f.UpdateAllCollateral(false, false) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) f.items = []*Item{ { @@ -895,9 +810,7 @@ func TestUpdateAllCollateral(t *testing.T) { }, } err = f.UpdateAllCollateral(false, false) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) f.items[0].trackingCandles = kline.NewDataFromKline() err = f.items[0].trackingCandles.SetStream([]data.Event{ @@ -906,15 +819,11 @@ func TestUpdateAllCollateral(t *testing.T) { assert.NoError(t, err) err = f.UpdateAllCollateral(false, false) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) f.items[0].asset = asset.Futures err = f.UpdateAllCollateral(false, false) - if !errors.Is(err, gctcommon.ErrNotYetImplemented) { - t.Errorf("received '%v', expected '%v'", err, gctcommon.ErrNotYetImplemented) - } + assert.ErrorIs(t, err, gctcommon.ErrNotYetImplemented) apiKey := "" apiSec := "" diff --git a/backtester/funding/item_test.go b/backtester/funding/item_test.go index b5e3f5aa..4d74a736 100644 --- a/backtester/funding/item_test.go +++ b/backtester/funding/item_test.go @@ -1,7 +1,6 @@ package funding import ( - "errors" "testing" "github.com/shopspring/decimal" @@ -53,33 +52,24 @@ func TestReserve(t *testing.T) { t.Parallel() i := Item{} err := i.Reserve(decimal.Zero) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = i.Reserve(elite) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) i.reserved = elite err = i.Reserve(elite) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) i.available = elite err = i.Reserve(elite) assert.NoError(t, err) err = i.Reserve(elite) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) err = i.Reserve(neg) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) } func TestIncreaseAvailable(t *testing.T) { @@ -92,26 +82,21 @@ func TestIncreaseAvailable(t *testing.T) { t.Errorf("expected %v", elite) } err = i.IncreaseAvailable(decimal.Zero) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = i.IncreaseAvailable(neg) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) } func TestRelease(t *testing.T) { t.Parallel() i := Item{} err := i.Release(decimal.Zero, decimal.Zero) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = i.Release(elite, decimal.Zero) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) + i.reserved = elite err = i.Release(elite, decimal.Zero) assert.NoError(t, err) @@ -121,13 +106,10 @@ func TestRelease(t *testing.T) { assert.NoError(t, err) err = i.Release(neg, decimal.Zero) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = i.Release(elite, neg) - if !errors.Is(err, errNegativeAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errNegativeAmountReceived) - } + assert.ErrorIs(t, err, errNegativeAmountReceived) } func TestMatchesCurrency(t *testing.T) { diff --git a/backtester/funding/spotpair_test.go b/backtester/funding/spotpair_test.go index aa76d206..2becf3be 100644 --- a/backtester/funding/spotpair_test.go +++ b/backtester/funding/spotpair_test.go @@ -1,11 +1,11 @@ package funding import ( - "errors" "testing" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" gctorder "github.com/thrasher-corp/gocryptotrader/exchanges/order" ) @@ -89,24 +89,19 @@ func TestReservePair(t *testing.T) { quoteItem.pairedWith = baseItem pairItems := SpotPair{base: baseItem, quote: quoteItem} err = pairItems.Reserve(decimal.Zero, gctorder.Buy) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = pairItems.Reserve(elite, gctorder.Buy) assert.NoError(t, err) err = pairItems.Reserve(decimal.Zero, gctorder.Sell) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = pairItems.Reserve(elite, gctorder.Sell) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) + err = pairItems.Reserve(elite, gctorder.DoNothing) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) } func TestReleasePair(t *testing.T) { @@ -121,46 +116,34 @@ func TestReleasePair(t *testing.T) { quoteItem.pairedWith = baseItem pairItems := SpotPair{base: baseItem, quote: quoteItem} err = pairItems.Reserve(decimal.Zero, gctorder.Buy) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = pairItems.Reserve(elite, gctorder.Buy) assert.NoError(t, err) err = pairItems.Reserve(decimal.Zero, gctorder.Sell) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = pairItems.Reserve(elite, gctorder.Sell) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) err = pairItems.Release(decimal.Zero, decimal.Zero, gctorder.Buy) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + err = pairItems.Release(elite, decimal.Zero, gctorder.Buy) assert.NoError(t, err) err = pairItems.Release(elite, decimal.Zero, gctorder.Buy) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) err = pairItems.Release(elite, decimal.Zero, gctorder.DoNothing) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) err = pairItems.Release(elite, decimal.Zero, gctorder.Sell) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) + err = pairItems.Release(decimal.Zero, decimal.Zero, gctorder.Sell) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) } func TestIncreaseAvailablePair(t *testing.T) { @@ -175,24 +158,21 @@ func TestIncreaseAvailablePair(t *testing.T) { quoteItem.pairedWith = baseItem pairItems := SpotPair{base: baseItem, quote: quoteItem} err = pairItems.IncreaseAvailable(decimal.Zero, gctorder.Buy) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + if !pairItems.quote.available.Equal(elite) { t.Errorf("received '%v' expected '%v'", elite, pairItems.quote.available) } err = pairItems.IncreaseAvailable(decimal.Zero, gctorder.Sell) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + if !pairItems.base.available.IsZero() { t.Errorf("received '%v' expected '%v'", decimal.Zero, pairItems.base.available) } err = pairItems.IncreaseAvailable(elite.Neg(), gctorder.Sell) - if !errors.Is(err, errZeroAmountReceived) { - t.Errorf("received '%v' expected '%v'", err, errZeroAmountReceived) - } + assert.ErrorIs(t, err, errZeroAmountReceived) + if !pairItems.quote.available.Equal(elite) { t.Errorf("received '%v' expected '%v'", elite, pairItems.quote.available) } @@ -204,9 +184,8 @@ func TestIncreaseAvailablePair(t *testing.T) { } err = pairItems.IncreaseAvailable(elite, gctorder.DoNothing) - if !errors.Is(err, errCannotAllocate) { - t.Errorf("received '%v' expected '%v'", err, errCannotAllocate) - } + assert.ErrorIs(t, err, errCannotAllocate) + if !pairItems.base.available.Equal(elite) { t.Errorf("received '%v' expected '%v'", elite, pairItems.base.available) } @@ -243,14 +222,9 @@ func TestGetPairReader(t *testing.T) { p := &SpotPair{ base: &Item{exchange: "hello"}, } - var expectedError error ip, err := p.GetPairReader() - if !errors.Is(err, expectedError) { - t.Errorf("received '%v' expected '%v'", err, expectedError) - } - if ip != p { - t.Error("expected the same thing") - } + require.NoError(t, err, "GetPairReader must not error") + assert.Equal(t, p, ip) } func TestGetCollateralReader(t *testing.T) { @@ -258,9 +232,8 @@ func TestGetCollateralReader(t *testing.T) { p := &SpotPair{ base: &Item{exchange: "hello"}, } - if _, err := p.GetCollateralReader(); !errors.Is(err, ErrNotCollateral) { - t.Errorf("received '%v' expected '%v'", err, ErrNotCollateral) - } + _, err := p.GetCollateralReader() + assert.ErrorIs(t, err, ErrNotCollateral) } func TestFundReader(t *testing.T) { @@ -307,9 +280,8 @@ func TestCollateralReleaser(t *testing.T) { p := &SpotPair{ base: &Item{exchange: "hello"}, } - if _, err := p.CollateralReleaser(); !errors.Is(err, ErrNotCollateral) { - t.Errorf("received '%v' expected '%v'", err, ErrNotCollateral) - } + _, err := p.GetCollateralReader() + assert.ErrorIs(t, err, ErrNotCollateral) } func TestLiquidate(t *testing.T) { diff --git a/backtester/funding/trackingcurrencies/trackingcurrencies_test.go b/backtester/funding/trackingcurrencies/trackingcurrencies_test.go index 55a37fee..85e17815 100644 --- a/backtester/funding/trackingcurrencies/trackingcurrencies_test.go +++ b/backtester/funding/trackingcurrencies/trackingcurrencies_test.go @@ -1,7 +1,6 @@ package trackingcurrencies import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -22,20 +21,14 @@ func TestCreateUSDTrackingPairs(t *testing.T) { t.Parallel() _, err := CreateUSDTrackingPairs(nil, nil) - if !errors.Is(err, errNilPairsReceived) { - t.Errorf("received '%v' expected '%v'", err, errNilPairsReceived) - } + assert.ErrorIs(t, err, errNilPairsReceived) _, err = CreateUSDTrackingPairs([]TrackingPair{{}}, nil) - if !errors.Is(err, errExchangeManagerRequired) { - t.Errorf("received '%v' expected '%v'", err, errExchangeManagerRequired) - } + assert.ErrorIs(t, err, errExchangeManagerRequired) em := engine.NewExchangeManager() _, err = CreateUSDTrackingPairs([]TrackingPair{{Exchange: eName}}, em) - if !errors.Is(err, engine.ErrExchangeNotFound) { - t.Errorf("received '%v' expected '%v'", err, engine.ErrExchangeNotFound) - } + assert.ErrorIs(t, err, engine.ErrExchangeNotFound) s1 := TrackingPair{ Exchange: eName, @@ -142,9 +135,8 @@ func TestFindMatchingUSDPairs(t *testing.T) { t.Run(tt.description, func(t *testing.T) { t.Parallel() basePair, quotePair, err := findMatchingUSDPairs(tt.initialPair, tt.availablePairs) - if !errors.Is(err, tt.expectedErr) { - t.Fatalf("'%v' received '%v' expected '%v'", tt.description, err, tt.expectedErr) - } + require.ErrorIs(t, err, tt.expectedErr) + if basePair != tt.basePair { t.Fatalf("'%v' received '%v' expected '%v'", tt.description, basePair, tt.basePair) } diff --git a/backtester/plugins/strategies/loader_test.go b/backtester/plugins/strategies/loader_test.go index be4a2f26..569466ba 100644 --- a/backtester/plugins/strategies/loader_test.go +++ b/backtester/plugins/strategies/loader_test.go @@ -1,7 +1,6 @@ package strategies import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -17,14 +16,10 @@ import ( func TestAddStrategies(t *testing.T) { t.Parallel() err := addStrategies(nil) - if !errors.Is(err, errNoStrategies) { - t.Error(err) - } + assert.ErrorIs(t, err, errNoStrategies) err = addStrategies([]strategies.Handler{&dollarcostaverage.Strategy{}}) - if !errors.Is(err, strategies.ErrStrategyAlreadyExists) { - t.Error(err) - } + assert.ErrorIs(t, err, strategies.ErrStrategyAlreadyExists) err = addStrategies([]strategies.Handler{&CustomStrategy{}}) assert.NoError(t, err) diff --git a/backtester/report/chart_test.go b/backtester/report/chart_test.go index fc133a4c..be91103e 100644 --- a/backtester/report/chart_test.go +++ b/backtester/report/chart_test.go @@ -1,7 +1,6 @@ package report import ( - "errors" "testing" "time" @@ -24,9 +23,8 @@ import ( func TestCreateUSDTotalsChart(t *testing.T) { t.Parallel() _, err := createUSDTotalsChart(nil, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + tt := time.Now() items := []statistics.ValueAtTime{ { @@ -36,9 +34,8 @@ func TestCreateUSDTotalsChart(t *testing.T) { }, } _, err = createUSDTotalsChart(items, nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + stats := []statistics.FundingItemStatistics{ { ReportItem: &funding.ReportItem{ @@ -68,9 +65,8 @@ func TestCreateUSDTotalsChart(t *testing.T) { func TestCreateHoldingsOverTimeChart(t *testing.T) { t.Parallel() _, err := createHoldingsOverTimeChart(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) + tt := time.Now() items := []statistics.FundingItemStatistics{ { @@ -101,9 +97,7 @@ func TestCreateHoldingsOverTimeChart(t *testing.T) { func TestCreatePNLCharts(t *testing.T) { t.Parallel() _, err := createPNLCharts(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) tt := time.Now() var d Data @@ -160,9 +154,7 @@ func TestCreatePNLCharts(t *testing.T) { func TestCreateFuturesSpotDiffChart(t *testing.T) { t.Parallel() _, err := createFuturesSpotDiffChart(nil) - if !errors.Is(err, gctcommon.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, gctcommon.ErrNilPointer) - } + assert.ErrorIs(t, err, gctcommon.ErrNilPointer) tt := time.Now() cp := currency.NewBTCUSD() diff --git a/backtester/report/report_test.go b/backtester/report/report_test.go index 8a7ce340..382aafd8 100644 --- a/backtester/report/report_test.go +++ b/backtester/report/report_test.go @@ -1,7 +1,6 @@ package report import ( - "errors" "testing" "time" @@ -323,16 +322,14 @@ func TestEnhanceCandles(t *testing.T) { tt := time.Now() var d Data err := d.enhanceCandles() - if !errors.Is(err, errNoCandles) { - t.Errorf("received: %v, expected: %v", err, errNoCandles) - } + assert.ErrorIs(t, err, errNoCandles) + err = d.SetKlineData(&gctkline.Item{}) assert.NoError(t, err) err = d.enhanceCandles() - if !errors.Is(err, errStatisticsUnset) { - t.Errorf("received: %v, expected: %v", err, errStatisticsUnset) - } + assert.ErrorIs(t, err, errStatisticsUnset) + d.Statistics = &statistics.Statistic{} err = d.enhanceCandles() assert.NoError(t, err) diff --git a/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go b/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go index 3702ae5a..353297ee 100644 --- a/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go +++ b/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go @@ -98,15 +98,13 @@ func setupExchange(ctx context.Context, t *testing.T, name string, cfg *config.C } err = exch.UpdateTradablePairs(ctx, true) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("Cannot setup %v UpdateTradablePairs %v", name, err) - } + require.Truef(t, errors.Is(err, context.DeadlineExceeded) || err == nil, "Exchange %s UpdateTradablePairs must not error: %s", name, err) b := exch.GetBase() assets := b.CurrencyPairs.GetAssetTypes(false) - require.NotEmptyf(t, assets, "exchange %s must have assets", name) + require.NotEmptyf(t, assets, "Exchange %s must have assets", name) for _, a := range assets { - require.NoErrorf(t, b.CurrencyPairs.SetAssetEnabled(a, true), "exchange %s SetAssetEnabled must not error for %s", name, a) + require.NoErrorf(t, b.CurrencyPairs.SetAssetEnabled(a, true), "Exchange %s SetAssetEnabled must not error for asset %s: %s", name, a, err) } // Add +1 to len to verify that exchanges can handle requests with unset pairs and assets @@ -127,9 +125,7 @@ assets: t.Fatalf("Cannot setup %v asset %v getPairFromPairs %v", name, assets[j], err) } err = b.CurrencyPairs.EnablePair(assets[j], p) - if err != nil && !errors.Is(err, currency.ErrPairAlreadyEnabled) { - t.Fatalf("Cannot setup %v asset %v EnablePair %v", name, assets[j], err) - } + require.Truef(t, errors.Is(err, currency.ErrPairAlreadyEnabled) || err == nil, "Exchange %s EnablePair must not error for %s", name, p) p, err = b.FormatExchangeCurrency(p, assets[j]) if err != nil { t.Fatalf("Cannot setup %v asset %v FormatExchangeCurrency %v", name, assets[j], err) diff --git a/common/common_test.go b/common/common_test.go index 2fe64d29..7d75fda7 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -81,9 +81,7 @@ func TestSendHTTPRequest(t *testing.T) { func TestSetHTTPClientWithTimeout(t *testing.T) { t.Parallel() err := SetHTTPClientWithTimeout(-0) - if !errors.Is(err, errCannotSetInvalidTimeout) { - t.Fatalf("received: %v but expected: %v", err, errCannotSetInvalidTimeout) - } + require.ErrorIs(t, err, errCannotSetInvalidTimeout) err = SetHTTPClientWithTimeout(time.Second * 15) require.NoError(t, err) @@ -92,9 +90,7 @@ func TestSetHTTPClientWithTimeout(t *testing.T) { func TestSetHTTPUserAgent(t *testing.T) { t.Parallel() err := SetHTTPUserAgent("") - if !errors.Is(err, errUserAgentInvalid) { - t.Fatalf("received: %v but expected: %v", err, errUserAgentInvalid) - } + require.ErrorIs(t, err, errUserAgentInvalid) err = SetHTTPUserAgent("testy test") require.NoError(t, err) @@ -103,9 +99,7 @@ func TestSetHTTPUserAgent(t *testing.T) { func TestSetHTTPClient(t *testing.T) { t.Parallel() err := SetHTTPClient(nil) - if !errors.Is(err, errHTTPClientInvalid) { - t.Fatalf("received: %v but expected: %v", err, errHTTPClientInvalid) - } + require.ErrorIs(t, err, errHTTPClientInvalid) err = SetHTTPClient(new(http.Client)) require.NoError(t, err) @@ -497,39 +491,25 @@ func TestParseStartEndDate(t *testing.T) { nt := time.Time{} err := StartEndTimeCheck(nt, nt) - if !errors.Is(err, ErrDateUnset) { - t.Errorf("received %v, expected %v", err, ErrDateUnset) - } + assert.ErrorIs(t, err, ErrDateUnset) err = StartEndTimeCheck(et, nt) - if !errors.Is(err, ErrDateUnset) { - t.Errorf("received %v, expected %v", err, ErrDateUnset) - } + assert.ErrorIs(t, err, ErrDateUnset) err = StartEndTimeCheck(et, zeroValueUnix) - if !errors.Is(err, ErrDateUnset) { - t.Errorf("received %v, expected %v", err, ErrDateUnset) - } + assert.ErrorIs(t, err, ErrDateUnset) err = StartEndTimeCheck(zeroValueUnix, et) - if !errors.Is(err, ErrDateUnset) { - t.Errorf("received %v, expected %v", err, ErrDateUnset) - } + assert.ErrorIs(t, err, ErrDateUnset) err = StartEndTimeCheck(et, et) - if !errors.Is(err, ErrStartEqualsEnd) { - t.Errorf("received %v, expected %v", err, ErrStartEqualsEnd) - } + assert.ErrorIs(t, err, ErrStartEqualsEnd) err = StartEndTimeCheck(et, pt) - if !errors.Is(err, ErrStartAfterEnd) { - t.Errorf("received %v, expected %v", err, ErrStartAfterEnd) - } + assert.ErrorIs(t, err, ErrStartAfterEnd) err = StartEndTimeCheck(ft, ft.Add(time.Hour)) - if !errors.Is(err, ErrStartAfterTimeNow) { - t.Errorf("received %v, expected %v", err, ErrStartAfterTimeNow) - } + assert.ErrorIs(t, err, ErrStartAfterTimeNow) err = StartEndTimeCheck(pt, et) assert.NoError(t, err) @@ -547,9 +527,7 @@ func TestGetAssertError(t *testing.T) { } err = GetTypeAssertError("bruh", struct{}{}) - if !errors.Is(err, ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrTypeAssertFailure) - } + require.ErrorIs(t, err, ErrTypeAssertFailure) err = GetTypeAssertError("string", struct{}{}) if err.Error() != "type assert failure from struct {} to string" { diff --git a/common/math/math_test.go b/common/math/math_test.go index 7efd2664..d29827a9 100644 --- a/common/math/math_test.go +++ b/common/math/math_test.go @@ -1,7 +1,6 @@ package math import ( - "errors" "math" "testing" @@ -135,9 +134,7 @@ func TestSortinoRatio(t *testing.T) { t.Error(err) } _, err = SortinoRatio(nil, rfr, avg) - if !errors.Is(err, errZeroValue) { - t.Errorf("expected: %v, received %v", errZeroValue, err) - } + assert.ErrorIs(t, err, errZeroValue) var r float64 r, err = SortinoRatio(figures, rfr, avg) @@ -233,17 +230,14 @@ func TestInformationRatio(t *testing.T) { } _, err = InformationRatio(figures, []float64{1}, avg, avgComparison) - if !errors.Is(err, errInformationBadLength) { - t.Errorf("expected: %v, received %v", errInformationBadLength, err) - } + assert.ErrorIs(t, err, errInformationBadLength) } func TestCalmarRatio(t *testing.T) { t.Parallel() _, err := CalmarRatio(0, 0, 0, 0) - if !errors.Is(err, errCalmarHighest) { - t.Errorf("expected: %v, received %v", errCalmarHighest, err) - } + assert.ErrorIs(t, err, errCalmarHighest) + var ratio float64 ratio, err = CalmarRatio(50000, 15000, 0.2, 0.1) if err != nil { @@ -261,17 +255,14 @@ func TestCAGR(t *testing.T) { 0, 0, 0) - if !errors.Is(err, errCAGRNoIntervals) { - t.Error(err) - } + assert.ErrorIs(t, err, errCAGRNoIntervals) + _, err = CompoundAnnualGrowthRate( 0, 0, 0, 1) - if !errors.Is(err, errCAGRZeroOpenValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errCAGRZeroOpenValue) var cagr float64 cagr, err = CompoundAnnualGrowthRate( @@ -313,9 +304,8 @@ func TestCAGR(t *testing.T) { func TestCalculateSharpeRatio(t *testing.T) { t.Parallel() result, err := SharpeRatio(nil, 0, 0) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + if result != 0 { t.Error("expected 0") } @@ -394,9 +384,8 @@ func TestGeometricAverage(t *testing.T) { t.Parallel() values := []float64{1, 2, 3, 4, 5, 6, 7, 8} _, err := GeometricMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + var mean float64 mean, err = GeometricMean(values) if err != nil { @@ -417,9 +406,8 @@ func TestGeometricAverage(t *testing.T) { values = []float64{-1, 12, 13, 19, 10} mean, err = GeometricMean(values) - if !errors.Is(err, errGeometricNegative) { - t.Error(err) - } + assert.ErrorIs(t, err, errGeometricNegative) + if mean != 0 { t.Errorf("expected %v, received %v", 0, mean) } @@ -429,9 +417,7 @@ func TestFinancialGeometricAverage(t *testing.T) { t.Parallel() values := []float64{1, 2, 3, 4, 5, 6, 7, 8} _, err := FinancialGeometricMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) var mean float64 mean, err = FinancialGeometricMean(values) @@ -462,18 +448,15 @@ func TestFinancialGeometricAverage(t *testing.T) { values = []float64{-2, 12, 13, 19, 10} _, err = FinancialGeometricMean(values) - if !errors.Is(err, errNegativeValueOutOfRange) { - t.Error(err) - } + assert.ErrorIs(t, err, errNegativeValueOutOfRange) } func TestArithmeticAverage(t *testing.T) { t.Parallel() values := []float64{1, 2, 3, 4, 5, 6, 7, 8} _, err := ArithmeticMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + var avg float64 avg, err = ArithmeticMean(values) if err != nil { @@ -500,38 +483,21 @@ func TestDecimalSortinoRatio(t *testing.T) { decimal.NewFromFloat(0.23), } avg, err := DecimalArithmeticMean(figures) - if err != nil { - t.Error(err) - } + require.NoError(t, err) _, err = DecimalSortinoRatio(nil, rfr, avg) - if !errors.Is(err, errZeroValue) { - t.Errorf("expected: %v, received %v", errZeroValue, err) - } + assert.ErrorIs(t, err, errZeroValue) - var r decimal.Decimal - r, err = DecimalSortinoRatio(figures, rfr, avg) - if err != nil && !errors.Is(err, ErrInexactConversion) { - t.Error(err) - } + r, err := DecimalSortinoRatio(figures, rfr, avg) + assert.ErrorIs(t, err, ErrInexactConversion) rf, exact := r.Float64() - if !exact && rf != 3.0377875479459906 { - t.Errorf("expected 3.0377875479459906, received %v", r) - } else if rf != 3.0377875479459907 { - t.Errorf("expected 3.0377875479459907, received %v", r) - } + assert.False(t, exact) + assert.Equal(t, 3.0377875479459906, rf) avg, err = DecimalFinancialGeometricMean(figures) - if err != nil { - t.Error(err) - } - + require.NoError(t, err) r, err = DecimalSortinoRatio(figures, rfr, avg) - if err != nil && !errors.Is(err, ErrInexactConversion) { - t.Error(err) - } - if !r.Equal(decimal.NewFromFloat(2.8712802265603243)) { - t.Errorf("expected 2.525203164136098, received %v", r) - } + assert.ErrorIs(t, err, ErrInexactConversion) + assert.True(t, r.Equal(decimal.NewFromFloat(2.8712802265603243))) // this follows and matches the example calculation from // https://www.wallstreetmojo.com/sortino-ratio/ @@ -550,16 +516,10 @@ func TestDecimalSortinoRatio(t *testing.T) { decimal.NewFromFloat(0.02), } avg, err = DecimalArithmeticMean(example) - if err != nil { - t.Error(err) - } + require.NoError(t, err) r, err = DecimalSortinoRatio(example, decimal.NewFromFloat(0.06), avg) - if err != nil && !errors.Is(err, ErrInexactConversion) { - t.Error(err) - } - if rr := r.Round(1); !rr.Equal(decimal.NewFromFloat(0.2)) { - t.Errorf("expected 0.2, received %v", rr) - } + assert.ErrorIs(t, err, ErrInexactConversion) + assert.True(t, r.Round(1).Equal(decimal.NewFromFloat(0.2))) } func TestDecimalInformationRatio(t *testing.T) { @@ -593,57 +553,37 @@ func TestDecimalInformationRatio(t *testing.T) { decimal.Zero, } avg, err := DecimalArithmeticMean(figures) - if err != nil { - t.Error(err) - } - if !avg.Equal(decimal.NewFromFloat(0.01145)) { - t.Error(avg) - } - var avgComparison decimal.Decimal - avgComparison, err = DecimalArithmeticMean(comparisonFigures) - if err != nil { - t.Error(err) - } - if !avgComparison.Equal(decimal.NewFromFloat(0.005425)) { - t.Error(avgComparison) - } + require.NoError(t, err) + assert.True(t, decimal.NewFromFloat(0.01145).Equal(avg)) + + avgComparison, err := DecimalArithmeticMean(comparisonFigures) + require.NoError(t, err) + assert.True(t, decimal.NewFromFloat(0.005425).Equal(avgComparison)) eachDiff := make([]decimal.Decimal, len(figures)) for i := range figures { eachDiff[i] = figures[i].Sub(comparisonFigures[i]) } stdDev, err := DecimalPopulationStandardDeviation(eachDiff) - if err != nil && !errors.Is(err, ErrInexactConversion) { - t.Error(err) - } - if !stdDev.Equal(decimal.NewFromFloat(0.028992588851865227)) { - t.Error(stdDev) - } + require.ErrorIs(t, err, ErrInexactConversion) + assert.Equal(t, decimal.NewFromFloat(0.028992588851865227), stdDev) + information := avg.Sub(avgComparison).Div(stdDev) - if !information.Equal(decimal.NewFromFloat(0.2078117283966652)) { - t.Errorf("expected %v received %v", 0.2078117283966652, information) - } - var information2 decimal.Decimal - information2, err = DecimalInformationRatio(figures, comparisonFigures, avg, avgComparison) - if err != nil { - t.Error(err) - } - if !information.Equal(information2) { - t.Error(information2) - } + assert.Equal(t, decimal.NewFromFloat(0.2078117283966652), information) + + information2, err := DecimalInformationRatio(figures, comparisonFigures, avg, avgComparison) + require.NoError(t, err) + assert.Equal(t, information, information2) _, err = DecimalInformationRatio(figures, []decimal.Decimal{decimal.NewFromInt(1)}, avg, avgComparison) - if !errors.Is(err, errInformationBadLength) { - t.Errorf("expected: %v, received %v", errInformationBadLength, err) - } + assert.ErrorIs(t, err, errInformationBadLength) } func TestDecimalCalmarRatio(t *testing.T) { t.Parallel() _, err := DecimalCalmarRatio(decimal.Zero, decimal.Zero, decimal.Zero, decimal.Zero) - if !errors.Is(err, errCalmarHighest) { - t.Errorf("expected: %v, received %v", errCalmarHighest, err) - } + assert.ErrorIs(t, err, errCalmarHighest) + var ratio decimal.Decimal ratio, err = DecimalCalmarRatio( decimal.NewFromInt(50000), @@ -661,9 +601,8 @@ func TestDecimalCalmarRatio(t *testing.T) { func TestDecimalCalculateSharpeRatio(t *testing.T) { t.Parallel() result, err := DecimalSharpeRatio(nil, decimal.Zero, decimal.Zero) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + if !result.IsZero() { t.Error("expected 0") } @@ -758,9 +697,8 @@ func TestDecimalGeometricAverage(t *testing.T) { decimal.NewFromInt(8), } _, err := DecimalGeometricMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + var mean decimal.Decimal mean, err = DecimalGeometricMean(values) if err != nil { @@ -793,9 +731,8 @@ func TestDecimalGeometricAverage(t *testing.T) { decimal.NewFromInt(10), } mean, err = DecimalGeometricMean(values) - if !errors.Is(err, errGeometricNegative) { - t.Error(err) - } + assert.ErrorIs(t, err, errGeometricNegative) + if !mean.IsZero() { t.Errorf("expected %v, received %v", 0, mean) } @@ -814,9 +751,7 @@ func TestDecimalFinancialGeometricAverage(t *testing.T) { decimal.NewFromInt(8), } _, err := DecimalFinancialGeometricMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) var mean decimal.Decimal mean, err = DecimalFinancialGeometricMean(values) @@ -865,9 +800,7 @@ func TestDecimalFinancialGeometricAverage(t *testing.T) { decimal.NewFromInt(10), } _, err = DecimalFinancialGeometricMean(values) - if !errors.Is(err, errNegativeValueOutOfRange) { - t.Error(err) - } + assert.ErrorIs(t, err, errNegativeValueOutOfRange) } func TestDecimalArithmeticAverage(t *testing.T) { @@ -883,9 +816,8 @@ func TestDecimalArithmeticAverage(t *testing.T) { decimal.NewFromInt(8), } _, err := DecimalArithmeticMean(nil) - if !errors.Is(err, errZeroValue) { - t.Error(err) - } + assert.ErrorIs(t, err, errZeroValue) + var avg decimal.Decimal avg, err = DecimalArithmeticMean(values) if err != nil { diff --git a/config/config_test.go b/config/config_test.go index b5895038..6baf7bce 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,7 +1,6 @@ package config import ( - "errors" "os" "path/filepath" "runtime" @@ -1129,9 +1128,7 @@ func TestGetExchangeConfig(t *testing.T) { err.Error()) } _, err = cfg.GetExchangeConfig("Testy") - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v' expected '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestGetForexProviders(t *testing.T) { @@ -2006,9 +2003,7 @@ func TestMigrateConfig(t *testing.T) { func TestExchangeConfigValidate(t *testing.T) { err := (*Exchange)(nil).Validate() - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + require.ErrorIs(t, err, errExchangeConfigIsNil) err = (&Exchange{}).Validate() require.NoError(t, err) diff --git a/currency/code_test.go b/currency/code_test.go index ec224275..5d78262a 100644 --- a/currency/code_test.go +++ b/currency/code_test.go @@ -1,7 +1,6 @@ package currency import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -253,9 +252,7 @@ func TestBaseCode(t *testing.T) { Symbol: "BTC", ID: 1337, }) - if !errors.Is(err, errRoleUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRoleUnset) - } + require.ErrorIs(t, err, errRoleUnset) err = main.UpdateCurrency(&Item{ FullName: "Bitcoin", @@ -331,14 +328,10 @@ func TestBaseCode(t *testing.T) { } err = main.LoadItem(nil) - if !errors.Is(err, errItemIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errItemIsNil) - } + require.ErrorIs(t, err, errItemIsNil) err = main.LoadItem(&Item{}) - if !errors.Is(err, errItemIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errItemIsEmpty) - } + require.ErrorIs(t, err, errItemIsEmpty) err = main.LoadItem(&Item{ ID: 0, diff --git a/currency/currency_test.go b/currency/currency_test.go index cf86f1a4..de0c15c6 100644 --- a/currency/currency_test.go +++ b/currency/currency_test.go @@ -1,7 +1,6 @@ package currency import ( - "errors" "testing" "github.com/stretchr/testify/require" @@ -93,19 +92,13 @@ func TestUpdateCurrencies(t *testing.T) { func TestConvertFiat(t *testing.T) { _, err := ConvertFiat(0, LTC, USD) - if !errors.Is(err, errInvalidAmount) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidAmount) - } + require.ErrorIs(t, err, errInvalidAmount) _, err = ConvertFiat(100, LTC, USD) - if !errors.Is(err, errNotFiatCurrency) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotFiatCurrency) - } + require.ErrorIs(t, err, errNotFiatCurrency) _, err = ConvertFiat(100, USD, LTC) - if !errors.Is(err, errNotFiatCurrency) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotFiatCurrency) - } + require.ErrorIs(t, err, errNotFiatCurrency) _, err = ConvertFiat(100, AUD, USD) if err != nil { @@ -135,14 +128,10 @@ func TestConvertFiat(t *testing.T) { func TestGetForeignExchangeRate(t *testing.T) { _, err := GetForeignExchangeRate(NewPair(EMPTYCODE, EMPTYCODE)) - if !errors.Is(err, errNotFiatCurrency) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotFiatCurrency) - } + require.ErrorIs(t, err, errNotFiatCurrency) _, err = GetForeignExchangeRate(NewPair(USD, EMPTYCODE)) - if !errors.Is(err, errNotFiatCurrency) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotFiatCurrency) - } + require.ErrorIs(t, err, errNotFiatCurrency) one, err := GetForeignExchangeRate(NewPair(USD, USD)) require.NoError(t, err) diff --git a/currency/forexprovider/exchangeratesapi.io/exchangeratesapi_test.go b/currency/forexprovider/exchangeratesapi.io/exchangeratesapi_test.go index d2fa41f6..1dbb78c3 100644 --- a/currency/forexprovider/exchangeratesapi.io/exchangeratesapi_test.go +++ b/currency/forexprovider/exchangeratesapi.io/exchangeratesapi_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency/forexprovider/base" ) @@ -23,7 +25,7 @@ func TestMain(t *testing.M) { APIKey: apiKey, APIKeyLvl: apiKeyLevel, }) - if err != nil && !errors.Is(err, errAPIKeyNotSet) { + if err != nil && !(errors.Is(err, errAPIKeyNotSet)) { log.Fatal(err) } os.Exit(t.Run()) @@ -71,9 +73,7 @@ func TestGetLatestRates(t *testing.T) { if e.APIKeyLvl <= apiKeyFree { _, err = e.GetLatestRates("USD", "") - if !errors.Is(err, errCannotSetBaseCurrencyOnFreePlan) { - t.Errorf("expected: %s, got %s", errCannotSetBaseCurrencyOnFreePlan, err) - } + assert.ErrorIs(t, err, errCannotSetBaseCurrencyOnFreePlan) } result, err = e.GetLatestRates("EUR", "AUD") @@ -102,9 +102,7 @@ func TestGetHistoricalRates(t *testing.T) { if e.APIKeyLvl <= apiKeyFree { _, err = e.GetHistoricalRates(time.Now(), "USD", []string{"AUD"}) - if !errors.Is(err, errCannotSetBaseCurrencyOnFreePlan) { - t.Errorf("expected: %s, got %s", errCannotSetBaseCurrencyOnFreePlan, err) - } + assert.ErrorIs(t, err, errCannotSetBaseCurrencyOnFreePlan) } _, err = e.GetHistoricalRates(time.Now(), "EUR", []string{"AUD,USD"}) @@ -120,9 +118,8 @@ func TestConvertCurrency(t *testing.T) { if e.APIKeyLvl <= apiKeyFree { _, err := e.ConvertCurrency("USD", "AUD", 1000, time.Time{}) - if !errors.Is(err, errAPIKeyLevelRestrictedAccess) { - t.Errorf("expected: %s, got %s", errAPIKeyLevelRestrictedAccess, err) - } + assert.ErrorIs(t, err, errAPIKeyLevelRestrictedAccess) + return } @@ -144,22 +141,17 @@ func TestGetTimeSeriesRates(t *testing.T) { if e.APIKeyLvl <= apiKeyFree { _, err := e.GetTimeSeriesRates(time.Time{}, time.Time{}, "EUR", []string{"EUR,USD"}) - if !errors.Is(err, errAPIKeyLevelRestrictedAccess) { - t.Errorf("expected %s, got %s", errAPIKeyLevelRestrictedAccess, err) - } + assert.ErrorIs(t, err, errAPIKeyLevelRestrictedAccess) + return } _, err := e.GetTimeSeriesRates(time.Time{}, time.Time{}, "USD", []string{"EUR", "USD"}) - if !errors.Is(err, errStartEndDatesInvalid) { - t.Fatalf("received '%v' expected '%v'", err, errStartEndDatesInvalid) - } + require.ErrorIs(t, err, errStartEndDatesInvalid) tmNow := time.Now() _, err = e.GetTimeSeriesRates(tmNow.AddDate(0, 1, 0), tmNow, "USD", []string{"EUR", "USD"}) - if !errors.Is(err, errStartAfterEnd) { - t.Fatalf("received '%v' expected '%v'", err, errStartAfterEnd) - } + require.ErrorIs(t, err, errStartAfterEnd) _, err = e.GetTimeSeriesRates(tmNow.AddDate(0, -1, 0), tmNow, "EUR", []string{"AUD,USD"}) if err != nil { @@ -174,9 +166,8 @@ func TestGetFluctuation(t *testing.T) { if e.APIKeyLvl <= apiKeyFree { _, err := e.GetFluctuations(time.Time{}, time.Time{}, "EUR", "") - if !errors.Is(err, errAPIKeyLevelRestrictedAccess) { - t.Errorf("expected: %s, got %s", errAPIKeyLevelRestrictedAccess, err) - } + assert.ErrorIs(t, err, errAPIKeyLevelRestrictedAccess) + return } diff --git a/currency/manager_test.go b/currency/manager_test.go index cc613042..a60903a4 100644 --- a/currency/manager_test.go +++ b/currency/manager_test.go @@ -1,7 +1,6 @@ package currency import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -81,14 +80,10 @@ func TestGet(t *testing.T) { } _, err = p.Get(asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = p.Get(asset.CoinMarginedFutures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) } func TestPairsManagerMatch(t *testing.T) { @@ -97,26 +92,18 @@ func TestPairsManagerMatch(t *testing.T) { p := &PairsManager{} _, err := p.Match("", 1337) - if !errors.Is(err, ErrSymbolStringEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSymbolStringEmpty) - } + require.ErrorIs(t, err, ErrSymbolStringEmpty) _, err = p.Match("sillyBilly", 1337) - if !errors.Is(err, errPairMatcherIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errPairMatcherIsNil) - } + require.ErrorIs(t, err, errPairMatcherIsNil) p = initTest(t) _, err = p.Match("sillyBilly", 1337) - if !errors.Is(err, ErrPairNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairNotFound) - } + require.ErrorIs(t, err, ErrPairNotFound) _, err = p.Match("sillyBilly", asset.Spot) - if !errors.Is(err, ErrPairNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairNotFound) - } + require.ErrorIs(t, err, ErrPairNotFound) whatIgot, err := p.Match("bTCuSD", asset.Spot) require.NoError(t, err) @@ -172,14 +159,10 @@ func TestStore(t *testing.T) { } err = p.Store(asset.Empty, nil) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = p.Store(asset.Futures, nil) - if !errors.Is(err, errPairStoreIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errPairStoreIsNil) - } + require.ErrorIs(t, err, errPairStoreIsNil) } func TestDelete(t *testing.T) { @@ -240,9 +223,7 @@ func TestGetPairs(t *testing.T) { } pairs, err = p.GetPairs(asset.Empty, true) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) if pairs != nil { t.Fatal("pairs shouldn't be populated") @@ -263,14 +244,10 @@ func TestStoreFormat(t *testing.T) { p := &PairsManager{} err := p.StoreFormat(0, &PairFormat{Delimiter: "~"}, true) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = p.StoreFormat(asset.Spot, nil, true) - if !errors.Is(err, ErrPairFormatIsNil) { - t.Fatalf("received: %v but expected: %v", err, ErrPairFormatIsNil) - } + require.ErrorIs(t, err, ErrPairFormatIsNil) err = p.StoreFormat(asset.Spot, &PairFormat{Delimiter: "~"}, true) require.NoError(t, err) @@ -302,9 +279,7 @@ func TestStorePairs(t *testing.T) { p := initTest(t) err := p.StorePairs(0, nil, false) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) p.Pairs = nil @@ -410,10 +385,7 @@ func TestDisablePair(t *testing.T) { func TestEnablePair(t *testing.T) { t.Parallel() p := initTest(t) - - if err := p.EnablePair(asset.Empty, NewBTCUSD()); !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, p.EnablePair(asset.Empty, NewBTCUSD()), asset.ErrNotSupported) p.Pairs = nil // Test enabling a pair when the pair manager is not initialised @@ -522,9 +494,7 @@ func TestFullStoreUnmarshalMarshal(t *testing.T) { data = []byte(`{"bro":{"assetEnabled":null,"enabled":"","available":""}}`) err = json.Unmarshal(data, &another) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) } func TestIsPairAvailable(t *testing.T) { @@ -691,9 +661,7 @@ func TestEnsureOnePairEnabled(t *testing.T) { Pairs: map[asset.Item]*PairStore{}, } _, _, err = pm.EnsureOnePairEnabled() - if !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Errorf("received: '%v' but expected: '%v'", err, ErrCurrencyPairsEmpty) - } + assert.ErrorIs(t, err, ErrCurrencyPairsEmpty) } func TestLoad(t *testing.T) { diff --git a/currency/pair_test.go b/currency/pair_test.go index 4907a0e5..126f4ef9 100644 --- a/currency/pair_test.go +++ b/currency/pair_test.go @@ -1,7 +1,6 @@ package currency import ( - "errors" "strconv" "testing" @@ -543,9 +542,8 @@ func TestRandomPairFromPairs(t *testing.T) { // Test that an empty pairs array returns an empty currency pair var emptyPairs Pairs result, err := emptyPairs.GetRandomPair() - if !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrCurrencyPairsEmpty) - } + require.ErrorIs(t, err, ErrCurrencyPairsEmpty) + if !result.IsEmpty() { t.Error("TestRandomPairFromPairs: Unexpected values") } @@ -739,9 +737,8 @@ func TestOther(t *testing.T) { if !received.Equal(DAI) { t.Fatal("unexpected value") } - if _, err := NewPair(DAI, XRP).Other(BTC); !errors.Is(err, ErrCurrencyCodeEmpty) { - t.Fatal("unexpected value") - } + _, err = NewPair(DAI, XRP).Other(BTC) + require.ErrorIs(t, err, ErrCurrencyCodeEmpty) } func TestIsPopulated(t *testing.T) { @@ -801,10 +798,7 @@ func TestGetOrderParameters(t *testing.T) { case !tc.market && !tc.selling: resp, err = tc.Pair.LimitBuyOrderParameters(tc.currency) } - - if !errors.Is(err, tc.expectedError) { - t.Fatalf("received %v, expected %v", err, tc.expectedError) - } + require.ErrorIs(t, err, tc.expectedError) if tc.expectedParams == nil { if resp != nil { diff --git a/currency/pairs_test.go b/currency/pairs_test.go index cceb2d0d..426a40b8 100644 --- a/currency/pairs_test.go +++ b/currency/pairs_test.go @@ -1,7 +1,6 @@ package currency import ( - "errors" "slices" "testing" @@ -52,13 +51,10 @@ func TestPairsString(t *testing.T) { func TestPairsFromString(t *testing.T) { t.Parallel() - if _, err := NewPairsFromString("", ""); !errors.Is(err, errNoDelimiter) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoDelimiter) - } - - if _, err := NewPairsFromString("", ","); !errors.Is(err, errCannotCreatePair) { - t.Fatalf("received: '%v' but expected: '%v'", err, errCannotCreatePair) - } + _, err := NewPairsFromString("", "") + assert.ErrorIs(t, err, errNoDelimiter) + _, err = NewPairsFromString("", ",") + assert.ErrorIs(t, err, errCannotCreatePair) pairs, err := NewPairsFromString("ALGO-AUD,BAT-AUD,BCH-AUD,BSV-AUD,BTC-AUD,COMP-AUD,ENJ-AUD,ETC-AUD,ETH-AUD,ETH-BTC,GNT-AUD,LINK-AUD,LTC-AUD,LTC-BTC,MCAU-AUD,OMG-AUD,POWR-AUD,UNI-AUD,USDT-AUD,XLM-AUD,XRP-AUD,XRP-BTC", ",") require.NoError(t, err) @@ -70,12 +66,7 @@ func TestPairsFromString(t *testing.T) { "UNI-AUD", "USDT-AUD", "XLM-AUD", "XRP-AUD", "XRP-BTC", } - returned := pairs.Strings() - for x := range returned { - if returned[x] != expected[x] { - t.Fatalf("received: '%v' but expected: '%v'", returned[x], expected[x]) - } - } + assert.Equal(t, expected, pairs.Strings(), "NewPairsFromString should return the correct pairs") } func TestPairsJoin(t *testing.T) { @@ -273,9 +264,7 @@ func TestContainsAll(t *testing.T) { } err := pairs.ContainsAll(nil, true) - if !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrCurrencyPairsEmpty) - } + require.ErrorIs(t, err, ErrCurrencyPairsEmpty) err = pairs.ContainsAll(Pairs{NewBTCUSD()}, true) require.NoError(t, err) @@ -284,14 +273,10 @@ func TestContainsAll(t *testing.T) { require.NoError(t, err) err = pairs.ContainsAll(Pairs{NewPair(XRP, BTC)}, false) - if !errors.Is(err, ErrPairNotContainedInAvailablePairs) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairNotContainedInAvailablePairs) - } + require.ErrorIs(t, err, ErrPairNotContainedInAvailablePairs) err = pairs.ContainsAll(Pairs{NewPair(XRP, BTC)}, true) - if !errors.Is(err, ErrPairNotContainedInAvailablePairs) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairNotContainedInAvailablePairs) - } + require.ErrorIs(t, err, ErrPairNotContainedInAvailablePairs) err = pairs.ContainsAll(pairs, true) require.NoError(t, err) @@ -307,17 +292,14 @@ func TestContainsAll(t *testing.T) { } err = pairs.ContainsAll(duplication, false) - if !errors.Is(err, ErrPairDuplication) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairDuplication) - } + require.ErrorIs(t, err, ErrPairDuplication) } func TestDeriveFrom(t *testing.T) { t.Parallel() _, err := Pairs{}.DeriveFrom("") - if !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrCurrencyPairsEmpty) - } + require.ErrorIs(t, err, ErrCurrencyPairsEmpty) + testCases := Pairs{ NewBTCUSDT(), NewPair(USDC, USDT), @@ -327,14 +309,10 @@ func TestDeriveFrom(t *testing.T) { } _, err = testCases.DeriveFrom("") - if !errors.Is(err, errSymbolEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errSymbolEmpty) - } + require.ErrorIs(t, err, errSymbolEmpty) _, err = testCases.DeriveFrom("btcUSD") - if !errors.Is(err, ErrPairNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrPairNotFound) - } + require.ErrorIs(t, err, ErrPairNotFound) got, err := testCases.DeriveFrom("USDCUSD") require.NoError(t, err) @@ -429,9 +407,7 @@ func TestGetMatch(t *testing.T) { } _, err := pairs.GetMatch(NewPair(BTC, WABI)) - if !errors.Is(err, ErrPairNotFound) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrPairNotFound) - } + require.ErrorIs(t, err, ErrPairNotFound) expected := NewBTCUSD() match, err := pairs.GetMatch(expected) @@ -628,9 +604,7 @@ func TestValidateAndConform(t *testing.T) { } _, err := conformMe.ValidateAndConform(EMPTYFORMAT, false) - if !errors.Is(err, ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, ErrCurrencyPairEmpty) duplication, err := NewPairFromString("linkusdt") if err != nil { @@ -649,9 +623,7 @@ func TestValidateAndConform(t *testing.T) { } _, err = conformMe.ValidateAndConform(EMPTYFORMAT, false) - if !errors.Is(err, ErrPairDuplication) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrPairDuplication) - } + require.ErrorIs(t, err, ErrPairDuplication) conformMe = Pairs{ NewBTCUSD(), @@ -737,9 +709,8 @@ func TestGetPairsByQuote(t *testing.T) { t.Parallel() var available Pairs - if _, err := available.GetPairsByQuote(EMPTYCODE); !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrCurrencyPairsEmpty) - } + _, err := available.GetPairsByQuote(EMPTYCODE) + require.ErrorIs(t, err, ErrCurrencyPairsEmpty) available = Pairs{ NewBTCUSD(), @@ -751,9 +722,8 @@ func TestGetPairsByQuote(t *testing.T) { NewPair(DAI, XRP), } - if _, err := available.GetPairsByQuote(EMPTYCODE); !errors.Is(err, ErrCurrencyCodeEmpty) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrCurrencyCodeEmpty) - } + _, err = available.GetPairsByQuote(EMPTYCODE) + require.ErrorIs(t, err, ErrCurrencyCodeEmpty) got, err := available.GetPairsByQuote(USD) require.NoError(t, err) @@ -774,9 +744,8 @@ func TestGetPairsByBase(t *testing.T) { t.Parallel() var available Pairs - if _, err := available.GetPairsByBase(EMPTYCODE); !errors.Is(err, ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrCurrencyPairsEmpty) - } + _, err := available.GetPairsByBase(EMPTYCODE) + require.ErrorIs(t, err, ErrCurrencyPairsEmpty) available = Pairs{ NewBTCUSD(), @@ -788,9 +757,8 @@ func TestGetPairsByBase(t *testing.T) { NewPair(DAI, XRP), } - if _, err := available.GetPairsByBase(EMPTYCODE); !errors.Is(err, ErrCurrencyCodeEmpty) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrCurrencyCodeEmpty) - } + _, err = available.GetPairsByBase(EMPTYCODE) + require.ErrorIs(t, err, ErrCurrencyCodeEmpty) got, err := available.GetPairsByBase(USD) require.NoError(t, err) diff --git a/database/database_test.go b/database/database_test.go index 59eb0097..18fe3dbc 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -2,7 +2,6 @@ package database import ( "database/sql" - "errors" "os" "path/filepath" "testing" @@ -18,33 +17,25 @@ func TestSetConfig(t *testing.T) { assert.NoError(t, err) err = inst.SetConfig(nil) - if !errors.Is(err, ErrNilConfig) { - t.Errorf("received %v, expected %v", err, ErrNilConfig) - } + assert.ErrorIs(t, err, ErrNilConfig) inst = nil err = inst.SetConfig(&Config{}) - if !errors.Is(err, ErrNilInstance) { - t.Errorf("received %v, expected %v", err, ErrNilInstance) - } + assert.ErrorIs(t, err, ErrNilInstance) } func TestSetSQLiteConnection(t *testing.T) { t.Parallel() inst := &Instance{} err := inst.SetSQLiteConnection(nil) - if !errors.Is(err, errNilSQL) { - t.Errorf("received %v, expected %v", err, errNilSQL) - } + assert.ErrorIs(t, err, errNilSQL) err = inst.SetSQLiteConnection(&sql.DB{}) assert.NoError(t, err) inst = nil err = inst.SetSQLiteConnection(nil) - if !errors.Is(err, ErrNilInstance) { - t.Errorf("received %v, expected %v", err, ErrNilInstance) - } + assert.ErrorIs(t, err, ErrNilInstance) } func TestSetPostgresConnection(t *testing.T) { @@ -142,19 +133,16 @@ func TestPing(t *testing.T) { inst.SQL = nil err = inst.Ping() - if !errors.Is(err, errNilSQL) { - t.Errorf("received %v, expected %v", err, errNilSQL) - } + assert.ErrorIs(t, err, errNilSQL) + inst.SetConnected(false) err = inst.Ping() - if !errors.Is(err, ErrDatabaseNotConnected) { - t.Errorf("received %v, expected %v", err, ErrDatabaseNotConnected) - } + assert.ErrorIs(t, err, ErrDatabaseNotConnected) + inst = nil err = inst.Ping() - if !errors.Is(err, ErrNilInstance) { - t.Errorf("received %v, expected %v", err, ErrNilInstance) - } + assert.ErrorIs(t, err, ErrNilInstance) + err = con.Close() assert.NoError(t, err) @@ -166,9 +154,7 @@ func TestGetSQL(t *testing.T) { t.Parallel() inst := &Instance{} _, err := inst.GetSQL() - if !errors.Is(err, errNilSQL) { - t.Errorf("received %v, expected %v", err, errNilSQL) - } + assert.ErrorIs(t, err, errNilSQL) databaseFullLocation := filepath.Join(DB.DataPath, "TestGetSQL") con, err := sql.Open("sqlite3", databaseFullLocation) @@ -182,7 +168,5 @@ func TestGetSQL(t *testing.T) { inst = nil _, err = inst.GetSQL() - if !errors.Is(err, ErrNilInstance) { - t.Errorf("received %v, expected %v", err, ErrNilInstance) - } + assert.ErrorIs(t, err, ErrNilInstance) } diff --git a/database/repository/candle/candle_test.go b/database/repository/candle/candle_test.go index 715e5784..0f077df6 100644 --- a/database/repository/candle/candle_test.go +++ b/database/repository/candle/candle_test.go @@ -1,13 +1,14 @@ package candle import ( - "errors" "fmt" "os" "path/filepath" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/database" "github.com/thrasher-corp/gocryptotrader/database/drivers" @@ -245,20 +246,11 @@ func TestSeries(t *testing.T) { } _, err = Series("", "", "", 0, "", start, end) - if !errors.Is(err, errInvalidInput) { - t.Fatal(err) - } + require.ErrorIs(t, err, errInvalidInput) - _, err = Series(testExchanges[0].Name, - "BTC", "MOON", - 864000, "spot", - start, end) - if err != nil && !errors.Is(err, errInvalidInput) && !errors.Is(err, ErrNoCandleDataFound) { - t.Fatal(err) - } - if err = testhelpers.CloseDatabase(dbConn); err != nil { - t.Error(err) - } + _, err = Series(testExchanges[0].Name, "BTC", "MOON", 864000, "spot", start, end) + require.ErrorIs(t, err, ErrNoCandleDataFound) + assert.NoError(t, testhelpers.CloseDatabase(dbConn)) }) } } diff --git a/database/repository/datahistoryjob/datahistoryjob_test.go b/database/repository/datahistoryjob/datahistoryjob_test.go index 55a2d1b0..4c910fed 100644 --- a/database/repository/datahistoryjob/datahistoryjob_test.go +++ b/database/repository/datahistoryjob/datahistoryjob_test.go @@ -1,7 +1,6 @@ package datahistoryjob import ( - "errors" "fmt" "log" "os" @@ -258,9 +257,8 @@ func TestDataHistoryJob(t *testing.T) { assert.NoError(t, err) err = db.SetRelationshipByNickname(results[2].Nickname, results[2].Nickname, 0) - if !errors.Is(err, errCannotSetSamePrerequisite) { - t.Errorf("received %v expected %v", err, errCannotSetSamePrerequisite) - } + assert.ErrorIs(t, err, errCannotSetSamePrerequisite) + err = db.SetRelationshipByNickname(results[3].Nickname, results[2].Nickname, 0) assert.NoError(t, err) diff --git a/database/repository/withdraw/withdraw_test.go b/database/repository/withdraw/withdraw_test.go index 805fb123..fad18060 100644 --- a/database/repository/withdraw/withdraw_test.go +++ b/database/repository/withdraw/withdraw_test.go @@ -1,13 +1,14 @@ package withdraw import ( - "errors" "fmt" "math/rand" "os" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/database" @@ -159,11 +160,7 @@ func withdrawHelper(t *testing.T) { seedWithdrawData() _, err := GetEventByUUID(withdraw.DryRunID.String()) - if err != nil { - if !errors.Is(err, common.ErrNoResults) { - t.Fatal(err) - } - } + require.ErrorIs(t, err, common.ErrNoResults) v, err := GetEventsByExchange(testExchanges[0].Name, 10) if err != nil { @@ -182,9 +179,7 @@ func withdrawHelper(t *testing.T) { if len(v) > 0 { _, err = GetEventByUUID(v[0].ID.String()) if err != nil { - if !errors.Is(err, common.ErrNoResults) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNoResults) } } diff --git a/engine/apiserver_test.go b/engine/apiserver_test.go index 5d9bc565..06e523cd 100644 --- a/engine/apiserver_test.go +++ b/engine/apiserver_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "io" "net/http" "net/http/httptest" @@ -17,29 +16,19 @@ import ( func TestSetupAPIServerManager(t *testing.T) { t.Parallel() _, err := setupAPIServerManager(nil, nil, nil, nil, nil, "") - if !errors.Is(err, errNilRemoteConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilRemoteConfig) - } + assert.ErrorIs(t, err, errNilRemoteConfig) _, err = setupAPIServerManager(&config.RemoteControlConfig{}, nil, nil, nil, nil, "") - if !errors.Is(err, errNilPProfConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilPProfConfig) - } + assert.ErrorIs(t, err, errNilPProfConfig) _, err = setupAPIServerManager(&config.RemoteControlConfig{}, &config.Profiler{}, nil, nil, nil, "") - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilExchangeManager) - } + assert.ErrorIs(t, err, errNilExchangeManager) _, err = setupAPIServerManager(&config.RemoteControlConfig{}, &config.Profiler{}, &ExchangeManager{}, nil, nil, "") - if !errors.Is(err, errNilBot) { - t.Errorf("error '%v', expected '%v'", err, errNilBot) - } + assert.ErrorIs(t, err, errNilBot) _, err = setupAPIServerManager(&config.RemoteControlConfig{}, &config.Profiler{}, &ExchangeManager{}, &fakeBot{}, nil, "") - if !errors.Is(err, errEmptyConfigPath) { - t.Errorf("error '%v', expected '%v'", err, errEmptyConfigPath) - } + assert.ErrorIs(t, err, errEmptyConfigPath) wd, _ := os.Getwd() _, err = setupAPIServerManager(&config.RemoteControlConfig{}, &config.Profiler{}, &ExchangeManager{}, &fakeBot{}, nil, wd) @@ -53,9 +42,8 @@ func TestStartRESTServer(t *testing.T) { assert.NoError(t, err) err = m.StartRESTServer() - if !errors.Is(err, errServerDisabled) { - t.Errorf("error '%v', expected '%v'", err, errServerDisabled) - } + assert.ErrorIs(t, err, errServerDisabled) + m.remoteConfig.DeprecatedRPC.Enabled = true err = m.StartRESTServer() if err != nil { @@ -70,9 +58,8 @@ func TestStartWebsocketServer(t *testing.T) { assert.NoError(t, err) err = m.StartWebsocketServer() - if !errors.Is(err, errServerDisabled) { - t.Errorf("error '%v', expected '%v'", err, errServerDisabled) - } + assert.ErrorIs(t, err, errServerDisabled) + m.remoteConfig.WebsocketRPC.Enabled = true err = m.StartWebsocketServer() assert.NoError(t, err) @@ -90,9 +77,7 @@ func TestStopRESTServer(t *testing.T) { assert.NoError(t, err) err = m.StopRESTServer() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.StartRESTServer() assert.NoError(t, err) @@ -120,9 +105,7 @@ func TestWebsocketStop(t *testing.T) { assert.NoError(t, err) err = m.StopWebsocketServer() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.StartWebsocketServer() assert.NoError(t, err) diff --git a/engine/communication_manager_test.go b/engine/communication_manager_test.go index f60ecfd4..232fabe7 100644 --- a/engine/communication_manager_test.go +++ b/engine/communication_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -12,14 +11,10 @@ import ( func TestSetup(t *testing.T) { t.Parallel() _, err := SetupCommunicationManager(nil) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) _, err = SetupCommunicationManager(&base.CommunicationsConfig{}) - if !errors.Is(err, communications.ErrNoRelayersEnabled) { - t.Errorf("error '%v', expected '%v'", err, communications.ErrNoRelayersEnabled) - } + assert.ErrorIs(t, err, communications.ErrNoRelayersEnabled) m, err := SetupCommunicationManager(&base.CommunicationsConfig{ SlackConfig: base.SlackConfig{ @@ -72,9 +67,7 @@ func TestStart(t *testing.T) { m.started = 1 err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) } func TestGetStatus(t *testing.T) { @@ -94,9 +87,7 @@ func TestGetStatus(t *testing.T) { m.started = 0 _, err = m.GetStatus() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestStop(t *testing.T) { @@ -115,14 +106,11 @@ func TestStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil err = m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestPushEvent(t *testing.T) { diff --git a/engine/connection_manager_test.go b/engine/connection_manager_test.go index f769336c..a9a3c359 100644 --- a/engine/connection_manager_test.go +++ b/engine/connection_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -11,9 +10,7 @@ import ( func TestSetupConnectionManager(t *testing.T) { t.Parallel() _, err := setupConnectionManager(nil) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) m, err := setupConnectionManager(&config.ConnectionMonitorConfig{}) assert.NoError(t, err) @@ -53,22 +50,18 @@ func TestConnectionMonitorStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) + m = nil err = m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestConnectionMonitorStop(t *testing.T) { t.Parallel() err := (&connectionManager{started: 1}).Stop() - if !errors.Is(err, errConnectionCheckerIsNil) { - t.Errorf("error '%v', expected '%v'", err, errConnectionCheckerIsNil) - } + assert.ErrorIs(t, err, errConnectionCheckerIsNil) + m, err := setupConnectionManager(&config.ConnectionMonitorConfig{}) assert.NoError(t, err) @@ -79,14 +72,11 @@ func TestConnectionMonitorStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil err = m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestConnectionMonitorIsOnline(t *testing.T) { diff --git a/engine/currency_state_manager_test.go b/engine/currency_state_manager_test.go index 8550c99b..06eaff3f 100644 --- a/engine/currency_state_manager_test.go +++ b/engine/currency_state_manager_test.go @@ -18,9 +18,7 @@ import ( func TestSetupCurrencyStateManager(t *testing.T) { t.Parallel() _, err := SetupCurrencyStateManager(0, nil) - if !errors.Is(err, errNilExchangeManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilExchangeManager) - } + require.ErrorIs(t, err, errNilExchangeManager) cm, err := SetupCurrencyStateManager(0, &ExchangeManager{}) require.NoError(t, err) @@ -130,27 +128,19 @@ func (f *fakerino) GetBase() *exchange.Base { func TestCurrencyStateManagerIsRunning(t *testing.T) { t.Parallel() err := (*CurrencyStateManager)(nil).Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) err = (&CurrencyStateManager{}).Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) err = (&CurrencyStateManager{started: 1, shutdown: make(chan struct{})}).Stop() require.NoError(t, err) err = (*CurrencyStateManager)(nil).Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) err = (&CurrencyStateManager{started: 1}).Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemAlreadyStarted) - } + require.ErrorIs(t, err, ErrSubSystemAlreadyStarted) man := &CurrencyStateManager{ shutdown: make(chan struct{}), @@ -195,25 +185,19 @@ func TestCurrencyStateManagerIsRunning(t *testing.T) { func TestGetAllRPC(t *testing.T) { t.Parallel() _, err := (*CurrencyStateManager)(nil).GetAllRPC("") - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeOne: true}, }).GetAllRPC("") - if !errors.Is(err, errManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errManager) - } + require.ErrorIs(t, err, errManager) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeTwo: true}, }).GetAllRPC("") - if !errors.Is(err, errExchange) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchange) - } + require.ErrorIs(t, err, errExchange) _, err = (&CurrencyStateManager{ started: 1, @@ -225,25 +209,19 @@ func TestGetAllRPC(t *testing.T) { func TestCanWithdrawRPC(t *testing.T) { t.Parallel() _, err := (*CurrencyStateManager)(nil).CanWithdrawRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeOne: true}, }).CanWithdrawRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errManager) - } + require.ErrorIs(t, err, errManager) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeTwo: true}, }).CanWithdrawRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errExchange) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchange) - } + require.ErrorIs(t, err, errExchange) _, err = (&CurrencyStateManager{ started: 1, @@ -255,25 +233,19 @@ func TestCanWithdrawRPC(t *testing.T) { func TestCanDepositRPC(t *testing.T) { t.Parallel() _, err := (*CurrencyStateManager)(nil).CanDepositRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeOne: true}, }).CanDepositRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errManager) - } + require.ErrorIs(t, err, errManager) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeTwo: true}, }).CanDepositRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errExchange) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchange) - } + require.ErrorIs(t, err, errExchange) _, err = (&CurrencyStateManager{ started: 1, @@ -285,25 +257,19 @@ func TestCanDepositRPC(t *testing.T) { func TestCanTradeRPC(t *testing.T) { t.Parallel() _, err := (*CurrencyStateManager)(nil).CanTradeRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeOne: true}, }).CanTradeRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errManager) - } + require.ErrorIs(t, err, errManager) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeTwo: true}, }).CanTradeRPC("", currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errExchange) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchange) - } + require.ErrorIs(t, err, errExchange) _, err = (&CurrencyStateManager{ started: 1, @@ -315,25 +281,19 @@ func TestCanTradeRPC(t *testing.T) { func TestCanTradePairRPC(t *testing.T) { t.Parallel() _, err := (*CurrencyStateManager)(nil).CanTradePairRPC("", currency.EMPTYPAIR, asset.Empty) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeOne: true}, }).CanTradePairRPC("", currency.EMPTYPAIR, asset.Empty) - if !errors.Is(err, errManager) { - t.Fatalf("received: '%v' but expected: '%v'", err, errManager) - } + require.ErrorIs(t, err, errManager) _, err = (&CurrencyStateManager{ started: 1, iExchangeManager: &fakeExchangeManagerino{ErrorMeTwo: true}, }).CanTradePairRPC("", currency.EMPTYPAIR, asset.Empty) - if !errors.Is(err, errExchange) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchange) - } + require.ErrorIs(t, err, errExchange) _, err = (&CurrencyStateManager{ started: 1, diff --git a/engine/database_connection_test.go b/engine/database_connection_test.go index 1ea9d291..5ffc8198 100644 --- a/engine/database_connection_test.go +++ b/engine/database_connection_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "log" "sync" "testing" @@ -28,9 +27,7 @@ func CreateDatabase(t *testing.T) { func TestSetupDatabaseConnectionManager(t *testing.T) { _, err := SetupDatabaseConnectionManager(nil) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) m, err := SetupDatabaseConnectionManager(&database.Config{}) assert.NoError(t, err) @@ -47,21 +44,18 @@ func TestStartSQLite(t *testing.T) { var wg sync.WaitGroup err = m.Start(&wg) - if !errors.Is(err, database.ErrDatabaseSupportDisabled) { - t.Errorf("error '%v', expected '%v'", err, database.ErrDatabaseSupportDisabled) - } + assert.ErrorIs(t, err, database.ErrDatabaseSupportDisabled) + m, err = SetupDatabaseConnectionManager(&database.Config{Enabled: true}) assert.NoError(t, err) err = m.Start(&wg) - if !errors.Is(err, database.ErrNoDatabaseProvided) { - t.Errorf("error '%v', expected '%v'", err, database.ErrNoDatabaseProvided) - } + assert.ErrorIs(t, err, database.ErrNoDatabaseProvided) + m.cfg = database.Config{Driver: database.DBSQLite} err = m.Start(&wg) - if !errors.Is(err, database.ErrDatabaseSupportDisabled) { - t.Errorf("error '%v', expected '%v'", err, database.ErrDatabaseSupportDisabled) - } + assert.ErrorIs(t, err, database.ErrDatabaseSupportDisabled) + _, err = SetupDatabaseConnectionManager(&database.Config{ Enabled: true, Driver: database.DBSQLite, @@ -80,19 +74,15 @@ func TestStartPostgres(t *testing.T) { var wg sync.WaitGroup err = m.Start(&wg) - if !errors.Is(err, database.ErrDatabaseSupportDisabled) { - t.Errorf("error '%v', expected '%v'", err, database.ErrDatabaseSupportDisabled) - } + assert.ErrorIs(t, err, database.ErrDatabaseSupportDisabled) + m.cfg.Enabled = true err = m.Start(&wg) - if !errors.Is(err, database.ErrNoDatabaseProvided) { - t.Errorf("error '%v', expected '%v'", err, database.ErrNoDatabaseProvided) - } + assert.ErrorIs(t, err, database.ErrNoDatabaseProvided) + m.cfg.Driver = database.DBPostgreSQL err = m.Start(&wg) - if !errors.Is(err, database.ErrFailedToConnect) { - t.Errorf("error '%v', expected '%v'", err, database.ErrFailedToConnect) - } + assert.ErrorIs(t, err, database.ErrFailedToConnect) } func TestDatabaseConnectionManagerIsRunning(t *testing.T) { @@ -136,9 +126,7 @@ func TestDatabaseConnectionManagerStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) var wg sync.WaitGroup err = m.Start(&wg) @@ -149,18 +137,15 @@ func TestDatabaseConnectionManagerStop(t *testing.T) { m = nil err = m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestCheckConnection(t *testing.T) { CreateDatabase(t) var m *DatabaseConnectionManager err := m.checkConnection() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) + m, err = SetupDatabaseConnectionManager(&database.Config{ Enabled: true, Driver: database.DBSQLite, @@ -172,9 +157,8 @@ func TestCheckConnection(t *testing.T) { assert.NoError(t, err) err = m.checkConnection() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + var wg sync.WaitGroup err = m.Start(&wg) assert.NoError(t, err) @@ -186,9 +170,7 @@ func TestCheckConnection(t *testing.T) { assert.NoError(t, err) err = m.checkConnection() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start(&wg) assert.NoError(t, err) @@ -198,9 +180,7 @@ func TestCheckConnection(t *testing.T) { m.dbConn.SetConnected(false) err = m.checkConnection() - if !errors.Is(err, database.ErrDatabaseNotConnected) { - t.Errorf("error '%v', expected '%v'", err, database.ErrDatabaseNotConnected) - } + assert.ErrorIs(t, err, database.ErrDatabaseNotConnected) err = m.Stop() assert.NoError(t, err) diff --git a/engine/datahistory_manager_test.go b/engine/datahistory_manager_test.go index 75a297f7..06e01b26 100644 --- a/engine/datahistory_manager_test.go +++ b/engine/datahistory_manager_test.go @@ -28,24 +28,16 @@ import ( func TestSetupDataHistoryManager(t *testing.T) { t.Parallel() _, err := SetupDataHistoryManager(nil, nil, nil) - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilExchangeManager) _, err = SetupDataHistoryManager(NewExchangeManager(), nil, nil) - if !errors.Is(err, errNilDatabaseConnectionManager) { - t.Errorf("error '%v', expected '%v'", err, errNilDatabaseConnectionManager) - } + assert.ErrorIs(t, err, errNilDatabaseConnectionManager) _, err = SetupDataHistoryManager(NewExchangeManager(), &DatabaseConnectionManager{}, nil) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) _, err = SetupDataHistoryManager(NewExchangeManager(), &DatabaseConnectionManager{}, &config.DataHistoryManager{}) - if !errors.Is(err, database.ErrNilInstance) { - t.Errorf("error '%v', expected '%v'", err, database.ErrNilInstance) - } + assert.ErrorIs(t, err, database.ErrNilInstance) dbInst := &database.Instance{} err = dbInst.SetConfig(&database.Config{Enabled: true}) @@ -92,14 +84,11 @@ func TestDataHistoryManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) + m = nil err = m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestDataHistoryManagerStop(t *testing.T) { @@ -110,58 +99,43 @@ func TestDataHistoryManagerStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil err = m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestUpsertJob(t *testing.T) { t.Parallel() m, _ := createDHM(t) err := m.UpsertJob(nil, false) - if !errors.Is(err, errNilJob) { - t.Errorf("error '%v', expected '%v'", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + dhj := &DataHistoryJob{} err = m.UpsertJob(dhj, false) - if !errors.Is(err, errNicknameUnset) { - t.Errorf("error '%v', expected '%v'", err, errNicknameUnset) - } + assert.ErrorIs(t, err, errNicknameUnset) + dhj.Nickname = "test1337" err = m.UpsertJob(dhj, false) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("error '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) dhj.Asset = asset.Spot err = m.UpsertJob(dhj, false) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("error '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) dhj.Exchange = strings.ToLower(testExchange) dhj.Pair = currency.NewPair(currency.BTC, currency.DOGE) err = m.UpsertJob(dhj, false) - if !errors.Is(err, errCurrencyNotEnabled) { - t.Errorf("error '%v', expected '%v'", err, errCurrencyNotEnabled) - } + assert.ErrorIs(t, err, errCurrencyNotEnabled) dhj.Pair = currency.NewBTCUSD() err = m.UpsertJob(dhj, false) - if !errors.Is(err, kline.ErrUnsupportedInterval) { - t.Errorf("error '%v', expected '%v'", err, kline.ErrUnsupportedInterval) - } + assert.ErrorIs(t, err, kline.ErrUnsupportedInterval) dhj.Interval = kline.OneHour err = m.UpsertJob(dhj, false) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("error '%v', expected '%v'", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) dhj.StartDate = time.Now().Add(-time.Hour) dhj.EndDate = time.Now() @@ -169,9 +143,7 @@ func TestUpsertJob(t *testing.T) { assert.NoError(t, err) err = m.UpsertJob(dhj, true) - if !errors.Is(err, errNicknameInUse) { - t.Errorf("error '%v', expected '%v'", err, errNicknameInUse) - } + assert.ErrorIs(t, err, errNicknameInUse) newJob := &DataHistoryJob{ Nickname: dhj.Nickname, @@ -194,9 +166,7 @@ func TestUpsertJob(t *testing.T) { PrerequisiteJobNickname: "hellomoto", } err = m.UpsertJob(newJob, false) - if !errors.Is(err, errInvalidDataHistoryDataType) { - t.Errorf("error '%v', expected '%v'", err, errInvalidDataHistoryDataType) - } + assert.ErrorIs(t, err, errInvalidDataHistoryDataType) newJob.DataType = dataHistoryTradeDataType err = m.UpsertJob(newJob, false) @@ -219,31 +189,23 @@ func TestSetJobStatus(t *testing.T) { assert.NoError(t, err) err = m.SetJobStatus("", "", 0) - if !errors.Is(err, errNicknameIDUnset) { - t.Errorf("error '%v', expected '%v'", err, errNicknameIDUnset) - } + assert.ErrorIs(t, err, errNicknameIDUnset) err = m.SetJobStatus("1337", "1337", 0) - if !errors.Is(err, errOnlyNicknameOrID) { - t.Errorf("error '%v', expected '%v'", err, errOnlyNicknameOrID) - } + assert.ErrorIs(t, err, errOnlyNicknameOrID) err = m.SetJobStatus(dhj.Nickname, "", dataHistoryStatusRemoved) assert.NoError(t, err) err = m.SetJobStatus("", dhj.ID.String(), dataHistoryStatusActive) - if !errors.Is(err, errBadStatus) { - t.Errorf("error '%v', expected '%v'", err, errBadStatus) - } + assert.ErrorIs(t, err, errBadStatus) j.Status = int64(dataHistoryStatusActive) err = m.SetJobStatus("", dhj.ID.String(), dataHistoryStatusPaused) assert.NoError(t, err) err = m.SetJobStatus("", dhj.ID.String(), dataHistoryStatusFailed) - if !errors.Is(err, errBadStatus) { - t.Errorf("error '%v', expected '%v'", err, errBadStatus) - } + assert.ErrorIs(t, err, errBadStatus) dhj.Status = dataHistoryStatusPaused err = m.SetJobStatus(dhj.Nickname, "", dataHistoryStatusActive) @@ -251,9 +213,7 @@ func TestSetJobStatus(t *testing.T) { dhj.Status = dataHistoryStatusRemoved err = m.SetJobStatus(dhj.Nickname, "", dataHistoryStatusActive) - if !errors.Is(err, errBadStatus) { - t.Errorf("error '%v', expected '%v'", err, errBadStatus) - } + assert.ErrorIs(t, err, errBadStatus) dhj.Status = dataHistoryStatusPaused err = m.SetJobStatus(dhj.Nickname, "", dataHistoryStatusRemoved) @@ -261,15 +221,11 @@ func TestSetJobStatus(t *testing.T) { atomic.StoreInt32(&m.started, 0) err = m.SetJobStatus("", dhj.ID.String(), 0) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil err = m.SetJobStatus("", dhj.ID.String(), 0) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetByNickname(t *testing.T) { @@ -298,15 +254,11 @@ func TestGetByNickname(t *testing.T) { atomic.StoreInt32(&m.started, 0) _, err = m.GetByNickname("test123", false) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.GetByNickname("test123", false) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetByID(t *testing.T) { @@ -328,24 +280,18 @@ func TestGetByID(t *testing.T) { assert.NoError(t, err) _, err = m.GetByID(uuid.UUID{}) - if !errors.Is(err, errEmptyID) { - t.Errorf("error '%v', expected '%v'", err, errEmptyID) - } + assert.ErrorIs(t, err, errEmptyID) _, err = m.GetByID(dhj.ID) assert.NoError(t, err) atomic.StoreInt32(&m.started, 0) _, err = m.GetByID(dhj.ID) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.GetByID(dhj.ID) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestRetrieveJobs(t *testing.T) { @@ -372,15 +318,11 @@ func TestRetrieveJobs(t *testing.T) { atomic.StoreInt32(&m.started, 0) _, err = m.retrieveJobs() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.retrieveJobs() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetActiveJobs(t *testing.T) { @@ -416,61 +358,44 @@ func TestGetActiveJobs(t *testing.T) { atomic.StoreInt32(&m.started, 0) _, err = m.GetActiveJobs() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.GetActiveJobs() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestValidateJob(t *testing.T) { t.Parallel() m, _ := createDHM(t) err := m.validateJob(nil) - if !errors.Is(err, errNilJob) { - t.Errorf("error '%v', expected '%v'", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + dhj := &DataHistoryJob{} err = m.validateJob(dhj) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("error '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) dhj.Asset = asset.Spot err = m.validateJob(dhj) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("error '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) dhj.Exchange = testExchange dhj.Pair = currency.NewPair(currency.BTC, currency.XRP) err = m.validateJob(dhj) - if !errors.Is(err, errCurrencyNotEnabled) { - t.Errorf("error '%v', expected '%v'", err, errCurrencyNotEnabled) - } + assert.ErrorIs(t, err, errCurrencyNotEnabled) dhj.Pair = currency.NewBTCUSD() err = m.validateJob(dhj) - if !errors.Is(err, kline.ErrUnsupportedInterval) { - t.Errorf("error '%v', expected '%v'", err, kline.ErrUnsupportedInterval) - } + assert.ErrorIs(t, err, kline.ErrUnsupportedInterval) dhj.Interval = kline.OneMin err = m.validateJob(dhj) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("error '%v', expected '%v'", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) dhj.StartDate = time.Now().Add(time.Minute) dhj.EndDate = time.Now().Add(time.Hour) err = m.validateJob(dhj) - if !errors.Is(err, common.ErrStartAfterTimeNow) { - t.Errorf("error '%v', expected '%v'", err, errInvalidTimes) - } + assert.ErrorIs(t, err, common.ErrStartAfterTimeNow) dhj.StartDate = time.Now().Add(-time.Hour * 60) dhj.EndDate = time.Now().Add(-time.Minute) @@ -489,15 +414,12 @@ func TestValidateJob(t *testing.T) { dhj.DataType = dataHistoryCandleValidationSecondarySourceType err = m.validateJob(dhj) - if !errors.Is(err, errExchangeNameUnset) { - t.Errorf("error '%v', expected '%v'", err, errExchangeNameUnset) - } + assert.ErrorIs(t, err, errExchangeNameUnset) + dhj.SecondaryExchangeSource = "lol" dhj.Exchange = "" err = m.validateJob(dhj) - if !errors.Is(err, errExchangeNameUnset) { - t.Errorf("error '%v', expected '%v'", err, errExchangeNameUnset) - } + assert.ErrorIs(t, err, errExchangeNameUnset) } func TestGetAllJobStatusBetween(t *testing.T) { @@ -528,15 +450,11 @@ func TestGetAllJobStatusBetween(t *testing.T) { m.started = 0 _, err = m.GetAllJobStatusBetween(time.Now().Add(-time.Hour), time.Now().Add(-time.Minute*30)) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.GetAllJobStatusBetween(time.Now().Add(-time.Hour), time.Now().Add(-time.Minute*30)) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestPrepareJobs(t *testing.T) { @@ -550,14 +468,11 @@ func TestPrepareJobs(t *testing.T) { } m.started = 0 _, err = m.PrepareJobs() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil _, err = m.PrepareJobs() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestCompareJobsToData(t *testing.T) { @@ -583,9 +498,7 @@ func TestCompareJobsToData(t *testing.T) { dhj.DataType = 1337 err = m.compareJobsToData(dhj) - if !errors.Is(err, errUnknownDataType) { - t.Errorf("error '%v', expected '%v'", err, errUnknownDataType) - } + assert.ErrorIs(t, err, errUnknownDataType) dhj.DataType = dataHistoryConvertCandlesDataType err = m.compareJobsToData(dhj) @@ -593,14 +506,11 @@ func TestCompareJobsToData(t *testing.T) { m.started = 0 err = m.compareJobsToData(dhj) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil err = m.compareJobsToData(dhj) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestRunJob(t *testing.T) { //nolint:tparallel // There is a race condition caused by the DataHistoryJob and it's a big change to fix. @@ -684,17 +594,13 @@ func TestRunJob(t *testing.T) { //nolint:tparallel // There is a race condition test.Status = dataHistoryIntervalIssuesFound err = m.runJob(test) - if !errors.Is(err, errJobInvalid) { - t.Errorf("error '%v', expected '%v'", err, errJobInvalid) - } + assert.ErrorIs(t, err, errJobInvalid) rh := test.rangeHolder test.Status = dataHistoryStatusActive test.rangeHolder = nil err = m.runJob(test) - if !errors.Is(err, errJobInvalid) { - t.Errorf("error '%v', expected '%v'", err, errJobInvalid) - } + assert.ErrorIs(t, err, errJobInvalid) test.rangeHolder = rh err = m.runJob(test) @@ -703,14 +609,11 @@ func TestRunJob(t *testing.T) { //nolint:tparallel // There is a race condition } var badM *DataHistoryManager err := badM.runJob(nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) + badM = &DataHistoryManager{} err = badM.runJob(nil) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestGenerateJobSummaryTest(t *testing.T) { @@ -737,15 +640,11 @@ func TestGenerateJobSummaryTest(t *testing.T) { atomic.StoreInt32(&m.started, 0) _, err = m.GenerateJobSummary("TestGenerateJobSummary") - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil _, err = m.GenerateJobSummary("TestGenerateJobSummary") - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestRunJobs(t *testing.T) { @@ -756,15 +655,11 @@ func TestRunJobs(t *testing.T) { atomic.StoreInt32(&m.started, 0) err = m.runJobs() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil err = m.runJobs() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestConverters(t *testing.T) { @@ -935,9 +830,8 @@ func TestProcessCandleData(t *testing.T) { t.Parallel() m, _ := createDHM(t) _, err := m.processCandleData(nil, nil, time.Time{}, time.Time{}, 0) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Nickname: "", Exchange: testExchange, @@ -948,9 +842,7 @@ func TestProcessCandleData(t *testing.T) { Interval: kline.OneHour, } _, err = m.processCandleData(j, nil, time.Time{}, time.Time{}, 0) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received %v expected %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) em := NewExchangeManager() exch, err := em.NewExchangeByName(testExchange) @@ -961,9 +853,7 @@ func TestProcessCandleData(t *testing.T) { IBotExchange: exch, } _, err = m.processCandleData(j, exch, time.Time{}, time.Time{}, 0) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received %v expected %v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) m.candleSaver = dataHistoryCandleSaver j.rangeHolder, err = kline.CalculateCandleDateRanges(j.StartDate, j.EndDate, j.Interval, 1337) @@ -988,9 +878,8 @@ func TestProcessTradeData(t *testing.T) { t.Parallel() m, _ := createDHM(t) _, err := m.processTradeData(nil, nil, time.Time{}, time.Time{}, 0) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Nickname: "", Exchange: testExchange, @@ -1001,9 +890,7 @@ func TestProcessTradeData(t *testing.T) { Interval: kline.OneHour, } _, err = m.processTradeData(j, nil, time.Time{}, time.Time{}, 0) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received %v expected %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) em := NewExchangeManager() exch, err := em.NewExchangeByName(testExchange) @@ -1014,9 +901,8 @@ func TestProcessTradeData(t *testing.T) { IBotExchange: exch, } _, err = m.processTradeData(j, exch, time.Time{}, time.Time{}, 0) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received %v expected %v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) + j.rangeHolder, err = kline.CalculateCandleDateRanges(j.StartDate, j.EndDate, j.Interval, 1337) if err != nil { t.Error(err) @@ -1040,9 +926,8 @@ func TestConvertJobTradesToCandles(t *testing.T) { t.Parallel() m, _ := createDHM(t) _, err := m.convertTradesToCandles(nil, time.Time{}, time.Time{}) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Nickname: "", Exchange: testExchange, @@ -1053,9 +938,8 @@ func TestConvertJobTradesToCandles(t *testing.T) { Interval: kline.OneHour, } _, err = m.convertTradesToCandles(j, time.Time{}, time.Time{}) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received %v expected %v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) + m.tradeLoader = dataHistoryTraderLoader m.candleSaver = dataHistoryCandleSaver r, err := m.convertTradesToCandles(j, j.StartDate, j.EndDate) @@ -1071,9 +955,8 @@ func TestUpscaleJobCandleData(t *testing.T) { m, _ := createDHM(t) m.candleSaver = dataHistoryCandleSaver _, err := m.convertCandleData(nil, time.Time{}, time.Time{}) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Nickname: "", Exchange: testExchange, @@ -1085,9 +968,7 @@ func TestUpscaleJobCandleData(t *testing.T) { ConversionInterval: kline.OneDay, } _, err = m.convertCandleData(j, time.Time{}, time.Time{}) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received %v expected %v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) r, err := m.convertCandleData(j, j.StartDate, j.EndDate) assert.NoError(t, err) @@ -1102,9 +983,8 @@ func TestValidateCandles(t *testing.T) { m, _ := createDHM(t) m.candleSaver = dataHistoryCandleSaver _, err := m.validateCandles(nil, nil, time.Time{}, time.Time{}) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Nickname: "", Exchange: testExchange, @@ -1115,9 +995,7 @@ func TestValidateCandles(t *testing.T) { Interval: kline.OneHour, } _, err = m.validateCandles(j, nil, time.Time{}, time.Time{}) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received %v expected %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) em := NewExchangeManager() exch, err := em.NewExchangeByName(testExchange) @@ -1128,9 +1006,8 @@ func TestValidateCandles(t *testing.T) { IBotExchange: exch, } _, err = m.validateCandles(j, exch, time.Time{}, time.Time{}) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received %v expected %v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) + j.rangeHolder, err = kline.CalculateCandleDateRanges(j.StartDate, j.EndDate, j.Interval, 1337) if err != nil { t.Error(err) @@ -1165,20 +1042,15 @@ func TestSetJobRelationship(t *testing.T) { assert.NoError(t, err) err = m.SetJobRelationship("", "") - if !errors.Is(err, errNicknameUnset) { - t.Errorf("received %v expected %v", err, errNicknameUnset) - } + assert.ErrorIs(t, err, errNicknameUnset) + m.started = 0 err = m.SetJobRelationship("", "") - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received %v expected %v", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) m = nil err = m.SetJobRelationship("", "") - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received %v expected %v", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestCheckCandleIssue(t *testing.T) { @@ -1255,9 +1127,8 @@ func TestCompletionCheck(t *testing.T) { t.Parallel() m, _ := createDHM(t) err := m.completeJob(nil, false, false) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) + j := &DataHistoryJob{ Status: dataHistoryStatusActive, } @@ -1283,9 +1154,7 @@ func TestCompletionCheck(t *testing.T) { } err = m.completeJob(j, true, true) - if !errors.Is(err, errJobInvalid) { - t.Errorf("received %v expected %v", err, errJobInvalid) - } + assert.ErrorIs(t, err, errJobInvalid) } func TestSaveCandlesInBatches(t *testing.T) { @@ -1294,27 +1163,19 @@ func TestSaveCandlesInBatches(t *testing.T) { candleSaver: dataHistoryCandleSaver, } err := dhm.saveCandlesInBatches(nil, nil, nil) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received %v expected %v", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) dhm.started = 1 err = dhm.saveCandlesInBatches(nil, nil, nil) - if !errors.Is(err, errNilJob) { - t.Errorf("received %v expected %v", err, errNilJob) - } + assert.ErrorIs(t, err, errNilJob) job := &DataHistoryJob{} err = dhm.saveCandlesInBatches(job, nil, nil) - if !errors.Is(err, errNilCandles) { - t.Errorf("received %v expected %v", err, errNilCandles) - } + assert.ErrorIs(t, err, errNilCandles) candles := &kline.Item{} err = dhm.saveCandlesInBatches(job, candles, nil) - if !errors.Is(err, errNilResult) { - t.Errorf("received %v expected %v", err, errNilResult) - } + assert.ErrorIs(t, err, errNilResult) result := &DataHistoryJobResult{} err = dhm.saveCandlesInBatches(job, candles, result) diff --git a/engine/depositaddress_test.go b/engine/depositaddress_test.go index b0a210c6..00ec7a0f 100644 --- a/engine/depositaddress_test.go +++ b/engine/depositaddress_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "testing" "github.com/stretchr/testify/assert" @@ -80,9 +79,7 @@ func TestSync(t *testing.T) { }, }, }) - if !errors.Is(err, ErrDepositAddressStoreIsNil) { - t.Errorf("received %v, expected %v", err, ErrDepositAddressStoreIsNil) - } + assert.ErrorIs(t, err, ErrDepositAddressStoreIsNil) m = nil err = m.Sync(map[string]map[string][]deposit.Address{ @@ -94,18 +91,14 @@ func TestSync(t *testing.T) { }, }, }) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received %v, expected %v", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetDepositAddressByExchangeAndCurrency(t *testing.T) { t.Parallel() m := SetupDepositAddressManager() _, err := m.GetDepositAddressByExchangeAndCurrency("", "", currency.BTC) - if !errors.Is(err, ErrDepositAddressStoreIsNil) { - t.Errorf("received %v, expected %v", err, ErrDepositAddressStoreIsNil) - } + assert.ErrorIs(t, err, ErrDepositAddressStoreIsNil) m.store = map[string]map[string][]deposit.Address{ bitStamp: { @@ -132,21 +125,16 @@ func TestGetDepositAddressByExchangeAndCurrency(t *testing.T) { }, } _, err = m.GetDepositAddressByExchangeAndCurrency("asdf", "", currency.BTC) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received %v, expected %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) + _, err = m.GetDepositAddressByExchangeAndCurrency(bitStamp, "", currency.LTC) - if !errors.Is(err, ErrDepositAddressNotFound) { - t.Errorf("received %v, expected %v", err, ErrDepositAddressNotFound) - } + assert.ErrorIs(t, err, ErrDepositAddressNotFound) + _, err = m.GetDepositAddressByExchangeAndCurrency(bitStamp, "", currency.BNB) - if !errors.Is(err, errNoDepositAddressesRetrieved) { - t.Errorf("received %v, expected %v", err, errNoDepositAddressesRetrieved) - } + assert.ErrorIs(t, err, errNoDepositAddressesRetrieved) + _, err = m.GetDepositAddressByExchangeAndCurrency(bitStamp, "NON-EXISTENT-CHAIN", currency.USDT) - if !errors.Is(err, errDepositAddressChainNotFound) { - t.Errorf("received %v, expected %v", err, errDepositAddressChainNotFound) - } + assert.ErrorIs(t, err, errDepositAddressChainNotFound) if r, _ := m.GetDepositAddressByExchangeAndCurrency(bitStamp, "ErC20", currency.USDT); r.Address != "0x1b" && r.Chain != "ERC20" { t.Error("unexpected values") @@ -165,9 +153,7 @@ func TestGetDepositAddressesByExchange(t *testing.T) { t.Parallel() m := SetupDepositAddressManager() _, err := m.GetDepositAddressesByExchange("") - if !errors.Is(err, ErrDepositAddressStoreIsNil) { - t.Errorf("received %v, expected %v", err, ErrDepositAddressStoreIsNil) - } + assert.ErrorIs(t, err, ErrDepositAddressStoreIsNil) m.store = map[string]map[string][]deposit.Address{ bitStamp: { @@ -179,9 +165,7 @@ func TestGetDepositAddressesByExchange(t *testing.T) { }, } _, err = m.GetDepositAddressesByExchange("non-existent") - if !errors.Is(err, ErrDepositAddressNotFound) { - t.Errorf("received %v, expected %v", err, ErrDepositAddressNotFound) - } + assert.ErrorIs(t, err, ErrDepositAddressNotFound) _, err = m.GetDepositAddressesByExchange(bitStamp) assert.NoError(t, err) diff --git a/engine/engine_test.go b/engine/engine_test.go index 1d793845..b2e8ab24 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "os" "slices" "strings" @@ -158,9 +157,7 @@ func TestStartStopTwoDoesNotCausePanic(t *testing.T) { func TestGetExchangeByName(t *testing.T) { t.Parallel() _, err := (*ExchangeManager)(nil).GetExchangeByName("tehehe") - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received: %v expected: %v", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) em := NewExchangeManager() exch, err := em.NewExchangeByName(testExchange) @@ -190,9 +187,7 @@ func TestGetExchangeByName(t *testing.T) { } _, err = e.GetExchangeByName("Asdasd") - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received: %v expected: %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestUnloadExchange(t *testing.T) { @@ -211,9 +206,7 @@ func TestUnloadExchange(t *testing.T) { Config: &config.Config{Exchanges: []config.Exchange{{Name: testExchange}}}, } err = e.UnloadExchange("asdf") - if !errors.Is(err, config.ErrExchangeNotFound) { - t.Errorf("error '%v', expected '%v'", err, config.ErrExchangeNotFound) - } + assert.ErrorIs(t, err, config.ErrExchangeNotFound) err = e.UnloadExchange(testExchange) if err != nil { @@ -222,9 +215,7 @@ func TestUnloadExchange(t *testing.T) { } err = e.UnloadExchange(testExchange) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("error '%v', expected '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestDryRunParamInteraction(t *testing.T) { @@ -317,9 +308,7 @@ func TestRegisterWebsocketDataHandler(t *testing.T) { t.Parallel() var e *Engine err := e.RegisterWebsocketDataHandler(nil, false) - if !errors.Is(err, errNilBot) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) - } + require.ErrorIs(t, err, errNilBot) e = &Engine{WebsocketRoutineManager: &WebsocketRoutineManager{}} err = e.RegisterWebsocketDataHandler(func(_ string, _ any) error { return nil }, false) @@ -330,9 +319,7 @@ func TestSetDefaultWebsocketDataHandler(t *testing.T) { t.Parallel() var e *Engine err := e.SetDefaultWebsocketDataHandler() - if !errors.Is(err, errNilBot) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilBot) - } + require.ErrorIs(t, err, errNilBot) e = &Engine{WebsocketRoutineManager: &WebsocketRoutineManager{}} err = e.SetDefaultWebsocketDataHandler() diff --git a/engine/event_manager_test.go b/engine/event_manager_test.go index 5e8531bf..9c2b726c 100644 --- a/engine/event_manager_test.go +++ b/engine/event_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "sync/atomic" "testing" @@ -17,14 +16,10 @@ import ( func TestSetupEventManager(t *testing.T) { t.Parallel() _, err := setupEventManager(nil, nil, 0, false) - if !errors.Is(err, errNilComManager) { - t.Errorf("error '%v', expected '%v'", err, errNilComManager) - } + assert.ErrorIs(t, err, errNilComManager) _, err = setupEventManager(&CommunicationManager{}, nil, 0, false) - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilExchangeManager) - } + assert.ErrorIs(t, err, errNilExchangeManager) m, err := setupEventManager(&CommunicationManager{}, &ExchangeManager{}, 0, false) require.NoError(t, err) @@ -45,15 +40,11 @@ func TestEventManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) m = nil err = m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestEventManagerIsRunning(t *testing.T) { @@ -89,14 +80,11 @@ func TestEventManagerStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + m = nil err = m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestEventManagerAdd(t *testing.T) { @@ -106,16 +94,14 @@ func TestEventManagerAdd(t *testing.T) { assert.NoError(t, err) _, err = m.Add("", "", EventConditionParams{}, currency.NewPair(currency.BTC, currency.USDC), asset.Spot, "") - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + err = m.Start() assert.NoError(t, err) _, err = m.Add("", "", EventConditionParams{}, currency.NewPair(currency.BTC, currency.USDC), asset.Spot, "") - if !errors.Is(err, errExchangeDisabled) { - t.Errorf("error '%v', expected '%v'", err, errExchangeDisabled) - } + assert.ErrorIs(t, err, errExchangeDisabled) + exch, err := em.NewExchangeByName(testExchange) if err != nil { t.Fatal(err) @@ -125,9 +111,7 @@ func TestEventManagerAdd(t *testing.T) { require.NoError(t, err) _, err = m.Add(testExchange, "", EventConditionParams{}, currency.NewPair(currency.BTC, currency.USDC), asset.Spot, "") - if !errors.Is(err, errInvalidItem) { - t.Errorf("error '%v', expected '%v'", err, errInvalidItem) - } + assert.ErrorIs(t, err, errInvalidItem) cond := EventConditionParams{ Condition: ConditionGreaterThan, @@ -135,9 +119,7 @@ func TestEventManagerAdd(t *testing.T) { OrderbookAmount: 1337, } _, err = m.Add(testExchange, ItemPrice, cond, currency.NewPair(currency.BTC, currency.USDC), asset.Spot, "") - if !errors.Is(err, errInvalidAction) { - t.Errorf("error '%v', expected '%v'", err, errInvalidAction) - } + assert.ErrorIs(t, err, errInvalidAction) _, err = m.Add(testExchange, ItemPrice, cond, currency.NewPair(currency.BTC, currency.USDC), asset.Spot, ActionTest) assert.NoError(t, err) diff --git a/engine/exchange_manager_test.go b/engine/exchange_manager_test.go index 02d82863..22c322fe 100644 --- a/engine/exchange_manager_test.go +++ b/engine/exchange_manager_test.go @@ -1,11 +1,11 @@ package engine import ( - "errors" "fmt" "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" "github.com/thrasher-corp/gocryptotrader/exchanges/bitfinex" @@ -33,24 +33,20 @@ func TestExchangeManagerAdd(t *testing.T) { t.Parallel() var m *ExchangeManager err := m.Add(nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = NewExchangeManager() err = m.Add(nil) - if !errors.Is(err, errExchangeIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeIsNil) - } + require.ErrorIs(t, err, errExchangeIsNil) + b := new(bitfinex.Bitfinex) b.SetDefaults() err = m.Add(b) require.NoError(t, err) err = m.Add(b) - if !errors.Is(err, ErrExchangeAlreadyLoaded) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeAlreadyLoaded) - } + require.ErrorIs(t, err, ErrExchangeAlreadyLoaded) + exchanges, err := m.GetExchanges() if err != nil { t.Error("no exchange manager found") @@ -64,9 +60,7 @@ func TestExchangeManagerGetExchanges(t *testing.T) { t.Parallel() var m *ExchangeManager _, err := m.GetExchanges() - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = NewExchangeManager() exchanges, err := m.GetExchanges() @@ -94,21 +88,15 @@ func TestExchangeManagerRemoveExchange(t *testing.T) { t.Parallel() var m *ExchangeManager err := m.RemoveExchange("") - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = NewExchangeManager() err = m.RemoveExchange("") - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) err = m.RemoveExchange("Bitfinex") - if !errors.Is(err, ErrExchangeNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeNotFound) - } + require.ErrorIs(t, err, ErrExchangeNotFound) b := new(bitfinex.Bitfinex) b.SetDefaults() @@ -116,9 +104,7 @@ func TestExchangeManagerRemoveExchange(t *testing.T) { require.NoError(t, err) err = m.RemoveExchange("Bitstamp") - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received: %v but expected: %v", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) err = m.RemoveExchange("BiTFiNeX") require.NoError(t, err) @@ -134,23 +120,17 @@ func TestExchangeManagerRemoveExchange(t *testing.T) { require.NoError(t, err) err = m.RemoveExchange("BiTFiNeX") - if !errors.Is(err, errExpectedTestError) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExpectedTestError) - } + require.ErrorIs(t, err, errExpectedTestError) } func TestNewExchangeByName(t *testing.T) { var m *ExchangeManager _, err := m.NewExchangeByName("") - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = NewExchangeManager() _, err = m.NewExchangeByName("") - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) exchanges := exchange.Exchanges exchanges = append(exchanges, "fake") @@ -175,9 +155,7 @@ func TestNewExchangeByName(t *testing.T) { require.NoError(t, err) _, err = m.NewExchangeByName("bitfinex") - if !errors.Is(err, ErrExchangeAlreadyLoaded) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeAlreadyLoaded) - } + require.ErrorIs(t, err, ErrExchangeAlreadyLoaded) } type ExchangeBuilder struct{} @@ -215,9 +193,7 @@ func TestExchangeManagerShutdown(t *testing.T) { t.Parallel() var m *ExchangeManager err := m.Shutdown(-1) - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = NewExchangeManager() err = m.Shutdown(-1) diff --git a/engine/helpers_test.go b/engine/helpers_test.go index ff15cd71..da5600f8 100644 --- a/engine/helpers_test.go +++ b/engine/helpers_test.go @@ -107,9 +107,7 @@ func TestGetSubsystemsStatus(t *testing.T) { func TestGetRPCEndpoints(t *testing.T) { _, err := (&Engine{}).GetRPCEndpoints() - if !errors.Is(err, errNilConfig) { - t.Fatalf("received: %v, but expected: %v", err, errNilConfig) - } + require.ErrorIs(t, err, errNilConfig) m, err := (&Engine{Config: &config.Config{}}).GetRPCEndpoints() require.NoError(t, err) @@ -220,21 +218,10 @@ func TestSetSubsystem(t *testing.T) { //nolint // TO-DO: Fix race t.Parallel() u t.Run(tt.Subsystem, func(t *testing.T) { t.Parallel() err := tt.Engine.SetSubsystem(tt.Subsystem, true) - if !errors.Is(err, tt.EnableError) { - t.Fatalf( - "while enabled %s subsystem received: %#v, but expected: %v", - tt.Subsystem, - err, - tt.EnableError) - } + require.ErrorIs(t, err, tt.EnableError) + err = tt.Engine.SetSubsystem(tt.Subsystem, false) - if !errors.Is(err, tt.DisableError) { - t.Fatalf( - "while disabling %s subsystem received: %#v, but expected: %v", - tt.Subsystem, - err, - tt.DisableError) - } + require.ErrorIs(t, err, tt.DisableError) }) } } @@ -998,12 +985,10 @@ func TestGetCryptocurrencyDepositAddressesByExchange(t *testing.T) { const exchName = "fake" e := createDepositEngine(&fakeDepositExchangeOpts{SupportsAuth: true, SupportsMultiChain: true}) _, err := e.GetCryptocurrencyDepositAddressesByExchange(exchName) - if err != nil { - t.Error(err) - } - if _, err = e.GetCryptocurrencyDepositAddressesByExchange("non-existent"); !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received %s, expected: %s", err, ErrExchangeNotFound) - } + assert.NoError(t, err, "GetCryptocurrencyDepositAddressesByExchange should not error") + _, err = e.GetCryptocurrencyDepositAddressesByExchange("non-existent") + assert.ErrorIs(t, err, ErrExchangeNotFound) + e.DepositAddressManager = SetupDepositAddressManager() _, err = e.GetCryptocurrencyDepositAddressesByExchange(exchName) if err == nil { @@ -1087,9 +1072,8 @@ func TestGetExchangeNames(t *testing.T) { for i := range bot.Config.Exchanges { exch, err := bot.ExchangeManager.NewExchangeByName(bot.Config.Exchanges[i].Name) - if err != nil && !errors.Is(err, ErrExchangeAlreadyLoaded) { - t.Fatal(err) - } + require.Truef(t, err == nil || errors.Is(err, ErrExchangeAlreadyLoaded), + "%s NewExchangeByName must not error: %s", bot.Config.Exchanges[i].Name, err) if exch != nil { exch.SetDefaults() err = bot.ExchangeManager.Add(exch) @@ -1264,9 +1248,7 @@ func TestNewSupportedExchangeByName(t *testing.T) { } _, err := NewSupportedExchangeByName("") - if !errors.Is(err, ErrExchangeNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestNewExchangeByNameWithDefaults(t *testing.T) { diff --git a/engine/ntp_manager_test.go b/engine/ntp_manager_test.go index 03e0a1d0..6c01a3be 100644 --- a/engine/ntp_manager_test.go +++ b/engine/ntp_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "testing" "time" @@ -11,13 +10,11 @@ import ( func TestSetupNTPManager(t *testing.T) { _, err := setupNTPManager(nil, false) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) + _, err = setupNTPManager(&config.NTPClientConfig{}, false) - if !errors.Is(err, errNilNTPConfigValues) { - t.Errorf("error '%v', expected '%v'", err, errNilNTPConfigValues) - } + assert.ErrorIs(t, err, errNilNTPConfigValues) + sec := time.Second cfg := &config.NTPClientConfig{ AllowedDifference: &sec, @@ -62,9 +59,7 @@ func TestNTPManagerIsRunning(t *testing.T) { func TestNTPManagerStart(t *testing.T) { var m *ntpManager err := m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) sec := time.Second cfg := &config.NTPClientConfig{ @@ -75,26 +70,20 @@ func TestNTPManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, errNTPManagerDisabled) { - t.Errorf("error '%v', expected '%v'", err, errNTPManagerDisabled) - } + assert.ErrorIs(t, err, errNTPManagerDisabled) m.level = 1 err = m.Start() assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) } func TestNTPManagerStop(t *testing.T) { var m *ntpManager err := m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) sec := time.Second cfg := &config.NTPClientConfig{ @@ -106,9 +95,7 @@ func TestNTPManagerStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) @@ -120,9 +107,8 @@ func TestNTPManagerStop(t *testing.T) { func TestFetchNTPTime(t *testing.T) { var m *ntpManager _, err := m.FetchNTPTime() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) + sec := time.Second cfg := &config.NTPClientConfig{ AllowedDifference: &sec, @@ -133,9 +119,7 @@ func TestFetchNTPTime(t *testing.T) { assert.NoError(t, err) _, err = m.FetchNTPTime() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) @@ -168,9 +152,7 @@ func TestProcessTime(t *testing.T) { assert.NoError(t, err) err = m.processTime() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) diff --git a/engine/order_manager_test.go b/engine/order_manager_test.go index 896362e5..363c5460 100644 --- a/engine/order_manager_test.go +++ b/engine/order_manager_test.go @@ -183,9 +183,8 @@ func TestSetupOrderManager(t *testing.T) { func TestOrderManagerStart(t *testing.T) { var m *OrderManager err := m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) + var wg sync.WaitGroup m, err = SetupOrderManager(NewExchangeManager(), &CommunicationManager{}, &wg, &config.OrderManager{}) assert.NoError(t, err) @@ -194,9 +193,7 @@ func TestOrderManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) } func TestOrderManagerIsRunning(t *testing.T) { @@ -224,18 +221,14 @@ func TestOrderManagerIsRunning(t *testing.T) { func TestOrderManagerStop(t *testing.T) { var m *OrderManager err := m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) var wg sync.WaitGroup m, err = SetupOrderManager(NewExchangeManager(), &CommunicationManager{}, &wg, &config.OrderManager{}) assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) @@ -596,9 +589,7 @@ func TestSubmit(t *testing.T) { m.cfg.AllowedPairs = nil _, err = m.Submit(t.Context(), o) - if !errors.Is(err, exchange.ErrAuthenticationSupportNotEnabled) { - t.Errorf("received: %v but expected: %v", err, exchange.ErrAuthenticationSupportNotEnabled) - } + assert.ErrorIs(t, err, exchange.ErrAuthenticationSupportNotEnabled) err = m.orderStore.add(&order.Detail{ Exchange: testExchange, @@ -1146,20 +1137,15 @@ func TestGetFuturesPositionsForExchange(t *testing.T) { o := &OrderManager{} cp := currency.NewBTCUSDT() _, err := o.GetFuturesPositionsForExchange("test", asset.Spot, cp) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + o.started = 1 o.orderStore.futuresPositionController = futures.SetupPositionController() _, err = o.GetFuturesPositionsForExchange("test", asset.Spot, cp) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) _, err = o.GetFuturesPositionsForExchange("test", asset.Futures, cp) - if !errors.Is(err, futures.ErrPositionNotFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrPositionNotFound) - } + assert.ErrorIs(t, err, futures.ErrPositionNotFound) err = o.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ OrderID: "test", @@ -1182,9 +1168,7 @@ func TestGetFuturesPositionsForExchange(t *testing.T) { o = nil _, err = o.GetFuturesPositionsForExchange("test", asset.Futures, cp) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestClearFuturesPositionsForExchange(t *testing.T) { @@ -1192,20 +1176,15 @@ func TestClearFuturesPositionsForExchange(t *testing.T) { o := &OrderManager{} cp := currency.NewBTCUSDT() err := o.ClearFuturesTracking("test", asset.Spot, cp) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + o.started = 1 o.orderStore.futuresPositionController = futures.SetupPositionController() err = o.ClearFuturesTracking("test", asset.Spot, cp) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) err = o.ClearFuturesTracking("test", asset.Futures, cp) - if !errors.Is(err, futures.ErrPositionNotFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrPositionNotFound) - } + assert.ErrorIs(t, err, futures.ErrPositionNotFound) err = o.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ OrderID: "test", @@ -1231,9 +1210,7 @@ func TestClearFuturesPositionsForExchange(t *testing.T) { o = nil err = o.ClearFuturesTracking("test", asset.Futures, cp) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { @@ -1241,20 +1218,15 @@ func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { o := &OrderManager{} cp := currency.NewBTCUSDT() _, err := o.UpdateOpenPositionUnrealisedPNL("test", asset.Spot, cp, 1, time.Now()) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + o.started = 1 o.orderStore.futuresPositionController = futures.SetupPositionController() _, err = o.UpdateOpenPositionUnrealisedPNL("test", asset.Spot, cp, 1, time.Now()) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) _, err = o.UpdateOpenPositionUnrealisedPNL("test", asset.Futures, cp, 1, time.Now()) - if !errors.Is(err, futures.ErrPositionNotFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrPositionNotFound) - } + assert.ErrorIs(t, err, futures.ErrPositionNotFound) err = o.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ OrderID: "test", @@ -1277,9 +1249,7 @@ func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { o = nil _, err = o.UpdateOpenPositionUnrealisedPNL("test", asset.Spot, cp, 1, time.Now()) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestSubmitFakeOrder(t *testing.T) { @@ -1353,19 +1323,16 @@ func TestUpdateExisting(t *testing.T) { s := &store{} s.Orders = make(map[string][]*order.Detail) err := s.updateExisting(nil) - if !errors.Is(err, errNilOrder) { - t.Errorf("received '%v', expected '%v'", err, errNilOrder) - } + assert.ErrorIs(t, err, errNilOrder) + od := &order.Detail{Exchange: testExchange} err = s.updateExisting(od) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v', expected '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) + s.Orders[strings.ToLower(testExchange)] = nil err = s.updateExisting(od) - if !errors.Is(err, ErrOrderNotFound) { - t.Errorf("received '%v', expected '%v'", err, ErrOrderNotFound) - } + assert.ErrorIs(t, err, ErrOrderNotFound) + od.Exchange = testExchange od.AssetType = asset.Futures od.OrderID = "123" @@ -1413,20 +1380,15 @@ func TestOrderManagerAdd(t *testing.T) { t.Parallel() o := &OrderManager{} err := o.Add(nil) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) + o.started = 1 err = o.Add(nil) - if !errors.Is(err, errNilOrder) { - t.Errorf("received '%v', expected '%v'", err, errNilOrder) - } + assert.ErrorIs(t, err, errNilOrder) o = nil err = o.Add(nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetAllOpenFuturesPositions(t *testing.T) { @@ -1437,23 +1399,17 @@ func TestGetAllOpenFuturesPositions(t *testing.T) { o.started = 0 _, err = o.GetAllOpenFuturesPositions() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) o.started = 1 o.activelyTrackFuturesPositions = true o.orderStore.futuresPositionController = futures.SetupPositionController() _, err = o.GetAllOpenFuturesPositions() - if !errors.Is(err, futures.ErrNoPositionsFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNoPositionsFound) - } + assert.ErrorIs(t, err, futures.ErrNoPositionsFound) o = nil _, err = o.GetAllOpenFuturesPositions() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestGetOpenFuturesPosition(t *testing.T) { @@ -1465,15 +1421,11 @@ func TestGetOpenFuturesPosition(t *testing.T) { o.started = 0 cp := currency.NewPair(currency.BTC, currency.PERP) _, err = o.GetOpenFuturesPosition(testExchange, asset.Spot, cp) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) o.started = 1 _, err = o.GetOpenFuturesPosition(testExchange, asset.Spot, cp) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) em := NewExchangeManager() exch, err := em.NewExchangeByName("binance") @@ -1516,14 +1468,10 @@ func TestGetOpenFuturesPosition(t *testing.T) { o.started = 1 _, err = o.GetOpenFuturesPosition(testExchange, asset.Spot, cp) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) _, err = o.GetOpenFuturesPosition(testExchange, asset.Futures, cp) - if !errors.Is(err, futures.ErrPositionNotFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrPositionNotFound) - } + assert.ErrorIs(t, err, futures.ErrPositionNotFound) err = o.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ AssetType: asset.Futures, @@ -1542,18 +1490,15 @@ func TestGetOpenFuturesPosition(t *testing.T) { o = nil _, err = o.GetOpenFuturesPosition(testExchange, asset.Spot, cp) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestProcessFuturesPositions(t *testing.T) { t.Parallel() o := &OrderManager{} err := o.processFuturesPositions(nil, nil) - if !errors.Is(err, errFuturesTrackingDisabled) { - t.Errorf("received '%v', expected '%v'", err, errFuturesTrackingDisabled) - } + assert.ErrorIs(t, err, errFuturesTrackingDisabled) + em := NewExchangeManager() exch, err := em.NewExchangeByName("binance") if err != nil { @@ -1601,9 +1546,7 @@ func TestProcessFuturesPositions(t *testing.T) { o.started = 1 err = o.processFuturesPositions(fakeExchange, nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) position := &futures.PositionResponse{ Asset: asset.Spot, @@ -1611,9 +1554,7 @@ func TestProcessFuturesPositions(t *testing.T) { Orders: nil, } err = o.processFuturesPositions(fakeExchange, position) - if !errors.Is(err, errNilOrder) { - t.Errorf("received '%v', expected '%v'", err, errNilOrder) - } + assert.ErrorIs(t, err, errNilOrder) od := &order.Detail{ AssetType: asset.Spot, @@ -1629,9 +1570,7 @@ func TestProcessFuturesPositions(t *testing.T) { *od, } err = o.processFuturesPositions(fakeExchange, position) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) position.Orders[0].AssetType = asset.Futures position.Asset = asset.Futures diff --git a/engine/portfolio_manager_test.go b/engine/portfolio_manager_test.go index e8b73913..70da56a8 100644 --- a/engine/portfolio_manager_test.go +++ b/engine/portfolio_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "sync" "testing" @@ -11,9 +10,7 @@ import ( func TestSetupPortfolioManager(t *testing.T) { _, err := setupPortfolioManager(nil, 0, nil) - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilExchangeManager) - } + assert.ErrorIs(t, err, errNilExchangeManager) m, err := setupPortfolioManager(NewExchangeManager(), 0, nil) assert.NoError(t, err) @@ -49,42 +46,32 @@ func TestPortfolioManagerStart(t *testing.T) { var m *portfolioManager var wg sync.WaitGroup err := m.Start(nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) m, err = setupPortfolioManager(NewExchangeManager(), 0, nil) assert.NoError(t, err) err = m.Start(nil) - if !errors.Is(err, errNilWaitGroup) { - t.Errorf("error '%v', expected '%v'", err, errNilWaitGroup) - } + assert.ErrorIs(t, err, errNilWaitGroup) err = m.Start(&wg) assert.NoError(t, err) err = m.Start(&wg) - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) } func TestPortfolioManagerStop(t *testing.T) { var m *portfolioManager var wg sync.WaitGroup err := m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) m, err = setupPortfolioManager(NewExchangeManager(), 0, nil) assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start(&wg) assert.NoError(t, err) diff --git a/engine/rpcserver_test.go b/engine/rpcserver_test.go index 9075d448..c94c096f 100644 --- a/engine/rpcserver_test.go +++ b/engine/rpcserver_test.go @@ -559,9 +559,8 @@ func TestGetSavedTrades(t *testing.T) { defer CleanRPCTest(t, engerino) s := RPCServer{Engine: engerino} _, err := s.GetSavedTrades(t.Context(), &gctrpc.GetSavedTradesRequest{}) - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidArguments) + _, err = s.GetSavedTrades(t.Context(), &gctrpc.GetSavedTradesRequest{ Exchange: fakeExchangeName, Pair: &gctrpc.CurrencyPair{ @@ -573,9 +572,8 @@ func TestGetSavedTrades(t *testing.T) { Start: time.Date(2020, 0, 0, 0, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), End: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC).Format(common.SimpleTimeFormatWithTimezone), }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) + _, err = s.GetSavedTrades(t.Context(), &gctrpc.GetSavedTradesRequest{ Exchange: testExchange, Pair: &gctrpc.CurrencyPair{ @@ -630,9 +628,7 @@ func TestConvertTradesToCandles(t *testing.T) { s := RPCServer{Engine: engerino} // bad param test _, err := s.ConvertTradesToCandles(t.Context(), &gctrpc.ConvertTradesToCandlesRequest{}) - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidArguments) // bad exchange test _, err = s.ConvertTradesToCandles(t.Context(), &gctrpc.ConvertTradesToCandlesRequest{ @@ -647,9 +643,7 @@ func TestConvertTradesToCandles(t *testing.T) { End: time.Date(2020, 0, 0, 1, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), TimeInterval: int64(kline.OneHour.Duration()), }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) // no trades test _, err = s.ConvertTradesToCandles(t.Context(), &gctrpc.ConvertTradesToCandlesRequest{ @@ -664,9 +658,7 @@ func TestConvertTradesToCandles(t *testing.T) { End: time.Date(2020, 0, 0, 1, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), TimeInterval: int64(kline.OneHour.Duration()), }) - if !errors.Is(err, errNoTrades) { - t.Errorf("received '%v' expected '%v'", err, errNoTrades) - } + assert.ErrorIs(t, err, errNoTrades) // add a trade err = sqltrade.Insert(sqltrade.Data{ @@ -782,9 +774,7 @@ func TestGetHistoricCandles(t *testing.T) { End: defaultEnd.Format(common.SimpleTimeFormatWithTimezone), AssetType: asset.Spot.String(), }) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received '%v', expected '%v'", err, ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) _, err = s.GetHistoricCandles(t.Context(), &gctrpc.GetHistoricCandlesRequest{ Exchange: "bruh", @@ -796,9 +786,7 @@ func TestGetHistoricCandles(t *testing.T) { End: defaultEnd.Format(common.SimpleTimeFormatWithTimezone), AssetType: asset.Spot.String(), }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v', expected '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) _, err = s.GetHistoricCandles(t.Context(), &gctrpc.GetHistoricCandlesRequest{ Exchange: testExchange, @@ -807,9 +795,8 @@ func TestGetHistoricCandles(t *testing.T) { Pair: nil, AssetType: asset.Spot.String(), }) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("received '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) + _, err = s.GetHistoricCandles(t.Context(), &gctrpc.GetHistoricCandlesRequest{ Exchange: testExchange, Pair: &gctrpc.CurrencyPair{ @@ -819,9 +806,8 @@ func TestGetHistoricCandles(t *testing.T) { Start: "2020-01-02 15:04:05 UTC", End: "2020-01-02 15:04:05 UTC", }) - if !errors.Is(err, common.ErrStartEqualsEnd) { - t.Errorf("received %v, expected %v", err, common.ErrStartEqualsEnd) - } + assert.ErrorIs(t, err, common.ErrStartEqualsEnd) + var results *gctrpc.GetHistoricCandlesResponse // default run results, err = s.GetHistoricCandles(t.Context(), &gctrpc.GetHistoricCandlesRequest{ @@ -928,10 +914,7 @@ func TestFindMissingSavedTradeIntervals(t *testing.T) { t.Error("expected error") return } - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - return - } + require.ErrorIs(t, err, errInvalidArguments) cp := currency.NewBTCUSD() // no data found response defaultStart := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).UTC() @@ -1030,10 +1013,7 @@ func TestFindMissingSavedCandleIntervals(t *testing.T) { t.Error("expected error") return } - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - return - } + require.ErrorIs(t, err, errInvalidArguments) cp := currency.NewBTCUSD() // no data found response defaultStart := time.Date(2020, 0, 0, 0, 0, 0, 0, time.UTC) @@ -1172,9 +1152,8 @@ func TestGetRecentTrades(t *testing.T) { defer CleanRPCTest(t, engerino) s := RPCServer{Engine: engerino} _, err := s.GetRecentTrades(t.Context(), &gctrpc.GetSavedTradesRequest{}) - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidArguments) + _, err = s.GetRecentTrades(t.Context(), &gctrpc.GetSavedTradesRequest{ Exchange: fakeExchangeName, Pair: &gctrpc.CurrencyPair{ @@ -1186,9 +1165,8 @@ func TestGetRecentTrades(t *testing.T) { Start: time.Date(2020, 0, 0, 0, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), End: time.Date(2020, 0, 0, 1, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) + _, err = s.GetRecentTrades(t.Context(), &gctrpc.GetSavedTradesRequest{ Exchange: testExchange, Pair: &gctrpc.CurrencyPair{ @@ -1220,9 +1198,8 @@ func TestGetHistoricTrades(t *testing.T) { defer CleanRPCTest(t, engerino) s := RPCServer{Engine: engerino} err := s.GetHistoricTrades(&gctrpc.GetSavedTradesRequest{}, nil) - if !errors.Is(err, errInvalidArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidArguments) + err = s.GetHistoricTrades(&gctrpc.GetSavedTradesRequest{ Exchange: fakeExchangeName, Pair: &gctrpc.CurrencyPair{ @@ -1234,9 +1211,8 @@ func TestGetHistoricTrades(t *testing.T) { Start: time.Date(2020, 0, 0, 0, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), End: time.Date(2020, 0, 0, 1, 0, 0, 0, time.UTC).Format(common.SimpleTimeFormatWithTimezone), }, nil) - if !errors.Is(err, ErrExchangeNotFound) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) + err = s.GetHistoricTrades(&gctrpc.GetSavedTradesRequest{ Exchange: testExchange, Pair: &gctrpc.CurrencyPair{ @@ -1304,9 +1280,7 @@ func TestUpdateAccountInfo(t *testing.T) { assert.NoError(t, err) _, err = s.UpdateAccountInfo(t.Context(), &gctrpc.GetAccountInfoRequest{Exchange: fakeExchangeName, AssetType: asset.Futures.String()}) - if !errors.Is(err, currency.ErrAssetNotFound) { - t.Errorf("received '%v', expected '%v'", err, currency.ErrAssetNotFound) - } + assert.ErrorIs(t, err, currency.ErrAssetNotFound) _, err = s.UpdateAccountInfo(t.Context(), &gctrpc.GetAccountInfoRequest{ Exchange: fakeExchangeName, @@ -1352,42 +1326,32 @@ func TestGetOrders(t *testing.T) { } _, err = s.GetOrders(t.Context(), nil) - if !errors.Is(err, errInvalidArguments) { - t.Errorf("received '%v', expected '%v'", err, errInvalidArguments) - } + assert.ErrorIs(t, err, errInvalidArguments) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ AssetType: asset.Spot.String(), Pair: p, }) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received '%v', expected '%v'", ErrExchangeNameIsEmpty, err) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: "bruh", AssetType: asset.Spot.String(), Pair: p, }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v', expected '%v'", ErrExchangeNotFound, err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, AssetType: asset.Spot.String(), }) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("received '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, Pair: p, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, @@ -1396,9 +1360,7 @@ func TestGetOrders(t *testing.T) { StartDate: time.Now().UTC().Add(time.Second).Format(common.SimpleTimeFormatWithTimezone), EndDate: time.Now().UTC().Add(-time.Hour).Format(common.SimpleTimeFormatWithTimezone), }) - if !errors.Is(err, common.ErrStartAfterEnd) { - t.Errorf("received %v, expected %v", err, common.ErrStartAfterEnd) - } + assert.ErrorIs(t, err, common.ErrStartAfterEnd) _, err = s.GetOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, @@ -1407,9 +1369,7 @@ func TestGetOrders(t *testing.T) { StartDate: time.Now().UTC().Add(-time.Hour).Format(common.SimpleTimeFormatWithTimezone), EndDate: time.Now().UTC().Add(time.Hour).Format(common.SimpleTimeFormatWithTimezone), }) - if !errors.Is(err, exchange.ErrCredentialsAreEmpty) { - t.Errorf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) - } + assert.ErrorIs(t, err, exchange.ErrCredentialsAreEmpty) b.SetCredentials("test", "test", "", "", "", "") b.API.AuthenticatedSupport = true @@ -1462,9 +1422,7 @@ func TestGetOrder(t *testing.T) { } _, err = s.GetOrder(t.Context(), nil) - if !errors.Is(err, errInvalidArguments) { - t.Errorf("received '%v', expected '%v'", err, errInvalidArguments) - } + assert.ErrorIs(t, err, errInvalidArguments) _, err = s.GetOrder(t.Context(), &gctrpc.GetOrderRequest{ Exchange: "test123", @@ -1472,9 +1430,7 @@ func TestGetOrder(t *testing.T) { Pair: p, Asset: "spot", }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v', expected '%v'", err, ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) _, err = s.GetOrder(t.Context(), &gctrpc.GetOrderRequest{ Exchange: exchName, @@ -1482,9 +1438,7 @@ func TestGetOrder(t *testing.T) { Pair: nil, Asset: "", }) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("received '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) _, err = s.GetOrder(t.Context(), &gctrpc.GetOrderRequest{ Exchange: exchName, @@ -1492,9 +1446,7 @@ func TestGetOrder(t *testing.T) { Pair: p, Asset: "", }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) _, err = s.GetOrder(t.Context(), &gctrpc.GetOrderRequest{ Exchange: exchName, @@ -1502,18 +1454,15 @@ func TestGetOrder(t *testing.T) { Pair: p, Asset: asset.Spot.String(), }) - if !errors.Is(err, ErrOrderIDCannotBeEmpty) { - t.Errorf("received '%v', expected '%v'", err, ErrOrderIDCannotBeEmpty) - } + assert.ErrorIs(t, err, ErrOrderIDCannotBeEmpty) + _, err = s.GetOrder(t.Context(), &gctrpc.GetOrderRequest{ Exchange: exchName, OrderId: "1234", Pair: p, Asset: asset.Spot.String(), }) - if !errors.Is(err, exchange.ErrCredentialsAreEmpty) { - t.Errorf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) - } + assert.ErrorIs(t, err, exchange.ErrCredentialsAreEmpty) } func TestCheckVars(t *testing.T) { @@ -1653,14 +1602,10 @@ func TestRPCServerUpsertDataHistoryJob(t *testing.T) { s := RPCServer{Engine: &Engine{dataHistoryManager: m, ExchangeManager: em}} _, err = s.UpsertDataHistoryJob(t.Context(), nil) - if !errors.Is(err, errNilRequestData) { - t.Errorf("received %v, expected %v", err, errNilRequestData) - } + assert.ErrorIs(t, err, errNilRequestData) _, err = s.UpsertDataHistoryJob(t.Context(), &gctrpc.UpsertDataHistoryJobRequest{}) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received %v, expected %v", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) job := &gctrpc.UpsertDataHistoryJobRequest{ Nickname: "hellomoto", @@ -1702,19 +1647,13 @@ func TestGetDataHistoryJobDetails(t *testing.T) { assert.NoError(t, err) _, err = s.GetDataHistoryJobDetails(t.Context(), nil) - if !errors.Is(err, errNilRequestData) { - t.Errorf("received %v, expected %v", err, errNilRequestData) - } + assert.ErrorIs(t, err, errNilRequestData) _, err = s.GetDataHistoryJobDetails(t.Context(), &gctrpc.GetDataHistoryJobDetailsRequest{}) - if !errors.Is(err, errNicknameIDUnset) { - t.Errorf("received %v, expected %v", err, errNicknameIDUnset) - } + assert.ErrorIs(t, err, errNicknameIDUnset) _, err = s.GetDataHistoryJobDetails(t.Context(), &gctrpc.GetDataHistoryJobDetailsRequest{Id: "123", Nickname: "123"}) - if !errors.Is(err, errOnlyNicknameOrID) { - t.Errorf("received %v, expected %v", err, errOnlyNicknameOrID) - } + assert.ErrorIs(t, err, errOnlyNicknameOrID) _, err = s.GetDataHistoryJobDetails(t.Context(), &gctrpc.GetDataHistoryJobDetailsRequest{Nickname: "TestGetDataHistoryJobDetails"}) assert.NoError(t, err) @@ -1751,19 +1690,13 @@ func TestSetDataHistoryJobStatus(t *testing.T) { require.NoError(t, err) _, err = s.SetDataHistoryJobStatus(t.Context(), nil) - if !errors.Is(err, errNilRequestData) { - t.Errorf("received %v, expected %v", err, errNilRequestData) - } + assert.ErrorIs(t, err, errNilRequestData) _, err = s.SetDataHistoryJobStatus(t.Context(), &gctrpc.SetDataHistoryJobStatusRequest{}) - if !errors.Is(err, errNicknameIDUnset) { - t.Errorf("received %v, expected %v", err, errNicknameIDUnset) - } + assert.ErrorIs(t, err, errNicknameIDUnset) _, err = s.SetDataHistoryJobStatus(t.Context(), &gctrpc.SetDataHistoryJobStatusRequest{Id: "123", Nickname: "123"}) - if !errors.Is(err, errOnlyNicknameOrID) { - t.Errorf("received %v, expected %v", err, errOnlyNicknameOrID) - } + assert.ErrorIs(t, err, errOnlyNicknameOrID) id := dhj.ID _, err = s.SetDataHistoryJobStatus(t.Context(), &gctrpc.SetDataHistoryJobStatusRequest{Nickname: "TestDeleteDataHistoryJob", Status: int64(dataHistoryStatusRemoved)}) @@ -1775,9 +1708,8 @@ func TestSetDataHistoryJobStatus(t *testing.T) { assert.NoError(t, err) _, err = s.SetDataHistoryJobStatus(t.Context(), &gctrpc.SetDataHistoryJobStatusRequest{Id: id.String(), Status: int64(dataHistoryStatusActive)}) - if !errors.Is(err, errBadStatus) { - t.Errorf("received %v, expected %v", err, errBadStatus) - } + assert.ErrorIs(t, err, errBadStatus) + j.Status = int64(dataHistoryStatusActive) _, err = s.SetDataHistoryJobStatus(t.Context(), &gctrpc.SetDataHistoryJobStatusRequest{Id: id.String(), Status: int64(dataHistoryStatusPaused)}) assert.NoError(t, err) @@ -1827,17 +1759,13 @@ func TestGetDataHistoryJobsBetween(t *testing.T) { } _, err := s.GetDataHistoryJobsBetween(t.Context(), nil) - if !errors.Is(err, errNilRequestData) { - t.Fatalf("received %v, expected %v", err, errNilRequestData) - } + require.ErrorIs(t, err, errNilRequestData) _, err = s.GetDataHistoryJobsBetween(t.Context(), &gctrpc.GetDataHistoryJobsBetweenRequest{ StartDate: time.Now().UTC().Add(time.Minute).Format(common.SimpleTimeFormatWithTimezone), EndDate: time.Now().UTC().Format(common.SimpleTimeFormatWithTimezone), }) - if !errors.Is(err, common.ErrStartAfterEnd) { - t.Fatalf("received %v, expected %v", err, common.ErrStartAfterEnd) - } + require.ErrorIs(t, err, common.ErrStartAfterEnd) err = m.UpsertJob(dhj, false) require.NoError(t, err) @@ -1917,42 +1845,32 @@ func TestGetManagedOrders(t *testing.T) { } _, err = s.GetManagedOrders(t.Context(), nil) - if !errors.Is(err, errInvalidArguments) { - t.Errorf("received '%v', expected '%v'", err, errInvalidArguments) - } + assert.ErrorIs(t, err, errInvalidArguments) _, err = s.GetManagedOrders(t.Context(), &gctrpc.GetOrdersRequest{ AssetType: asset.Spot.String(), Pair: p, }) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received '%v', expected '%v'", ErrExchangeNameIsEmpty, err) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) _, err = s.GetManagedOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: "bruh", AssetType: asset.Spot.String(), Pair: p, }) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received '%v', expected '%v'", ErrExchangeNotFound, err) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) _, err = s.GetManagedOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, AssetType: asset.Spot.String(), }) - if !errors.Is(err, errCurrencyPairUnset) { - t.Errorf("received '%v', expected '%v'", err, errCurrencyPairUnset) - } + assert.ErrorIs(t, err, errCurrencyPairUnset) _, err = s.GetManagedOrders(t.Context(), &gctrpc.GetOrdersRequest{ Exchange: exchName, Pair: p, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) o := order.Detail{ Price: 100000, @@ -2087,14 +2005,10 @@ func TestUpdateDataHistoryJobPrerequisite(t *testing.T) { m, _ := createDHM(t) s := RPCServer{Engine: &Engine{dataHistoryManager: m}} _, err := s.UpdateDataHistoryJobPrerequisite(t.Context(), nil) - if !errors.Is(err, errNilRequestData) { - t.Errorf("received %v, expected %v", err, errNilRequestData) - } + assert.ErrorIs(t, err, errNilRequestData) _, err = s.UpdateDataHistoryJobPrerequisite(t.Context(), &gctrpc.UpdateDataHistoryJobPrerequisiteRequest{}) - if !errors.Is(err, errNicknameUnset) { - t.Errorf("received %v, expected %v", err, errNicknameUnset) - } + assert.ErrorIs(t, err, errNicknameUnset) _, err = s.UpdateDataHistoryJobPrerequisite(t.Context(), &gctrpc.UpdateDataHistoryJobPrerequisiteRequest{ Nickname: "test456", @@ -2112,9 +2026,7 @@ func TestCurrencyStateGetAll(t *testing.T) { t.Parallel() _, err := (&RPCServer{Engine: &Engine{}}).CurrencyStateGetAll(t.Context(), &gctrpc.CurrencyStateGetAllRequest{Exchange: fakeExchangeName}) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("received %v, expected %v", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestCurrencyStateWithdraw(t *testing.T) { @@ -2125,9 +2037,7 @@ func TestCurrencyStateWithdraw(t *testing.T) { &gctrpc.CurrencyStateWithdrawRequest{ Exchange: "wow", Asset: "meow", }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = (&RPCServer{ Engine: &Engine{}, @@ -2135,9 +2045,7 @@ func TestCurrencyStateWithdraw(t *testing.T) { &gctrpc.CurrencyStateWithdrawRequest{ Exchange: "wow", Asset: "spot", }) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: %v, but expected: %v", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestCurrencyStateDeposit(t *testing.T) { @@ -2146,17 +2054,13 @@ func TestCurrencyStateDeposit(t *testing.T) { Engine: &Engine{}, }).CurrencyStateDeposit(t.Context(), &gctrpc.CurrencyStateDepositRequest{Exchange: "wow", Asset: "meow"}) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = (&RPCServer{ Engine: &Engine{}, }).CurrencyStateDeposit(t.Context(), &gctrpc.CurrencyStateDepositRequest{Exchange: "wow", Asset: "spot"}) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: %v, but expected: %v", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestCurrencyStateTrading(t *testing.T) { @@ -2165,17 +2069,13 @@ func TestCurrencyStateTrading(t *testing.T) { Engine: &Engine{}, }).CurrencyStateTrading(t.Context(), &gctrpc.CurrencyStateTradingRequest{Exchange: "wow", Asset: "meow"}) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = (&RPCServer{ Engine: &Engine{}, }).CurrencyStateTrading(t.Context(), &gctrpc.CurrencyStateTradingRequest{Exchange: "wow", Asset: "spot"}) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received: %v, but expected: %v", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) } func TestCurrencyStateTradingPair(t *testing.T) { @@ -2296,9 +2196,7 @@ func TestGetFuturesPositionsOrders(t *testing.T) { Quote: cp.Quote.String(), }, }) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) } func TestGetCollateral(t *testing.T) { @@ -2348,9 +2246,7 @@ func TestGetCollateral(t *testing.T) { Exchange: fakeExchangeName, Asset: asset.Futures.String(), }) - if !errors.Is(err, exchange.ErrCredentialsAreEmpty) { - t.Fatalf("received '%v', expected '%v'", err, exchange.ErrCredentialsAreEmpty) - } + require.ErrorIs(t, err, exchange.ErrCredentialsAreEmpty) ctx := account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "fakerino", Secret: "supafake"}) @@ -2359,9 +2255,7 @@ func TestGetCollateral(t *testing.T) { Exchange: fakeExchangeName, Asset: asset.Futures.String(), }) - if !errors.Is(err, errNoAccountInformation) { - t.Fatalf("received '%v', expected '%v'", err, errNoAccountInformation) - } + require.ErrorIs(t, err, errNoAccountInformation) ctx = account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "fakerino", Secret: "supafake", SubAccount: "1337"}) @@ -2385,9 +2279,7 @@ func TestGetCollateral(t *testing.T) { Asset: asset.Spot.String(), IncludeBreakdown: true, }) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) _, err = s.GetCollateral(ctx, &gctrpc.GetCollateralRequest{ Exchange: fakeExchangeName, @@ -2402,15 +2294,11 @@ func TestShutdown(t *testing.T) { t.Parallel() s := RPCServer{Engine: &Engine{}} _, err := s.Shutdown(t.Context(), &gctrpc.ShutdownRequest{}) - if !errors.Is(err, errShutdownNotAllowed) { - t.Fatalf("received: '%v' but expected: '%v'", err, errShutdownNotAllowed) - } + require.ErrorIs(t, err, errShutdownNotAllowed) s.Engine.Settings.EnableGRPCShutdown = true _, err = s.Shutdown(t.Context(), &gctrpc.ShutdownRequest{}) - if !errors.Is(err, errGRPCShutdownSignalIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errGRPCShutdownSignalIsNil) - } + require.ErrorIs(t, err, errGRPCShutdownSignalIsNil) s.Engine.GRPCShutdownSignal = make(chan struct{}, 1) _, err = s.Shutdown(t.Context(), &gctrpc.ShutdownRequest{}) @@ -2461,25 +2349,19 @@ func TestGetTechnicalAnalysis(t *testing.T) { } _, err = s.GetTechnicalAnalysis(t.Context(), &gctrpc.GetTechnicalAnalysisRequest{}) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) _, err = s.GetTechnicalAnalysis(t.Context(), &gctrpc.GetTechnicalAnalysisRequest{ Exchange: fakeExchangeName, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = s.GetTechnicalAnalysis(t.Context(), &gctrpc.GetTechnicalAnalysisRequest{ Exchange: fakeExchangeName, AssetType: "upsideprofitcontract", Pair: &gctrpc.CurrencyPair{}, }) - if !errors.Is(err, errExpectedTestError) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExpectedTestError) - } + require.ErrorIs(t, err, errExpectedTestError) _, err = s.GetTechnicalAnalysis(t.Context(), &gctrpc.GetTechnicalAnalysisRequest{ Exchange: fakeExchangeName, @@ -2487,9 +2369,7 @@ func TestGetTechnicalAnalysis(t *testing.T) { Pair: &gctrpc.CurrencyPair{Base: "btc", Quote: "usd"}, Interval: int64(kline.OneDay), }) - if !errors.Is(err, errInvalidStrategy) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidStrategy) - } + require.ErrorIs(t, err, errInvalidStrategy) resp, err := s.GetTechnicalAnalysis(t.Context(), &gctrpc.GetTechnicalAnalysisRequest{ Exchange: fakeExchangeName, @@ -2701,27 +2581,19 @@ func TestGetMarginRatesHistory(t *testing.T) { }, } _, err = s.GetMarginRatesHistory(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) request := &gctrpc.GetMarginRatesHistoryRequest{} _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received '%v' expected '%v'", err, ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) request.Exchange = fakeExchangeName _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v' expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) request.Asset = asset.Spot.String() _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, currency.ErrCurrencyNotFound) { - t.Errorf("received '%v' expected '%v'", err, currency.ErrCurrencyNotFound) - } + assert.ErrorIs(t, err, currency.ErrCurrencyNotFound) request.Currency = "usd" _, err = s.GetMarginRatesHistory(t.Context(), request) @@ -2759,21 +2631,15 @@ func TestGetMarginRatesHistory(t *testing.T) { request.CalculateOffline = true _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, common.ErrCannotCalculateOffline) { - t.Errorf("received '%v' expected '%v'", err, common.ErrCannotCalculateOffline) - } + assert.ErrorIs(t, err, common.ErrCannotCalculateOffline) request.TakerFeeRate = "-1337" _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, common.ErrCannotCalculateOffline) { - t.Errorf("received '%v' expected '%v'", err, common.ErrCannotCalculateOffline) - } + assert.ErrorIs(t, err, common.ErrCannotCalculateOffline) request.TakerFeeRate = "1337" _, err = s.GetMarginRatesHistory(t.Context(), request) - if !errors.Is(err, common.ErrCannotCalculateOffline) { - t.Errorf("received '%v' expected '%v'", err, common.ErrCannotCalculateOffline) - } + assert.ErrorIs(t, err, common.ErrCannotCalculateOffline) request.Rates = []*gctrpc.MarginRate{ { @@ -2860,9 +2726,8 @@ func TestGetFundingRates(t *testing.T) { } _, err = s.GetFundingRates(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received: '%v' but expected: '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) + request := &gctrpc.GetFundingRatesRequest{ Exchange: "", Asset: "", @@ -2873,20 +2738,15 @@ func TestGetFundingRates(t *testing.T) { IncludePayments: false, } _, err = s.GetFundingRates(t.Context(), request) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received: '%v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) + request.Exchange = exch.GetName() _, err = s.GetFundingRates(t.Context(), request) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) request.Asset = asset.Spot.String() _, err = s.GetFundingRates(t.Context(), request) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received: '%v' but expected: '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) request.Asset = asset.Futures.String() request.Pair = &gctrpc.CurrencyPair{ @@ -2962,9 +2822,8 @@ func TestGetLatestFundingRate(t *testing.T) { } _, err = s.GetLatestFundingRate(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received: '%v' but expected: '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) + request := &gctrpc.GetLatestFundingRateRequest{ Exchange: "", Asset: "", @@ -2972,20 +2831,15 @@ func TestGetLatestFundingRate(t *testing.T) { IncludePredicted: false, } _, err = s.GetLatestFundingRate(t.Context(), request) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received: '%v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) + request.Exchange = exch.GetName() _, err = s.GetLatestFundingRate(t.Context(), request) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) request.Asset = asset.Spot.String() _, err = s.GetLatestFundingRate(t.Context(), request) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received: '%v' but expected: '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) request.Asset = asset.Futures.String() request.Pair = &gctrpc.CurrencyPair{ @@ -3057,15 +2911,11 @@ func TestGetManagedPosition(t *testing.T) { }, } _, err = s.GetManagedPosition(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) request := &gctrpc.GetManagedPositionRequest{} _, err = s.GetManagedPosition(t.Context(), request) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) request.Pair = &gctrpc.CurrencyPair{ Delimiter: "-", @@ -3073,21 +2923,15 @@ func TestGetManagedPosition(t *testing.T) { Quote: "USD", } _, err = s.GetManagedPosition(t.Context(), request) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Errorf("received '%v', expected '%v'", err, ErrExchangeNameIsEmpty) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) request.Exchange = fakeExchangeName _, err = s.GetManagedPosition(t.Context(), request) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) request.Asset = asset.Spot.String() _, err = s.GetManagedPosition(t.Context(), request) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) request.Asset = asset.Futures.String() s.OrderManager, err = SetupOrderManager(em, &CommunicationManager{}, &wg, &config.OrderManager{FuturesTrackingSeekDuration: time.Hour}) @@ -3096,9 +2940,7 @@ func TestGetManagedPosition(t *testing.T) { s.OrderManager.started = 1 s.OrderManager.activelyTrackFuturesPositions = true _, err = s.GetManagedPosition(t.Context(), request) - if !errors.Is(err, futures.ErrPositionNotFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrPositionNotFound) - } + assert.ErrorIs(t, err, futures.ErrPositionNotFound) err = s.OrderManager.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ Leverage: 1337, @@ -3193,9 +3035,7 @@ func TestGetAllManagedPositions(t *testing.T) { }, } _, err = s.GetAllManagedPositions(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) request := &gctrpc.GetAllManagedPositionsRequest{} s.OrderManager, err = SetupOrderManager(em, &CommunicationManager{}, &wg, &config.OrderManager{FuturesTrackingSeekDuration: time.Hour, ActivelyTrackFuturesPositions: true}) @@ -3203,9 +3043,7 @@ func TestGetAllManagedPositions(t *testing.T) { s.OrderManager.started = 1 _, err = s.GetAllManagedPositions(t.Context(), request) - if !errors.Is(err, futures.ErrNoPositionsFound) { - t.Errorf("received '%v', expected '%v'", err, futures.ErrNoPositionsFound) - } + assert.ErrorIs(t, err, futures.ErrNoPositionsFound) err = s.OrderManager.orderStore.futuresPositionController.TrackNewOrder(&order.Detail{ Leverage: 1337, @@ -3275,22 +3113,16 @@ func TestGetOrderbookMovement(t *testing.T) { req := &gctrpc.GetOrderbookMovementRequest{} _, err = s.GetOrderbookMovement(t.Context(), req) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) req.Exchange = "fake" _, err = s.GetOrderbookMovement(t.Context(), req) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%+v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) req.Asset = asset.Spot.String() req.Pair = &gctrpc.CurrencyPair{} _, err = s.GetOrderbookMovement(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Pair = &gctrpc.CurrencyPair{ Base: currency.BTC.String(), @@ -3382,22 +3214,16 @@ func TestGetOrderbookAmountByNominal(t *testing.T) { req := &gctrpc.GetOrderbookAmountByNominalRequest{} _, err = s.GetOrderbookAmountByNominal(t.Context(), req) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) req.Exchange = "fake" _, err = s.GetOrderbookAmountByNominal(t.Context(), req) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%+v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) req.Asset = asset.Spot.String() req.Pair = &gctrpc.CurrencyPair{} _, err = s.GetOrderbookAmountByNominal(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Pair = &gctrpc.CurrencyPair{ Base: currency.BTC.String(), @@ -3482,22 +3308,16 @@ func TestGetOrderbookAmountByImpact(t *testing.T) { req := &gctrpc.GetOrderbookAmountByImpactRequest{} _, err = s.GetOrderbookAmountByImpact(t.Context(), req) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, ErrExchangeNameIsEmpty) - } + require.ErrorIs(t, err, ErrExchangeNameIsEmpty) req.Exchange = "fake" _, err = s.GetOrderbookAmountByImpact(t.Context(), req) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%+v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) req.Asset = asset.Spot.String() req.Pair = &gctrpc.CurrencyPair{} _, err = s.GetOrderbookAmountByImpact(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%+v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Pair = &gctrpc.CurrencyPair{ Base: currency.BTC.String(), @@ -3582,15 +3402,11 @@ func TestChangePositionMargin(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.ChangePositionMargin(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.ChangePositionMarginRequest{} _, err = s.ChangePositionMargin(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Errorf("received '%v' expected '%v'", err, currency.ErrCurrencyPairEmpty) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Exchange = fakeExchangeName req.Pair = &gctrpc.CurrencyPair{ @@ -3641,15 +3457,11 @@ func TestSetLeverage(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.SetLeverage(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.SetLeverageRequest{} _, err = s.SetLeverage(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Exchange = fakeExchangeName req.Pair = &gctrpc.CurrencyPair{ @@ -3670,9 +3482,7 @@ func TestSetLeverage(t *testing.T) { req.OrderSide = "lol" _, err = s.SetLeverage(t.Context(), req) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) req.OrderSide = order.Long.String() _, err = s.SetLeverage(t.Context(), req) @@ -3713,15 +3523,11 @@ func TestGetLeverage(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.GetLeverage(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.GetLeverageRequest{} _, err = s.GetLeverage(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Exchange = fakeExchangeName req.Pair = &gctrpc.CurrencyPair{ @@ -3745,9 +3551,7 @@ func TestGetLeverage(t *testing.T) { req.OrderSide = "lol" _, err = s.GetLeverage(t.Context(), req) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) req.OrderSide = order.Long.String() _, err = s.GetLeverage(t.Context(), req) @@ -3788,15 +3592,11 @@ func TestSetMarginType(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.SetMarginType(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.SetMarginTypeRequest{} _, err = s.SetMarginType(t.Context(), req) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) req.Exchange = fakeExchangeName req.Pair = &gctrpc.CurrencyPair{ @@ -3844,15 +3644,11 @@ func TestSetCollateralMode(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.SetCollateralMode(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.SetCollateralModeRequest{} _, err = s.SetCollateralMode(t.Context(), req) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) req.Exchange = fakeExchangeName req.Asset = asset.USDTMarginedFutures.String() @@ -3885,15 +3681,11 @@ func TestGetCollateralMode(t *testing.T) { s := RPCServer{Engine: &Engine{ExchangeManager: em}} _, err = s.GetCollateralMode(t.Context(), nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) req := &gctrpc.GetCollateralModeRequest{} _, err = s.GetCollateralMode(t.Context(), req) - if !errors.Is(err, ErrExchangeNameIsEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrExchangeNameIsEmpty) req.Exchange = fakeExchangeName req.Asset = asset.USDTMarginedFutures.String() diff --git a/engine/sync_manager_test.go b/engine/sync_manager_test.go index 6db28364..1335ab3d 100644 --- a/engine/sync_manager_test.go +++ b/engine/sync_manager_test.go @@ -19,39 +19,25 @@ import ( func TestSetupSyncManager(t *testing.T) { t.Parallel() _, err := SetupSyncManager(nil, nil, nil, false) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("error '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) _, err = SetupSyncManager(&config.SyncManagerConfig{}, nil, nil, false) - if !errors.Is(err, errNoSyncItemsEnabled) { - t.Errorf("error '%v', expected '%v'", err, errNoSyncItemsEnabled) - } + assert.ErrorIs(t, err, errNoSyncItemsEnabled) _, err = SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true}, nil, nil, false) - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilExchangeManager) - } + assert.ErrorIs(t, err, errNilExchangeManager) _, err = SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true}, &ExchangeManager{}, nil, false) - if !errors.Is(err, errNilConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilConfig) - } + assert.ErrorIs(t, err, errNilConfig) _, err = SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true}, &ExchangeManager{}, &config.RemoteControlConfig{}, true) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("error '%v', expected '%v'", err, currency.ErrCurrencyCodeEmpty) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) _, err = SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true, FiatDisplayCurrency: currency.BTC}, &ExchangeManager{}, &config.RemoteControlConfig{}, true) - if !errors.Is(err, currency.ErrFiatDisplayCurrencyIsNotFiat) { - t.Errorf("error '%v', expected '%v'", err, currency.ErrFiatDisplayCurrencyIsNotFiat) - } + assert.ErrorIs(t, err, currency.ErrFiatDisplayCurrencyIsNotFiat) _, err = SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true, FiatDisplayCurrency: currency.USD}, &ExchangeManager{}, &config.RemoteControlConfig{}, true) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("error '%v', expected '%v'", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) m, err := SetupSyncManager(&config.SyncManagerConfig{SynchronizeTrades: true, FiatDisplayCurrency: currency.USD, PairFormatDisplay: ¤cy.EMPTYFORMAT}, &ExchangeManager{}, &config.RemoteControlConfig{}, true) assert.NoError(t, err) @@ -81,24 +67,18 @@ func TestSyncManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) m = nil err = m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) } func TestSyncManagerStop(t *testing.T) { t.Parallel() var m *SyncManager err := m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) em := NewExchangeManager() exch, err := em.NewExchangeByName("Bitstamp") @@ -113,9 +93,7 @@ func TestSyncManagerStop(t *testing.T) { assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) @@ -232,15 +210,11 @@ func TestRelayWebsocketEvent(t *testing.T) { func TestWaitForInitialSync(t *testing.T) { var m *SyncManager err := m.WaitForInitialSync() - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received %v, but expected: %v", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = &SyncManager{} err = m.WaitForInitialSync() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received %v, but expected: %v", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) m.started = 1 err = m.WaitForInitialSync() @@ -251,15 +225,11 @@ func TestSyncManagerWebsocketUpdate(t *testing.T) { t.Parallel() var m *SyncManager err := m.WebsocketUpdate("", currency.EMPTYPAIR, 1, 47, nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received %v, but expected: %v", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = &SyncManager{} err = m.WebsocketUpdate("", currency.EMPTYPAIR, 1, 47, nil) - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Fatalf("received %v, but expected: %v", err, ErrSubSystemNotStarted) - } + require.ErrorIs(t, err, ErrSubSystemNotStarted) m.started = 1 // not started initial sync @@ -283,14 +253,10 @@ func TestSyncManagerWebsocketUpdate(t *testing.T) { m.config.SynchronizeTrades = true err = m.WebsocketUpdate("", currency.EMPTYPAIR, asset.Spot, 1336, nil) - if !errors.Is(err, errUnknownSyncItem) { - t.Fatalf("received %v, but expected: %v", err, errUnknownSyncItem) - } + require.ErrorIs(t, err, errUnknownSyncItem) err = m.WebsocketUpdate("", currency.EMPTYPAIR, asset.Spot, SyncItemOrderbook, nil) - if !errors.Is(err, errCouldNotSyncNewData) { - t.Fatalf("received %v, but expected: %v", err, errCouldNotSyncNewData) - } + require.ErrorIs(t, err, errCouldNotSyncNewData) m.add(key.ExchangePairAsset{ Asset: asset.Spot, diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index 1b65438c..6a8c7fa8 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -20,23 +20,16 @@ import ( func TestWebsocketRoutineManagerSetup(t *testing.T) { _, err := setupWebsocketRoutineManager(nil, nil, nil, nil, false) - if !errors.Is(err, errNilExchangeManager) { - t.Errorf("error '%v', expected '%v'", err, errNilExchangeManager) - } + assert.ErrorIs(t, err, errNilExchangeManager) _, err = setupWebsocketRoutineManager(NewExchangeManager(), (*OrderManager)(nil), nil, nil, false) - if !errors.Is(err, errNilCurrencyPairSyncer) { - t.Errorf("error '%v', expected '%v'", err, errNilCurrencyPairSyncer) - } + assert.ErrorIs(t, err, errNilCurrencyPairSyncer) + _, err = setupWebsocketRoutineManager(NewExchangeManager(), &OrderManager{}, &SyncManager{}, nil, false) - if !errors.Is(err, errNilCurrencyConfig) { - t.Errorf("error '%v', expected '%v'", err, errNilCurrencyConfig) - } + assert.ErrorIs(t, err, errNilCurrencyConfig) _, err = setupWebsocketRoutineManager(NewExchangeManager(), &OrderManager{}, &SyncManager{}, ¤cy.Config{}, true) - if !errors.Is(err, errNilCurrencyPairFormat) { - t.Errorf("error '%v', expected '%v'", err, errNilCurrencyPairFormat) - } + assert.ErrorIs(t, err, errNilCurrencyPairFormat) m, err := setupWebsocketRoutineManager(NewExchangeManager(), &OrderManager{}, &SyncManager{}, ¤cy.Config{CurrencyPairFormat: ¤cy.PairFormat{}}, false) assert.NoError(t, err) @@ -49,9 +42,8 @@ func TestWebsocketRoutineManagerSetup(t *testing.T) { func TestWebsocketRoutineManagerStart(t *testing.T) { var m *WebsocketRoutineManager err := m.Start() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) + cfg := ¤cy.Config{CurrencyPairFormat: ¤cy.PairFormat{ Uppercase: false, Delimiter: "-", @@ -63,9 +55,7 @@ func TestWebsocketRoutineManagerStart(t *testing.T) { assert.NoError(t, err) err = m.Start() - if !errors.Is(err, ErrSubSystemAlreadyStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemAlreadyStarted) - } + assert.ErrorIs(t, err, ErrSubSystemAlreadyStarted) } func TestWebsocketRoutineManagerIsRunning(t *testing.T) { @@ -95,17 +85,13 @@ func TestWebsocketRoutineManagerIsRunning(t *testing.T) { func TestWebsocketRoutineManagerStop(t *testing.T) { var m *WebsocketRoutineManager err := m.Stop() - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("error '%v', expected '%v'", err, ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) m, err = setupWebsocketRoutineManager(NewExchangeManager(), &OrderManager{}, &SyncManager{}, ¤cy.Config{CurrencyPairFormat: ¤cy.PairFormat{}}, false) assert.NoError(t, err) err = m.Stop() - if !errors.Is(err, ErrSubSystemNotStarted) { - t.Errorf("error '%v', expected '%v'", err, ErrSubSystemNotStarted) - } + assert.ErrorIs(t, err, ErrSubSystemNotStarted) err = m.Start() assert.NoError(t, err) @@ -226,9 +212,7 @@ func TestWebsocketRoutineManagerHandleData(t *testing.T) { if err == nil { t.Error("Expected error") } - if !errors.Is(err, classificationError.Err) { - t.Errorf("error '%v', expected '%v'", err, classificationError.Err) - } + assert.ErrorIs(t, err, classificationError.Err) err = m.websocketDataHandler(exchName, &orderbook.Base{ Exchange: "Bitstamp", @@ -247,17 +231,13 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Parallel() var m *WebsocketRoutineManager err := m.registerWebsocketDataHandler(nil, false) - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = new(WebsocketRoutineManager) m.shutdown = make(chan struct{}) err = m.registerWebsocketDataHandler(nil, false) - if !errors.Is(err, errNilWebsocketDataHandlerFunction) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction) - } + require.ErrorIs(t, err, errNilWebsocketDataHandlerFunction) // externally defined capture device dataChan := make(chan any) @@ -301,17 +281,13 @@ func TestSetWebsocketDataHandler(t *testing.T) { t.Parallel() var m *WebsocketRoutineManager err := m.setWebsocketDataHandler(nil) - if !errors.Is(err, ErrNilSubsystem) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNilSubsystem) - } + require.ErrorIs(t, err, ErrNilSubsystem) m = new(WebsocketRoutineManager) m.shutdown = make(chan struct{}) err = m.setWebsocketDataHandler(nil) - if !errors.Is(err, errNilWebsocketDataHandlerFunction) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilWebsocketDataHandlerFunction) - } + require.ErrorIs(t, err, errNilWebsocketDataHandlerFunction) err = m.registerWebsocketDataHandler(m.websocketDataHandler, false) require.NoError(t, err) diff --git a/engine/withdraw_manager_test.go b/engine/withdraw_manager_test.go index 7ff30898..e7535422 100644 --- a/engine/withdraw_manager_test.go +++ b/engine/withdraw_manager_test.go @@ -1,7 +1,6 @@ package engine import ( - "errors" "sync" "testing" "time" @@ -81,17 +80,14 @@ func TestSubmitWithdrawal(t *testing.T) { }, } _, err = m.SubmitWithdrawal(t.Context(), req) - if !errors.Is(err, common.ErrFunctionNotSupported) { - t.Errorf("received %v, expected %v", err, common.ErrFunctionNotSupported) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) req.Type = withdraw.Crypto req.Currency = currency.BTC req.Crypto.Address = "1337" _, err = m.SubmitWithdrawal(t.Context(), req) - if !errors.Is(err, withdraw.ErrStrAddressNotWhiteListed) { - t.Errorf("received %v, expected %v", err, withdraw.ErrStrAddressNotWhiteListed) - } + assert.ErrorIs(t, err, withdraw.ErrStrAddressNotWhiteListed) + var wg sync.WaitGroup err = pm.Start(&wg) if err != nil { @@ -106,20 +102,14 @@ func TestSubmitWithdrawal(t *testing.T) { assert.NoError(t, err) _, err = m.SubmitWithdrawal(t.Context(), req) - if !errors.Is(err, withdraw.ErrStrExchangeNotSupportedByAddress) { - t.Errorf("received %v, expected %v", err, withdraw.ErrStrExchangeNotSupportedByAddress) - } + assert.ErrorIs(t, err, withdraw.ErrStrExchangeNotSupportedByAddress) adds[0].SupportedExchanges = withdrawManagerTestExchangeName _, err = m.SubmitWithdrawal(t.Context(), req) - if !errors.Is(err, exchange.ErrAuthenticationSupportNotEnabled) { - t.Errorf("received '%v', expected '%v'", err, exchange.ErrAuthenticationSupportNotEnabled) - } + assert.ErrorIs(t, err, exchange.ErrAuthenticationSupportNotEnabled) _, err = m.SubmitWithdrawal(t.Context(), nil) - if !errors.Is(err, withdraw.ErrRequestCannotBeNil) { - t.Errorf("received %v, expected %v", err, withdraw.ErrRequestCannotBeNil) - } + assert.ErrorIs(t, err, withdraw.ErrRequestCannotBeNil) m.isDryRun = true _, err = m.SubmitWithdrawal(t.Context(), req) @@ -137,9 +127,7 @@ func TestWithdrawEventByID(t *testing.T) { ID: withdraw.DryRunID, } _, err = m.WithdrawalEventByID(withdraw.DryRunID.String()) - if !errors.Is(err, ErrWithdrawRequestNotFound) { - t.Errorf("received %v, expected %v", err, ErrWithdrawRequestNotFound) - } + assert.ErrorIs(t, err, ErrWithdrawRequestNotFound) withdraw.Cache.Add(withdraw.DryRunID.String(), tempResp) v, err := m.WithdrawalEventByID(withdraw.DryRunID.String()) @@ -159,18 +147,10 @@ func TestWithdrawalEventByExchange(t *testing.T) { } _, err = (*WithdrawManager)(nil).WithdrawalEventByExchange("xxx", 0) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received: %v but expected: %v", - err, - ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) _, err = m.WithdrawalEventByExchange("xxx", 0) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received: %v but expected: %v", - err, - ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestWithdrawEventByDate(t *testing.T) { @@ -182,18 +162,10 @@ func TestWithdrawEventByDate(t *testing.T) { } _, err = (*WithdrawManager)(nil).WithdrawEventByDate("xxx", time.Now(), time.Now(), 1) - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received: %v but expected: %v", - err, - ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) _, err = m.WithdrawEventByDate("xxx", time.Now(), time.Now(), 1) - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received: %v but expected: %v", - err, - ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } func TestWithdrawalEventByExchangeID(t *testing.T) { @@ -205,16 +177,8 @@ func TestWithdrawalEventByExchangeID(t *testing.T) { } _, err = (*WithdrawManager)(nil).WithdrawalEventByExchangeID("xxx", "xxx") - if !errors.Is(err, ErrNilSubsystem) { - t.Errorf("received: %v but expected: %v", - err, - ErrNilSubsystem) - } + assert.ErrorIs(t, err, ErrNilSubsystem) _, err = m.WithdrawalEventByExchangeID("xxx", "xxx") - if !errors.Is(err, ErrExchangeNotFound) { - t.Errorf("received: %v but expected: %v", - err, - ErrExchangeNotFound) - } + assert.ErrorIs(t, err, ErrExchangeNotFound) } diff --git a/exchanges/account/account_test.go b/exchanges/account/account_test.go index e2e768cf..de7a82db 100644 --- a/exchanges/account/account_test.go +++ b/exchanges/account/account_test.go @@ -1,7 +1,6 @@ package account import ( - "errors" "sync" "testing" "time" @@ -59,9 +58,7 @@ func TestCollectBalances(t *testing.T) { } _, err = CollectBalances(map[string][]Balance{}, asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetHoldings(t *testing.T) { @@ -250,9 +247,7 @@ func TestBalanceInternalWait(t *testing.T) { t.Parallel() var bi *ProtectedBalance _, _, err := bi.Wait(0) - if !errors.Is(err, errBalanceIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errBalanceIsNil) - } + require.ErrorIs(t, err, errBalanceIsNil) bi = &ProtectedBalance{} waiter, _, err := bi.Wait(time.Nanosecond) diff --git a/exchanges/account/credentials_test.go b/exchanges/account/credentials_test.go index 8a5c490e..0d1cf0f0 100644 --- a/exchanges/account/credentials_test.go +++ b/exchanges/account/credentials_test.go @@ -1,7 +1,6 @@ package account import ( - "errors" "testing" "github.com/stretchr/testify/require" @@ -28,9 +27,7 @@ func TestIsEmpty(t *testing.T) { func TestParseCredentialsMetadata(t *testing.T) { t.Parallel() _, err := ParseCredentialsMetadata(t.Context(), nil) - if !errors.Is(err, errMetaDataIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errMetaDataIsNil) - } + require.ErrorIs(t, err, errMetaDataIsNil) _, err = ParseCredentialsMetadata(t.Context(), metadata.MD{}) require.NoError(t, err) @@ -40,18 +37,14 @@ func TestParseCredentialsMetadata(t *testing.T) { nortyMD, _ := metadata.FromOutgoingContext(ctx) _, err = ParseCredentialsMetadata(t.Context(), nortyMD) - if !errors.Is(err, errInvalidCredentialMetaDataLength) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidCredentialMetaDataLength) - } + require.ErrorIs(t, err, errInvalidCredentialMetaDataLength) ctx = metadata.AppendToOutgoingContext(t.Context(), string(ContextCredentialsFlag), "brokenstring") nortyMD, _ = metadata.FromOutgoingContext(ctx) _, err = ParseCredentialsMetadata(t.Context(), nortyMD) - if !errors.Is(err, errMissingInfo) { - t.Fatalf("received: '%v' but expected: '%v'", err, errMissingInfo) - } + require.ErrorIs(t, err, errMissingInfo) beforeCreds := Credentials{ Key: "superkey", diff --git a/exchanges/alert/alert_test.go b/exchanges/alert/alert_test.go index f7ac362a..0deb2a99 100644 --- a/exchanges/alert/alert_test.go +++ b/exchanges/alert/alert_test.go @@ -1,7 +1,6 @@ package alert import ( - "errors" "log" "sync" "testing" @@ -124,9 +123,7 @@ func getSize() int { func TestSetPreAllocationCommsBuffer(t *testing.T) { t.Parallel() err := SetPreAllocationCommsBuffer(-1) - if !errors.Is(err, errInvalidBufferSize) { - t.Fatalf("received: '%v' but expected '%v'", err, errInvalidBufferSize) - } + require.ErrorIs(t, err, errInvalidBufferSize) if getSize() != 5 { t.Fatal("unexpected amount") diff --git a/exchanges/asset/asset_test.go b/exchanges/asset/asset_test.go index b2256995..e98edf70 100644 --- a/exchanges/asset/asset_test.go +++ b/exchanges/asset/asset_test.go @@ -1,7 +1,6 @@ package asset import ( - "errors" "slices" "testing" @@ -125,9 +124,8 @@ func TestNew(t *testing.T) { t.Run("", func(t *testing.T) { t.Parallel() returned, err := New(tt.Input) - if !errors.Is(err, tt.Error) { - t.Fatalf("received: '%v' but expected: '%v'", err, tt.Error) - } + require.ErrorIs(t, err, tt.Error) + if returned != tt.Expected { t.Fatalf("received: '%v' but expected: '%v'", returned, tt.Expected) } @@ -174,9 +172,7 @@ func TestUnmarshalMarshal(t *testing.T) { } err = json.Unmarshal([]byte(`"confused"`), &spot) - if !errors.Is(err, ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrNotSupported) - } + require.ErrorIs(t, err, ErrNotSupported) err = json.Unmarshal([]byte(`""`), &spot) require.NoError(t, err) diff --git a/exchanges/binance/binance_test.go b/exchanges/binance/binance_test.go index b974b738..67256fce 100644 --- a/exchanges/binance/binance_test.go +++ b/exchanges/binance/binance_test.go @@ -80,9 +80,7 @@ func TestUServerTime(t *testing.T) { func TestWrapperGetServerTime(t *testing.T) { t.Parallel() _, err := b.GetServerTime(t.Context(), asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) st, err := b.GetServerTime(t.Context(), asset.Spot) require.NoError(t, err) @@ -2289,28 +2287,18 @@ func TestGetHistoricCandles(t *testing.T) { bAssets := b.GetAssetTypes(false) for i := range bAssets { cps, err := b.GetAvailablePairs(bAssets[i]) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "GetAvailablePairs for asset %s must not error", bAssets[i]) + require.NotEmptyf(t, cps, "GetAvailablePairs for asset %s must return at least one pair", bAssets[i]) err = b.CurrencyPairs.EnablePair(bAssets[i], cps[0]) - if err != nil && !errors.Is(err, currency.ErrPairAlreadyEnabled) { - t.Fatal(err) - } + require.Truef(t, err == nil || errors.Is(err, currency.ErrPairAlreadyEnabled), + "EnablePair for asset %s and pair %s must not error: %s", bAssets[i], cps[0], err) _, err = b.GetHistoricCandles(t.Context(), cps[0], bAssets[i], kline.OneDay, startTime, end) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "GetHistoricCandles should not error for asset %s and pair %s", bAssets[i], cps[0]) } - pair, err := currency.NewPairFromString("BTC-USDT") - if err != nil { - t.Fatal(err) - } startTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - _, err = b.GetHistoricCandles(t.Context(), pair, asset.Spot, kline.Interval(time.Hour*7), startTime, end) - if !errors.Is(err, kline.ErrRequestExceedsExchangeLimits) { - t.Fatalf("received: '%v', but expected: '%v'", err, kline.ErrRequestExceedsExchangeLimits) - } + _, err := b.GetHistoricCandles(t.Context(), currency.NewBTCUSDT(), asset.Spot, kline.Interval(time.Hour*7), startTime, end) + require.ErrorIs(t, err, kline.ErrRequestExceedsExchangeLimits) } func TestGetHistoricCandlesExtended(t *testing.T) { @@ -2320,17 +2308,13 @@ func TestGetHistoricCandlesExtended(t *testing.T) { bAssets := b.GetAssetTypes(false) for i := range bAssets { cps, err := b.GetAvailablePairs(bAssets[i]) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "GetAvailablePairs for asset %s must not error", bAssets[i]) + require.NotEmptyf(t, cps, "GetAvailablePairs for asset %s must return at least one pair", bAssets[i]) err = b.CurrencyPairs.EnablePair(bAssets[i], cps[0]) - if err != nil && !errors.Is(err, currency.ErrPairAlreadyEnabled) { - t.Fatal(err) - } + require.Truef(t, err == nil || errors.Is(err, currency.ErrPairAlreadyEnabled), + "EnablePair for asset %s and pair %s must not error: %s", bAssets[i], cps[0], err) _, err = b.GetHistoricCandlesExtended(t.Context(), cps[0], bAssets[i], kline.OneDay, startTime, end) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "GetHistoricCandlesExtended should not error for asset %s and pair %s", bAssets[i], cps[0]) } } @@ -2544,14 +2528,10 @@ func TestSetExchangeOrderExecutionLimits(t *testing.T) { } err = limit.Conforms(0.000001, 0.1, order.Limit) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Fatalf("expected %v, but received %v", order.ErrAmountBelowMin, err) - } + require.ErrorIs(t, err, order.ErrAmountBelowMin) err = limit.Conforms(0.01, 1, order.Limit) - if !errors.Is(err, order.ErrPriceBelowMin) { - t.Fatalf("expected %v, but received %v", order.ErrPriceBelowMin, err) - } + require.ErrorIs(t, err, order.ErrPriceBelowMin) } func TestWsOrderExecutionReport(t *testing.T) { @@ -2821,9 +2801,7 @@ func TestGetHistoricalFundingRates(t *testing.T) { IncludePayments: true, IncludePredictedRate: true, }) - if !errors.Is(err, common.ErrFunctionNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) _, err = b.GetHistoricalFundingRates(t.Context(), &fundingrate.HistoricalRatesRequest{ Asset: asset.USDTMarginedFutures, @@ -2832,9 +2810,7 @@ func TestGetHistoricalFundingRates(t *testing.T) { EndDate: e, PaymentCurrency: currency.DOGE, }) - if !errors.Is(err, common.ErrFunctionNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) r := &fundingrate.HistoricalRatesRequest{ Asset: asset.USDTMarginedFutures, @@ -2869,26 +2845,21 @@ func TestGetLatestFundingRates(t *testing.T) { Pair: cp, IncludePredictedRate: true, }) - if !errors.Is(err, common.ErrFunctionNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) + err = b.CurrencyPairs.EnablePair(asset.USDTMarginedFutures, cp) - if err != nil && !errors.Is(err, currency.ErrPairAlreadyEnabled) { - t.Fatal(err) - } + require.Truef(t, err == nil || errors.Is(err, currency.ErrPairAlreadyEnabled), + "EnablePair for asset %s and pair %s must not error: %s", asset.USDTMarginedFutures, cp, err) + _, err = b.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.USDTMarginedFutures, Pair: cp, }) - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "GetLatestFundingRates should not error for USDTMarginedFutures") _, err = b.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.CoinMarginedFutures, }) - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "GetLatestFundingRates should not error for CoinMarginedFutures") } func TestIsPerpetualFutureCurrency(t *testing.T) { @@ -2965,13 +2936,11 @@ func TestGetCollateralMode(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.GetCollateralMode(t.Context(), asset.Spot) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetCollateralMode(t.Context(), asset.CoinMarginedFutures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetCollateralMode(t.Context(), asset.USDTMarginedFutures) assert.NoError(t, err) } @@ -2980,20 +2949,16 @@ func TestSetCollateralMode(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SetCollateralMode(t.Context(), asset.Spot, collateral.SingleMode) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + err = b.SetCollateralMode(t.Context(), asset.CoinMarginedFutures, collateral.SingleMode) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + err = b.SetCollateralMode(t.Context(), asset.USDTMarginedFutures, collateral.MultiMode) assert.NoError(t, err) err = b.SetCollateralMode(t.Context(), asset.USDTMarginedFutures, collateral.PortfolioMode) - if !errors.Is(err, order.ErrCollateralInvalid) { - t.Errorf("received '%v', expected '%v'", err, order.ErrCollateralInvalid) - } + assert.ErrorIs(t, err, order.ErrCollateralInvalid) } func TestChangePositionMargin(t *testing.T) { @@ -3052,9 +3017,7 @@ func TestGetPositionSummary(t *testing.T) { Pair: p, UnderlyingPair: bb, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetFuturesPositionOrders(t *testing.T) { @@ -3100,9 +3063,7 @@ func TestSetMarginType(t *testing.T) { assert.NoError(t, err) err = b.SetMarginType(t.Context(), asset.Spot, currency.NewBTCUSDT(), margin.Isolated) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetLeverage(t *testing.T) { @@ -3122,9 +3083,7 @@ func TestGetLeverage(t *testing.T) { t.Error(err) } _, err = b.GetLeverage(t.Context(), asset.Spot, currency.NewBTCUSDT(), 0, order.UnknownSide) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestSetLeverage(t *testing.T) { @@ -3144,9 +3103,7 @@ func TestSetLeverage(t *testing.T) { t.Error(err) } err = b.SetLeverage(t.Context(), asset.Spot, p, margin.Multi, 5, order.UnknownSide) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetCryptoLoansIncomeHistory(t *testing.T) { @@ -3159,18 +3116,14 @@ func TestGetCryptoLoansIncomeHistory(t *testing.T) { func TestCryptoLoanBorrow(t *testing.T) { t.Parallel() - if _, err := b.CryptoLoanBorrow(t.Context(), currency.EMPTYCODE, 1000, currency.BTC, 1, 7); !errors.Is(err, errLoanCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanCoinMustBeSet) - } - if _, err := b.CryptoLoanBorrow(t.Context(), currency.USDT, 1000, currency.EMPTYCODE, 1, 7); !errors.Is(err, errCollateralCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errCollateralCoinMustBeSet) - } - if _, err := b.CryptoLoanBorrow(t.Context(), currency.USDT, 0, currency.BTC, 1, 0); !errors.Is(err, errLoanTermMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanTermMustBeSet) - } - if _, err := b.CryptoLoanBorrow(t.Context(), currency.USDT, 0, currency.BTC, 0, 7); !errors.Is(err, errEitherLoanOrCollateralAmountsMustBeSet) { - t.Errorf("received %v, expected %v", err, errEitherLoanOrCollateralAmountsMustBeSet) - } + _, err := b.CryptoLoanBorrow(t.Context(), currency.EMPTYCODE, 1000, currency.BTC, 1, 7) + assert.ErrorIs(t, err, errLoanCoinMustBeSet) + _, err = b.CryptoLoanBorrow(t.Context(), currency.USDT, 1000, currency.EMPTYCODE, 1, 7) + assert.ErrorIs(t, err, errCollateralCoinMustBeSet) + _, err = b.CryptoLoanBorrow(t.Context(), currency.USDT, 0, currency.BTC, 1, 0) + assert.ErrorIs(t, err, errLoanTermMustBeSet) + _, err = b.CryptoLoanBorrow(t.Context(), currency.USDT, 0, currency.BTC, 0, 7) + assert.ErrorIs(t, err, errEitherLoanOrCollateralAmountsMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.CryptoLoanBorrow(t.Context(), currency.USDT, 1000, currency.BTC, 1, 7); err != nil { @@ -3196,12 +3149,10 @@ func TestCryptoLoanOngoingOrders(t *testing.T) { func TestCryptoLoanRepay(t *testing.T) { t.Parallel() - if _, err := b.CryptoLoanRepay(t.Context(), 0, 1000, 1, false); !errors.Is(err, errOrderIDMustBeSet) { - t.Errorf("received %v, expected %v", err, errOrderIDMustBeSet) - } - if _, err := b.CryptoLoanRepay(t.Context(), 42069, 0, 1, false); !errors.Is(err, errAmountMustBeSet) { - t.Errorf("received %v, expected %v", err, errAmountMustBeSet) - } + _, err := b.CryptoLoanRepay(t.Context(), 0, 1000, 1, false) + assert.ErrorIs(t, err, errOrderIDMustBeSet) + _, err = b.CryptoLoanRepay(t.Context(), 42069, 0, 1, false) + assert.ErrorIs(t, err, errAmountMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.CryptoLoanRepay(t.Context(), 42069, 1000, 1, false); err != nil { @@ -3219,12 +3170,10 @@ func TestCryptoLoanRepaymentHistory(t *testing.T) { func TestCryptoLoanAdjustLTV(t *testing.T) { t.Parallel() - if _, err := b.CryptoLoanAdjustLTV(t.Context(), 0, true, 1); !errors.Is(err, errOrderIDMustBeSet) { - t.Errorf("received %v, expected %v", err, errOrderIDMustBeSet) - } - if _, err := b.CryptoLoanAdjustLTV(t.Context(), 42069, true, 0); !errors.Is(err, errAmountMustBeSet) { - t.Errorf("received %v, expected %v", err, errAmountMustBeSet) - } + _, err := b.CryptoLoanAdjustLTV(t.Context(), 0, true, 1) + assert.ErrorIs(t, err, errOrderIDMustBeSet) + _, err = b.CryptoLoanAdjustLTV(t.Context(), 42069, true, 0) + assert.ErrorIs(t, err, errAmountMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.CryptoLoanAdjustLTV(t.Context(), 42069, true, 1); err != nil { @@ -3258,15 +3207,12 @@ func TestCryptoLoanCollateralAssetsData(t *testing.T) { func TestCryptoLoanCheckCollateralRepayRate(t *testing.T) { t.Parallel() - if _, err := b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.EMPTYCODE, currency.BNB, 69); !errors.Is(err, errLoanCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanCoinMustBeSet) - } - if _, err := b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.BUSD, currency.EMPTYCODE, 69); !errors.Is(err, errCollateralCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errCollateralCoinMustBeSet) - } - if _, err := b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.BUSD, currency.BNB, 0); !errors.Is(err, errAmountMustBeSet) { - t.Errorf("received %v, expected %v", err, errAmountMustBeSet) - } + _, err := b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.EMPTYCODE, currency.BNB, 69) + assert.ErrorIs(t, err, errLoanCoinMustBeSet) + _, err = b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.BUSD, currency.EMPTYCODE, 69) + assert.ErrorIs(t, err, errCollateralCoinMustBeSet) + _, err = b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.BUSD, currency.BNB, 0) + assert.ErrorIs(t, err, errAmountMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if _, err := b.CryptoLoanCheckCollateralRepayRate(t.Context(), currency.BUSD, currency.BNB, 69); err != nil { @@ -3288,15 +3234,12 @@ func TestCryptoLoanCustomiseMarginCall(t *testing.T) { func TestFlexibleLoanBorrow(t *testing.T) { t.Parallel() - if _, err := b.FlexibleLoanBorrow(t.Context(), currency.EMPTYCODE, currency.USDC, 1, 0); !errors.Is(err, errLoanCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanCoinMustBeSet) - } - if _, err := b.FlexibleLoanBorrow(t.Context(), currency.ATOM, currency.EMPTYCODE, 1, 0); !errors.Is(err, errCollateralCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errCollateralCoinMustBeSet) - } - if _, err := b.FlexibleLoanBorrow(t.Context(), currency.ATOM, currency.USDC, 0, 0); !errors.Is(err, errEitherLoanOrCollateralAmountsMustBeSet) { - t.Errorf("received %v, expected %v", err, errEitherLoanOrCollateralAmountsMustBeSet) - } + _, err := b.FlexibleLoanBorrow(t.Context(), currency.EMPTYCODE, currency.USDC, 1, 0) + assert.ErrorIs(t, err, errLoanCoinMustBeSet) + _, err = b.FlexibleLoanBorrow(t.Context(), currency.ATOM, currency.EMPTYCODE, 1, 0) + assert.ErrorIs(t, err, errCollateralCoinMustBeSet) + _, err = b.FlexibleLoanBorrow(t.Context(), currency.ATOM, currency.USDC, 0, 0) + assert.ErrorIs(t, err, errEitherLoanOrCollateralAmountsMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.FlexibleLoanBorrow(t.Context(), currency.ATOM, currency.USDC, 1, 0); err != nil { @@ -3322,16 +3265,12 @@ func TestFlexibleLoanBorrowHistory(t *testing.T) { func TestFlexibleLoanRepay(t *testing.T) { t.Parallel() - - if _, err := b.FlexibleLoanRepay(t.Context(), currency.EMPTYCODE, currency.BTC, 1, false, false); !errors.Is(err, errLoanCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanCoinMustBeSet) - } - if _, err := b.FlexibleLoanRepay(t.Context(), currency.USDT, currency.EMPTYCODE, 1, false, false); !errors.Is(err, errCollateralCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errCollateralCoinMustBeSet) - } - if _, err := b.FlexibleLoanRepay(t.Context(), currency.USDT, currency.BTC, 0, false, false); !errors.Is(err, errAmountMustBeSet) { - t.Errorf("received %v, expected %v", err, errAmountMustBeSet) - } + _, err := b.FlexibleLoanRepay(t.Context(), currency.EMPTYCODE, currency.BTC, 1, false, false) + assert.ErrorIs(t, err, errLoanCoinMustBeSet) + _, err = b.FlexibleLoanRepay(t.Context(), currency.USDT, currency.EMPTYCODE, 1, false, false) + assert.ErrorIs(t, err, errCollateralCoinMustBeSet) + _, err = b.FlexibleLoanRepay(t.Context(), currency.USDT, currency.BTC, 0, false, false) + assert.ErrorIs(t, err, errAmountMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.FlexibleLoanRepay(t.Context(), currency.ATOM, currency.USDC, 1, false, false); err != nil { @@ -3349,12 +3288,10 @@ func TestFlexibleLoanRepayHistory(t *testing.T) { func TestFlexibleLoanAdjustLTV(t *testing.T) { t.Parallel() - if _, err := b.FlexibleLoanAdjustLTV(t.Context(), currency.EMPTYCODE, currency.BTC, 1, true); !errors.Is(err, errLoanCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errLoanCoinMustBeSet) - } - if _, err := b.FlexibleLoanAdjustLTV(t.Context(), currency.USDT, currency.EMPTYCODE, 1, true); !errors.Is(err, errCollateralCoinMustBeSet) { - t.Errorf("received %v, expected %v", err, errCollateralCoinMustBeSet) - } + _, err := b.FlexibleLoanAdjustLTV(t.Context(), currency.EMPTYCODE, currency.BTC, 1, true) + assert.ErrorIs(t, err, errLoanCoinMustBeSet) + _, err = b.FlexibleLoanAdjustLTV(t.Context(), currency.USDT, currency.EMPTYCODE, 1, true) + assert.ErrorIs(t, err, errCollateralCoinMustBeSet) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) if _, err := b.FlexibleLoanAdjustLTV(t.Context(), currency.USDT, currency.BTC, 1, true); err != nil { @@ -3389,13 +3326,11 @@ func TestFlexibleCollateralAssetsData(t *testing.T) { func TestGetFuturesContractDetails(t *testing.T) { t.Parallel() _, err := b.GetFuturesContractDetails(t.Context(), asset.Spot) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Error(err) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) + _, err = b.GetFuturesContractDetails(t.Context(), asset.Futures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetFuturesContractDetails(t.Context(), asset.USDTMarginedFutures) assert.NoError(t, err) diff --git a/exchanges/binance/ratelimit_test.go b/exchanges/binance/ratelimit_test.go index 8625ea43..febda44a 100644 --- a/exchanges/binance/ratelimit_test.go +++ b/exchanges/binance/ratelimit_test.go @@ -51,9 +51,8 @@ func TestRateLimit_Limit(t *testing.T) { defer cancel() } - if err := rl.InitiateRateLimit(ctx, tt.Limit); err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("error applying rate limit: %v", err) - } + err := rl.InitiateRateLimit(ctx, tt.Limit) + require.Truef(t, err == nil || errors.Is(err, context.DeadlineExceeded), "InitiateRateLimit must not error: %s", err) }) } } diff --git a/exchanges/binanceus/binanceus_test.go b/exchanges/binanceus/binanceus_test.go index ebad1dba..8c4f4108 100644 --- a/exchanges/binanceus/binanceus_test.go +++ b/exchanges/binanceus/binanceus_test.go @@ -1,7 +1,6 @@ package binanceus import ( - "errors" "log" "os" reflects "reflect" @@ -273,14 +272,11 @@ func TestGetOrderInfo(t *testing.T) { func TestGetDepositAddress(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) _, err := bi.GetDepositAddress(t.Context(), currency.EMPTYCODE, "", currency.BNB.String()) - if err != nil && !errors.Is(err, errMissingRequiredArgumentCoin) { - t.Errorf("Binanceus GetDepositAddress() expecting %v, but found %v", errMissingRequiredArgumentCoin, err) - } - if _, err := bi.GetDepositAddress(t.Context(), currency.USDT, "", currency.BNB.String()); err != nil { - t.Error("Binanceus GetDepositAddress() error", err) - } + assert.ErrorIs(t, err, errMissingRequiredArgumentCoin) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) + _, err = bi.GetDepositAddress(t.Context(), currency.USDT, "", currency.BNB.String()) + assert.NoError(t, err) } func TestGetWithdrawalHistory(t *testing.T) { @@ -390,9 +386,7 @@ func TestGetHistoricCandles(t *testing.T) { endTime := time.Date(2021, 2, 15, 0, 0, 0, 0, time.UTC) _, err := bi.GetHistoricCandles(t.Context(), pair, asset.Spot, kline.Interval(time.Hour*5), startTime, endTime) - if !errors.Is(err, kline.ErrRequestExceedsExchangeLimits) { - t.Fatalf("received: '%v', but expected: '%v'", err, kline.ErrRequestExceedsExchangeLimits) - } + require.ErrorIs(t, err, kline.ErrRequestExceedsExchangeLimits) _, err = bi.GetHistoricCandles(t.Context(), pair, asset.Spot, kline.OneDay, startTime, endTime) if err != nil { @@ -590,32 +584,30 @@ func TestGetMasterAccountTotalUSDValue(t *testing.T) { func TestGetSubaccountStatusList(t *testing.T) { t.Parallel() + _, err := bi.GetSubaccountStatusList(t.Context(), "") + assert.ErrorIs(t, err, errMissingSubAccountEmail) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - if _, er := bi.GetSubaccountStatusList(t.Context(), ""); er != nil && !errors.Is(er, errMissingSubAccountEmail) { - t.Errorf("Binanceus GetSubaccountStatusList() expecting %v, but found %v", errMissingSubAccountEmail, er) - } - if _, er := bi.GetSubaccountStatusList(t.Context(), "someone@thrasher.corp"); er != nil && !strings.Contains(er.Error(), "Sub-account function is not enabled.") { - t.Errorf("Binanceus GetSubaccountStatusList() expecting %s, but found %v", "Sub-account function is not enabled.", er) - } + _, err = bi.GetSubaccountStatusList(t.Context(), "someone@thrasher.corp") + assert.ErrorContains(t, err, "Sub-account function is not enabled.") } func TestGetSubAccountDepositAddress(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - if _, er := bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{}); er != nil && !errors.Is(er, errMissingSubAccountEmail) { - t.Errorf("Binanceus GetSubAccountDepositAddress() %v, but found %v", errMissingSubAccountEmail, er) - } - if _, er := bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{ + _, err := bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{}) + assert.ErrorIs(t, err, errMissingSubAccountEmail) + _, err = bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{ Email: "someone@thrasher.io", - }); er != nil && !errors.Is(er, errMissingCurrencyCoin) { - t.Errorf("Binanceus GetSubAccountDepositAddress() %v, but found %v", errMissingCurrencyCoin, er) - } - if _, er := bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{ + }) + assert.ErrorIs(t, err, errMissingCurrencyCoin) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) + + _, err = bi.GetSubAccountDepositAddress(t.Context(), SubAccountDepositAddressRequestParams{ Email: "someone@thrasher.io", Coin: currency.BTC, - }); er != nil && !strings.Contains(er.Error(), "This parent sub have no relation") { - t.Errorf("Binanceus GetSubAccountDepositAddress() %v, but found %v", errMissingCurrencyCoin, er) - } + }) + assert.ErrorContains(t, err, "This parent sub have no relation") } var subAccountDepositHistoryItemJSON = `{ @@ -634,16 +626,14 @@ var subAccountDepositHistoryItemJSON = `{ func TestGetSubAccountDepositHistory(t *testing.T) { t.Parallel() var resp SubAccountDepositItem - if er := json.Unmarshal([]byte(subAccountDepositHistoryItemJSON), &resp); er != nil { - t.Error("Binanceus Decerializing to SubAccountDepositItem error", er) - } + require.NoError(t, json.Unmarshal([]byte(subAccountDepositHistoryItemJSON), &resp)) + _, err := bi.GetSubAccountDepositHistory(t.Context(), "", currency.BTC, 1, time.Time{}, time.Time{}, 0, 0) + assert.ErrorIs(t, err, errMissingSubAccountEmail) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - if _, er := bi.GetSubAccountDepositHistory(t.Context(), "", currency.BTC, 1, time.Time{}, time.Time{}, 0, 0); er != nil && !errors.Is(er, errMissingSubAccountEmail) { - t.Errorf("Binanceus GetSubAccountDepositHistory() expecting %v, but found %v", errMissingSubAccountEmail, er) - } - if _, er := bi.GetSubAccountDepositHistory(t.Context(), "someone@thrasher.io", currency.BTC, 1, time.Time{}, time.Time{}, 0, 0); er != nil && !strings.Contains(er.Error(), "This parent sub have no relation") { - t.Errorf("Binanceus GetSubAccountDepositHistory() expecting %s, but found %v", "This parent sub have no relation", er) - } + + _, err = bi.GetSubAccountDepositHistory(t.Context(), "someone@thrasher.io", currency.BTC, 1, time.Time{}, time.Time{}, 0, 0) + assert.ErrorContains(t, err, "This parent sub have no relation") } var subaccountItemJSON = `{ @@ -683,22 +673,16 @@ var referalRewardHistoryResponse = `{ func TestGetReferralRewardHistory(t *testing.T) { t.Parallel() var resp ReferralRewardHistoryResponse - if er := json.Unmarshal([]byte(referalRewardHistoryResponse), &resp); er != nil { - t.Error("Binanceus decerializing to ReferalRewardHistoryResponse error", er) - } + require.NoError(t, json.Unmarshal([]byte(referalRewardHistoryResponse), &resp)) + _, err := bi.GetReferralRewardHistory(t.Context(), 9, 5, 50) + assert.ErrorIs(t, err, errInvalidUserBusinessType) + _, err = bi.GetReferralRewardHistory(t.Context(), 1, 0, 50) + assert.ErrorIs(t, err, errMissingPageNumber) + _, err = bi.GetReferralRewardHistory(t.Context(), 1, 5, 0) + assert.ErrorIs(t, err, errInvalidRowNumber) sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - if _, er := bi.GetReferralRewardHistory(t.Context(), 9, 5, 50); !errors.Is(er, errInvalidUserBusinessType) { - t.Errorf("Binanceus GetReferralRewardHistory() expecting %v, but found %v", errInvalidUserBusinessType, er) - } - if _, er := bi.GetReferralRewardHistory(t.Context(), 1, 0, 50); !errors.Is(er, errMissingPageNumber) { - t.Errorf("Binanceus GetReferralRewardHistory() expecting %v, but found %v", errMissingPageNumber, er) - } - if _, er := bi.GetReferralRewardHistory(t.Context(), 1, 5, 0); !errors.Is(er, errInvalidRowNumber) { - t.Errorf("Binanceus GetReferralRewardHistory() expecting %v, but found %v", errInvalidRowNumber, er) - } - if _, er := bi.GetReferralRewardHistory(t.Context(), 1, 5, 50); er != nil { - t.Error("Binanceus GetReferralRewardHistory() error", er) - } + _, err = bi.GetReferralRewardHistory(t.Context(), 1, 5, 50) + assert.NoError(t, err) } func TestGetSubaccountTransferHistory(t *testing.T) { @@ -715,33 +699,26 @@ func TestGetSubaccountTransferHistory(t *testing.T) { func TestExecuteSubAccountTransfer(t *testing.T) { t.Parallel() + _, err := bi.ExecuteSubAccountTransfer(t.Context(), &SubAccountTransferRequestParams{}) + assert.ErrorIs(t, err, errUnacceptableSenderEmail) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - _, er := bi.ExecuteSubAccountTransfer(t.Context(), &SubAccountTransferRequestParams{}) - if !errors.Is(er, errUnacceptableSenderEmail) { - t.Errorf("binanceus error: expected %v, but found %v", errUnacceptableSenderEmail, er) - } - _, er = bi.ExecuteSubAccountTransfer(t.Context(), &SubAccountTransferRequestParams{ + _, err = bi.ExecuteSubAccountTransfer(t.Context(), &SubAccountTransferRequestParams{ FromEmail: "fromemail@thrasher.io", - ToEmail: "toemail@threasher.io", + ToEmail: "toemail@thrasher.io", Asset: "BTC", Amount: 0.000005, }) - if er != nil && !strings.Contains(er.Error(), "You are not authorized to execute this request.") { - t.Errorf("Binanceus GetSubaccountTransferHistory() error %v", er) - } + assert.ErrorContains(t, err, "You are not authorized to execute this request.") } func TestGetSubaccountAssets(t *testing.T) { t.Parallel() + _, err := bi.GetSubaccountAssets(t.Context(), "") + assert.ErrorIs(t, err, errNotValidEmailAddress) sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - _, er := bi.GetSubaccountAssets(t.Context(), "") - if !errors.Is(er, errNotValidEmailAddress) { - t.Errorf("Binanceus GetSubaccountAssets() expected %v, but found %v", er, errNotValidEmailAddress) - } - _, er = bi.GetSubaccountAssets(t.Context(), "subaccount@thrasher.io") - if er != nil && !strings.Contains(er.Error(), "This account does not exist.") { - t.Fatal("Binanceus GetSubaccountAssets() error", er) - } + _, err = bi.GetSubaccountAssets(t.Context(), "subaccount@thrasher.io") + assert.ErrorContains(t, err, "This account does not exist.") } func TestGetOrderRateLimits(t *testing.T) { @@ -819,19 +796,15 @@ func TestNewOrder(t *testing.T) { func TestGetOrder(t *testing.T) { t.Parallel() + _, err := bi.GetOrder(t.Context(), &OrderRequestParams{}) + assert.ErrorIs(t, err, errIncompleteArguments) sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - _, er := bi.GetOrder(t.Context(), &OrderRequestParams{}) - if !errors.Is(er, errIncompleteArguments) { - t.Errorf("Binanceus GetOrder() error expecting %v, but found %v", errIncompleteArguments, er) - } - _, er = bi.GetOrder(t.Context(), &OrderRequestParams{ + _, err = bi.GetOrder(t.Context(), &OrderRequestParams{ Symbol: "BTCUSDT", OrigClientOrderID: "something", }) // You can check the existence of an order using a valid Symbol and OrigClient Order ID - if er != nil && !strings.Contains(er.Error(), "Order does not exist.") { - t.Error("Binanceus GetOrder() error", er) - } + assert.ErrorContains(t, err, "Order does not exist.") } var openOrdersItemJSON = `{ @@ -890,47 +863,42 @@ func TestCancelExistingOrder(t *testing.T) { func TestCancelOpenOrdersForSymbol(t *testing.T) { t.Parallel() + _, err := bi.CancelOpenOrdersForSymbol(t.Context(), "") + assert.ErrorIs(t, err, errMissingCurrencySymbol) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - _, er := bi.CancelOpenOrdersForSymbol(t.Context(), "") - if !errors.Is(er, errMissingCurrencySymbol) { - t.Errorf("Binanceus CancelOpenOrdersForSymbol() error expecting %v, but found %v", errIncompleteArguments, er) - } - _, er = bi.CancelOpenOrdersForSymbol(t.Context(), "BTCUSDT") - if er != nil && !strings.Contains(er.Error(), "Unknown order sent") { - t.Error("Binanceus CancelOpenOrdersForSymbol() error", er) - } + + _, err = bi.CancelOpenOrdersForSymbol(t.Context(), "BTCUSDT") + assert.NoError(t, err) } // TestGetTrades test for fetching the list of // trades attached with this account. func TestGetTrades(t *testing.T) { t.Parallel() + _, err := bi.GetTrades(t.Context(), &GetTradesParams{}) + assert.ErrorIs(t, err, errIncompleteArguments) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - _, er := bi.GetTrades(t.Context(), &GetTradesParams{}) - if !errors.Is(er, errIncompleteArguments) { - t.Errorf(" Binanceus GetTrades() expecting error %v, but found %v", errIncompleteArguments, er) - } - _, er = bi.GetTrades(t.Context(), &GetTradesParams{Symbol: "BTCUSDT"}) - if er != nil { - t.Error("Binanceus GetTrades() error", er) - } + + _, err = bi.GetTrades(t.Context(), &GetTradesParams{Symbol: "BTCUSDT"}) + assert.NoError(t, err) } func TestCreateNewOCOOrder(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - _, er := bi.CreateNewOCOOrder(t.Context(), + _, err := bi.CreateNewOCOOrder(t.Context(), &OCOOrderInputParams{ StopPrice: 1000, Side: order.Buy.String(), Quantity: 0.0000001, Price: 1232334.00, }) - if !errors.Is(er, errIncompleteArguments) { - t.Errorf("Binanceus CreatenewOCOOrder() error expected %v, but found %v", errIncompleteArguments, er) - } - _, er = bi.CreateNewOCOOrder( - t.Context(), + assert.ErrorIs(t, err, errIncompleteArguments) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) + + _, err = bi.CreateNewOCOOrder(t.Context(), &OCOOrderInputParams{ Symbol: "XTZUSD", Price: 100, @@ -941,9 +909,7 @@ func TestCreateNewOCOOrder(t *testing.T) { StopLimitTimeInForce: "GTC", RecvWindow: 6000, }) - if er != nil && !strings.Contains(er.Error(), "Precision is over the maximum defined for this asset.") { - t.Error("Binanceus CreateNewOCOOrder() error", er) - } + assert.ErrorContains(t, err, "Precision is over the maximum defined for this asset.") } var ocoOrderJSON = `{ @@ -971,20 +937,16 @@ var ocoOrderJSON = `{ func TestGetOCOOrder(t *testing.T) { t.Parallel() var resp OCOOrderResponse - if er := json.Unmarshal([]byte(ocoOrderJSON), &resp); er != nil { - t.Error("Binanceus decerializing OCOOrderResponse error", er) - } + require.NoError(t, json.Unmarshal([]byte(ocoOrderJSON), &resp)) + _, err := bi.GetOCOOrder(t.Context(), &GetOCOOrderRequestParams{}) + assert.ErrorIs(t, err, errIncompleteArguments) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - _, er := bi.GetOCOOrder(t.Context(), &GetOCOOrderRequestParams{}) - if !errors.Is(er, errIncompleteArguments) { - t.Errorf("Binanceus GetOCOOrder() error expecting %v, but found %v", errIncompleteArguments, er) - } - _, er = bi.GetOCOOrder(t.Context(), &GetOCOOrderRequestParams{ + + _, err = bi.GetOCOOrder(t.Context(), &GetOCOOrderRequestParams{ OrderListID: "123445", }) - if er != nil && !strings.Contains(er.Error(), "Order list does not exist.") { - t.Error("Binanceus GetOCOOrder() error", er) - } + assert.ErrorContains(t, err, "Order list does not exist.") } func TestGetAllOCOOrder(t *testing.T) { @@ -1007,11 +969,14 @@ func TestGetOpenOCOOrders(t *testing.T) { func TestCancelOCOOrder(t *testing.T) { t.Parallel() + _, err := bi.CancelOCOOrder(t.Context(), &OCOOrdersDeleteRequestParams{}) + assert.ErrorIs(t, err, errIncompleteArguments) sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - _, er := bi.CancelOCOOrder(t.Context(), &OCOOrdersDeleteRequestParams{}) - if !errors.Is(er, errIncompleteArguments) { - t.Errorf("Binanceus CancelOCOOrder() error expected %v, but found %v", errIncompleteArguments, er) - } + _, err = bi.CancelOCOOrder(t.Context(), &OCOOrdersDeleteRequestParams{ + Symbol: "BTCUSDT", + OrderListID: 123456, + }) + assert.NoError(t, err) } // OTC end Points test code. @@ -1026,27 +991,19 @@ func TestGetSupportedCoinPairs(t *testing.T) { func TestRequestForQuote(t *testing.T) { t.Parallel() + _, err := bi.RequestForQuote(t.Context(), &RequestQuoteParams{ToCoin: "BTC", RequestCoin: "USDT", RequestAmount: 1}) + assert.ErrorIs(t, err, errMissingFromCoinName) + _, err = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", RequestCoin: "USDT", RequestAmount: 1}) + assert.ErrorIs(t, err, errMissingToCoinName) + _, err = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", ToCoin: "BTC", RequestCoin: "USDT"}) + assert.ErrorIs(t, err, errMissingRequestAmount) + _, err = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", ToCoin: "BTC", RequestAmount: 1}) + assert.ErrorIs(t, err, errMissingRequestCoin) + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi) - _, er := bi.RequestForQuote(t.Context(), &RequestQuoteParams{ToCoin: "BTC", RequestCoin: "USDT", RequestAmount: 1}) - if er != nil && !errors.Is(er, errMissingFromCoinName) { - t.Errorf("Binanceus RequestForQuote() expecting %v, but found %v", errMissingFromCoinName, er) - } - _, er = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", RequestCoin: "USDT", RequestAmount: 1}) - if er != nil && !errors.Is(er, errMissingToCoinName) { - t.Errorf("Binanceus RequestForQuote() expecting %v, but found %v", errMissingToCoinName, er) - } - _, er = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", ToCoin: "BTC", RequestCoin: "USDT"}) - if er != nil && !errors.Is(er, errMissingRequestAmount) { - t.Errorf("Binanceus RequestForQuote() expecting %v, but found %v", errMissingRequestAmount, er) - } - _, er = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "ETH", ToCoin: "BTC", RequestAmount: 1}) - if er != nil && !errors.Is(er, errMissingRequestCoin) { - t.Errorf("Binanceus RequestForQuote() expecting %v, but found %v", errMissingRequestCoin, er) - } - _, er = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "BTC", ToCoin: "USDT", RequestCoin: "BTC", RequestAmount: 1}) - if er != nil { - t.Error("Binanceus RequestForQuote() error", er) - } + + _, err = bi.RequestForQuote(t.Context(), &RequestQuoteParams{FromCoin: "BTC", ToCoin: "USDT", RequestCoin: "BTC", RequestAmount: 1}) + assert.NoError(t, err) } var testPlaceOTCTradeOrderJSON = `{ @@ -1057,20 +1014,13 @@ var testPlaceOTCTradeOrderJSON = `{ func TestPlaceOTCTradeOrder(t *testing.T) { t.Parallel() + var resp OTCTradeOrderResponse + require.NoError(t, json.Unmarshal([]byte(testPlaceOTCTradeOrderJSON), &resp)) + _, err := bi.PlaceOTCTradeOrder(t.Context(), "") + assert.ErrorIs(t, err, errMissingQuoteID) sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - var res OTCTradeOrderResponse - er := json.Unmarshal([]byte(testPlaceOTCTradeOrderJSON), &res) - if er != nil { - t.Error("Binanceus PlaceOTCTradeOrder() error", er) - } - _, er = bi.PlaceOTCTradeOrder(t.Context(), "") - if !errors.Is(er, errMissingQuoteID) { - t.Errorf("Binanceus PlaceOTCTradeOrder() expecting %v, but found %v", errMissingQuoteID, er) - } - _, er = bi.PlaceOTCTradeOrder(t.Context(), "15848701022") - if er != nil && !strings.Contains(er.Error(), "-9000") { - t.Error("Binanceus PlaceOTCTradeOrder() error", er) - } + _, err = bi.PlaceOTCTradeOrder(t.Context(), "15848701022") + assert.ErrorContains(t, err, "-9000") } var testGetOTCTradeOrderJSON = `{ @@ -1180,31 +1130,28 @@ func TestGetAssetFeesAndWalletStatus(t *testing.T) { func TestWithdrawCrypto(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) - _, er := bi.WithdrawCrypto(t.Context(), &withdraw.Request{}) - if !errors.Is(er, errMissingRequiredArgumentCoin) { - t.Errorf("Binanceus WithdrawCrypto() error expecting %v, but found %v", errMissingRequiredArgumentCoin, er) - } - if _, er = bi.WithdrawCrypto(t.Context(), &withdraw.Request{ + + _, err := bi.WithdrawCrypto(t.Context(), &withdraw.Request{}) + assert.ErrorIs(t, err, errMissingRequiredArgumentCoin) + _, err = bi.WithdrawCrypto(t.Context(), &withdraw.Request{ Currency: currency.BTC, - }); !errors.Is(er, errMissingRequiredArgumentNetwork) { - t.Errorf("Binanceus WithdrawCrypto() expecting %v, but found %v", errMissingRequiredArgumentNetwork, er) - } + }) + assert.ErrorIs(t, err, errMissingRequiredArgumentNetwork) params := &withdraw.Request{ Currency: currency.BTC, + Crypto: withdraw.CryptoRequest{ + Chain: "BSC", + }, } - params.Crypto.Chain = "BSC" - if _, er = bi.WithdrawCrypto(t.Context(), params); !errors.Is(er, errMissingRequiredParameterAddress) { - t.Errorf("Binanceus WithdrawCrypto() expecting %v, but found %v", errMissingRequiredParameterAddress, er) - } + _, err = bi.WithdrawCrypto(t.Context(), params) + assert.ErrorIs(t, err, errMissingRequiredParameterAddress) params.Crypto.Address = "1234567" - if _, er = bi.WithdrawCrypto(t.Context(), params); !errors.Is(er, errAmountValueMustBeGreaterThan0) { - t.Errorf("Binanceus WithdrawCrypto() expecting %v, but found %v", errAmountValueMustBeGreaterThan0, er) - } + _, err = bi.WithdrawCrypto(t.Context(), params) + assert.ErrorIs(t, err, errAmountValueMustBeGreaterThan0) params.Amount = 1 - if _, er = bi.WithdrawCrypto(t.Context(), params); er != nil && !strings.Contains(er.Error(), "You are not authorized to execute this request.") { - t.Error("Binanceus WithdrawCrypto() error", er) - } + sharedtestvalues.SkipTestIfCredentialsUnset(t, bi, canManipulateRealOrders) + _, err = bi.WithdrawCrypto(t.Context(), params) + assert.ErrorContains(t, err, "You are not authorized to execute this request.") } func TestFiatWithdrawalHistory(t *testing.T) { diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index e9949cc3..521d2c3b 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -2,7 +2,6 @@ package bitfinex import ( "bufio" - "errors" "log" "os" "strconv" @@ -130,9 +129,7 @@ func TestGetPairs(t *testing.T) { t.Parallel() _, err := b.GetPairs(t.Context(), asset.Binary) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) assets := b.GetAssetTypes(false) for x := range assets { @@ -1697,17 +1694,13 @@ func TestFixCasing(t *testing.T) { } _, err = b.fixCasing(currency.NewPair(currency.EMPTYCODE, currency.BTC), asset.MarginFunding) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) _, err = b.fixCasing(currency.NewPair(currency.BTC, currency.EMPTYCODE), asset.MarginFunding) require.NoError(t, err) _, err = b.fixCasing(currency.EMPTYPAIR, asset.MarginFunding) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) } func Test_FormatExchangeKlineInterval(t *testing.T) { @@ -1931,9 +1924,7 @@ func TestGetSiteListConfigData(t *testing.T) { t.Parallel() _, err := b.GetSiteListConfigData(t.Context(), "") - if !errors.Is(err, errSetCannotBeEmpty) { - t.Fatalf("received: %v, expected: %v", err, errSetCannotBeEmpty) - } + require.ErrorIs(t, err, errSetCannotBeEmpty) pairs, err := b.GetSiteListConfigData(t.Context(), bitfinexSecuritiesPairs) require.NoError(t, err) diff --git a/exchanges/btcmarkets/btcmarkets_test.go b/exchanges/btcmarkets/btcmarkets_test.go index 70b9c6fd..b48495bf 100644 --- a/exchanges/btcmarkets/btcmarkets_test.go +++ b/exchanges/btcmarkets/btcmarkets_test.go @@ -2,7 +2,6 @@ package btcmarkets import ( "context" - "errors" "fmt" "log" "os" @@ -203,9 +202,8 @@ func TestSubmitOrder(t *testing.T) { Pair: currency.NewPair(currency.BTC, currency.AUD), TimeInForce: order.PostOnly, }) - if !errors.Is(err, order.ErrTypeIsInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, order.ErrTypeIsInvalid) - } + require.ErrorIs(t, err, order.ErrTypeIsInvalid) + _, err = b.SubmitOrder(t.Context(), &order.Submit{ Exchange: b.Name, Price: 100, @@ -216,9 +214,7 @@ func TestSubmitOrder(t *testing.T) { Pair: currency.NewPair(currency.BTC, currency.AUD), TimeInForce: order.PostOnly, }) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, order.ErrSideIsInvalid) - } + require.ErrorIs(t, err, order.ErrSideIsInvalid) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) @@ -882,9 +878,7 @@ func TestChecksum(t *testing.T) { t.Fatal(err) } err = checksum(b, uint32(1223123)) - if !errors.Is(err, errChecksumFailure) { - t.Errorf("received '%v', expected '%v'", err, errChecksumFailure) - } + assert.ErrorIs(t, err, errChecksumFailure) } func TestTrim(t *testing.T) { @@ -917,9 +911,7 @@ func TestTrim(t *testing.T) { func TestFormatOrderType(t *testing.T) { t.Parallel() _, err := b.formatOrderType(0) - if !errors.Is(err, order.ErrTypeIsInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, order.ErrTypeIsInvalid) - } + require.ErrorIs(t, err, order.ErrTypeIsInvalid) r, err := b.formatOrderType(order.Limit) require.NoError(t, err) @@ -960,9 +952,7 @@ func TestFormatOrderType(t *testing.T) { func TestFormatOrderSide(t *testing.T) { t.Parallel() _, err := b.formatOrderSide(255) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, order.ErrSideIsInvalid) - } + require.ErrorIs(t, err, order.ErrSideIsInvalid) f, err := b.formatOrderSide(order.Bid) require.NoError(t, err) @@ -994,19 +984,13 @@ func TestGetTimeInForce(t *testing.T) { func TestReplaceOrder(t *testing.T) { t.Parallel() _, err := b.ReplaceOrder(t.Context(), "", "bro", 0, 0) - if !errors.Is(err, errInvalidAmount) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidAmount) - } + require.ErrorIs(t, err, errInvalidAmount) _, err = b.ReplaceOrder(t.Context(), "", "bro", 1, 0) - if !errors.Is(err, errInvalidAmount) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidAmount) - } + require.ErrorIs(t, err, errInvalidAmount) _, err = b.ReplaceOrder(t.Context(), "", "bro", 1, 1) - if !errors.Is(err, errIDRequired) { - t.Fatalf("received: '%v' but expected: '%v'", err, errIDRequired) - } + require.ErrorIs(t, err, errIDRequired) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) @@ -1017,9 +1001,7 @@ func TestReplaceOrder(t *testing.T) { func TestWrapperModifyOrder(t *testing.T) { t.Parallel() _, err := b.ModifyOrder(t.Context(), &order.Modify{}) - if !errors.Is(err, order.ErrPairIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, order.ErrPairIsEmpty) - } + require.ErrorIs(t, err, order.ErrPairIsEmpty) sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) @@ -1041,9 +1023,7 @@ func TestWrapperModifyOrder(t *testing.T) { func TestUpdateOrderExecutionLimits(t *testing.T) { t.Parallel() err := b.UpdateOrderExecutionLimits(t.Context(), asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v' but expected: '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = b.UpdateOrderExecutionLimits(t.Context(), asset.Spot) require.NoError(t, err) @@ -1060,9 +1040,7 @@ func TestConvertToKlineCandle(t *testing.T) { t.Parallel() _, err := convertToKlineCandle(nil) - if !errors.Is(err, errFailedToConvertToCandle) { - t.Fatalf("received: '%v' but expected: '%v'", err, errFailedToConvertToCandle) - } + require.ErrorIs(t, err, errFailedToConvertToCandle) data := [6]string{time.RFC3339[:len(time.RFC3339)-5], "1.0", "2", "3", "4", "5"} diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index d28621de..f651bb23 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -49,45 +49,44 @@ const ( ) var ( - errCategoryNotSet = errors.New("category not set") - errBaseNotSet = errors.New("base coin not set when category is option") - errInvalidTriggerDirection = errors.New("invalid trigger direction") - errInvalidTriggerPriceType = errors.New("invalid trigger price type") - errNilArgument = errors.New("nil argument") - errMissingUserID = errors.New("sub user id missing") - errMissingUsername = errors.New("username is missing") - errInvalidMemberType = errors.New("invalid member type") - errMissingTransferID = errors.New("transfer ID is required") - errMemberIDRequired = errors.New("member ID is required") - errNonePointerArgument = errors.New("argument must be pointer") - errEitherOrderIDOROrderLinkIDRequired = errors.New("either orderId or orderLinkId required") - errNoOrderPassed = errors.New("no order passed") - errSymbolOrSettleCoinRequired = errors.New("provide symbol or settleCoin at least one") - errInvalidTradeModeValue = errors.New("invalid trade mode value") - errTakeProfitOrStopLossModeMissing = errors.New("TP/SL mode missing") - errMissingAccountType = errors.New("account type not specified") - errMembersIDsNotSet = errors.New("members IDs not set") - errMissingChainType = errors.New("missing chain type is empty") - errMissingChainInformation = errors.New("missing transfer chain") - errMissingAddressInfo = errors.New("address is required") - errMissingWithdrawalID = errors.New("missing withdrawal id") - errTimeWindowRequired = errors.New("time window is required") - errFrozenPeriodRequired = errors.New("frozen period required") - errQuantityLimitRequired = errors.New("quantity limit required") - errInvalidPushData = errors.New("invalid push data") - errInvalidLeverage = errors.New("leverage can't be zero or less then it") - errInvalidPositionMode = errors.New("position mode is invalid") - errInvalidMode = errors.New("mode can't be empty or missing") - errInvalidOrderFilter = errors.New("invalid order filter") - errInvalidCategory = errors.New("invalid category") - errEitherSymbolOrCoinRequired = errors.New("either symbol or coin required") - errOrderLinkIDMissing = errors.New("order link id missing") - errSymbolMissing = errors.New("symbol missing") - errInvalidAutoAddMarginValue = errors.New("invalid add auto margin value") - errDisconnectTimeWindowNotSet = errors.New("disconnect time window not set") - errAPIKeyIsNotUnified = errors.New("api key is not unified") - errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") - errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") + errCategoryNotSet = errors.New("category not set") + errBaseNotSet = errors.New("base coin not set when category is option") + errInvalidTriggerDirection = errors.New("invalid trigger direction") + errInvalidTriggerPriceType = errors.New("invalid trigger price type") + errNilArgument = errors.New("nil argument") + errMissingUserID = errors.New("sub user id missing") + errMissingUsername = errors.New("username is missing") + errInvalidMemberType = errors.New("invalid member type") + errMissingTransferID = errors.New("transfer ID is required") + errMemberIDRequired = errors.New("member ID is required") + errNonePointerArgument = errors.New("argument must be pointer") + errEitherOrderIDOROrderLinkIDRequired = errors.New("either orderId or orderLinkId required") + errNoOrderPassed = errors.New("no order passed") + errSymbolOrSettleCoinRequired = errors.New("provide symbol or settleCoin at least one") + errInvalidTradeModeValue = errors.New("invalid trade mode value") + errTakeProfitOrStopLossModeMissing = errors.New("TP/SL mode missing") + errMissingAccountType = errors.New("account type not specified") + errMembersIDsNotSet = errors.New("members IDs not set") + errMissingChainType = errors.New("missing chain type is empty") + errMissingChainInformation = errors.New("missing transfer chain") + errMissingAddressInfo = errors.New("address is required") + errMissingWithdrawalID = errors.New("missing withdrawal id") + errTimeWindowRequired = errors.New("time window is required") + errFrozenPeriodRequired = errors.New("frozen period required") + errQuantityLimitRequired = errors.New("quantity limit required") + errInvalidPushData = errors.New("invalid push data") + errInvalidLeverage = errors.New("leverage can't be zero or less then it") + errInvalidPositionMode = errors.New("position mode is invalid") + errInvalidMode = errors.New("mode can't be empty or missing") + errInvalidOrderFilter = errors.New("invalid order filter") + errInvalidCategory = errors.New("invalid category") + errEitherSymbolOrCoinRequired = errors.New("either symbol or coin required") + errOrderLinkIDMissing = errors.New("order link id missing") + errSymbolMissing = errors.New("symbol missing") + errInvalidAutoAddMarginValue = errors.New("invalid add auto margin value") + errDisconnectTimeWindowNotSet = errors.New("disconnect time window not set") + errAPIKeyIsNotUnified = errors.New("api key is not unified") + errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") ) var ( diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index e7324bb3..f81ee85f 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -183,13 +183,10 @@ func TestGetRiskLimit(t *testing.T) { t.Error(err) } _, err = b.GetRiskLimit(t.Context(), "option", optionsTradablePair.String()) - if !errors.Is(err, errInvalidCategory) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetRiskLimit(t.Context(), "spot", spotTradablePair.String()) - if !errors.Is(err, errInvalidCategory) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidCategory) } // test cases for Wrapper @@ -330,9 +327,7 @@ func TestGetHistoricCandles(t *testing.T) { t.Error(err) } _, err = b.GetHistoricCandles(t.Context(), optionsTradablePair, asset.Options, kline.OneHour, start, end) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetHistoricCandlesExtended(t *testing.T) { @@ -356,9 +351,7 @@ func TestGetHistoricCandlesExtended(t *testing.T) { t.Error(err) } _, err = b.GetHistoricCandlesExtended(t.Context(), optionsTradablePair, asset.Options, kline.FiveMin, startTime, end) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("found '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestCancelOrder(t *testing.T) { @@ -429,9 +422,7 @@ func TestCancelAllOrders(t *testing.T) { t.Error(err) } _, err = b.CancelAllOrders(t.Context(), &order.Cancel{Exchange: b.Name, AssetType: asset.Futures, Pair: spotTradablePair}) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, but found %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetOrderInfo(t *testing.T) { @@ -643,13 +634,11 @@ func TestGetTickersV5(t *testing.T) { func TestGetFundingRateHistory(t *testing.T) { t.Parallel() _, err := b.GetFundingRateHistory(t.Context(), "bruh", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetFundingRateHistory(t.Context(), "spot", spotTradablePair.String(), time.Time{}, time.Time{}, 100) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetFundingRateHistory(t.Context(), "linear", usdtMarginedTradablePair.String(), time.Time{}, time.Time{}, 100) if err != nil { t.Error(err) @@ -663,9 +652,7 @@ func TestGetFundingRateHistory(t *testing.T) { t.Error(err) } _, err = b.GetFundingRateHistory(t.Context(), "option", optionsTradablePair.String(), time.Time{}, time.Time{}, 100) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) } func TestGetPublicTradingHistory(t *testing.T) { @@ -695,9 +682,8 @@ func TestGetPublicTradingHistory(t *testing.T) { func TestGetOpenInterestData(t *testing.T) { t.Parallel() _, err := b.GetOpenInterestData(t.Context(), "spot", spotTradablePair.String(), "5min", time.Time{}, time.Time{}, 0, "") - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetOpenInterestData(t.Context(), "linear", usdtMarginedTradablePair.String(), "5min", time.Time{}, time.Time{}, 0, "") if err != nil { t.Error(err) @@ -711,9 +697,7 @@ func TestGetOpenInterestData(t *testing.T) { t.Error(err) } _, err = b.GetOpenInterestData(t.Context(), "option", optionsTradablePair.String(), "5min", time.Time{}, time.Time{}, 0, "") - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) } func TestGetHistoricalVolatility(t *testing.T) { @@ -729,9 +713,7 @@ func TestGetHistoricalVolatility(t *testing.T) { t.Error(err) } _, err = b.GetHistoricalVolatility(t.Context(), "spot", "", 123, start, end) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, but found %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) } func TestGetInsurance(t *testing.T) { @@ -745,9 +727,8 @@ func TestGetInsurance(t *testing.T) { func TestGetDeliveryPrice(t *testing.T) { t.Parallel() _, err := b.GetDeliveryPrice(t.Context(), "spot", spotTradablePair.String(), "", "", 200) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, but found %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetDeliveryPrice(t.Context(), "linear", "", "", "", 200) if err != nil { t.Error(err) @@ -797,49 +778,42 @@ func TestPlaceOrder(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) ctx := t.Context() _, err := b.PlaceOrder(ctx, nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "my-category", }) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "spot", }) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyPairEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "spot", Symbol: currency.Pair{Delimiter: "", Base: currency.BTC, Quote: currency.USDT}, }) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Fatalf("expected %v, got %v", order.ErrSideIsInvalid, err) - } + require.ErrorIs(t, err, order.ErrSideIsInvalid) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "spot", Symbol: spotTradablePair, Side: "buy", }) - if !errors.Is(err, order.ErrTypeIsInvalid) { - t.Fatalf("expected %v, got %v", order.ErrTypeIsInvalid, err) - } + require.ErrorIs(t, err, order.ErrTypeIsInvalid) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "spot", Symbol: spotTradablePair, Side: "buy", OrderType: "limit", }) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Fatalf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + require.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.PlaceOrder(ctx, &PlaceOrderParams{ Category: "spot", Symbol: spotTradablePair, @@ -848,9 +822,8 @@ func TestPlaceOrder(t *testing.T) { OrderQuantity: 1, TriggerDirection: 3, }) - if !errors.Is(err, errInvalidTriggerDirection) { - t.Fatalf("expected %v, got %v", errInvalidTriggerDirection, err) - } + require.ErrorIs(t, err, errInvalidTriggerDirection) + _, err = b.PlaceOrder(t.Context(), &PlaceOrderParams{ Category: "spot", Symbol: spotTradablePair, @@ -928,33 +901,28 @@ func TestAmendOrder(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.AmendOrder(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.AmendOrder(t.Context(), &AmendOrderParams{}) - if !errors.Is(err, errEitherOrderIDOROrderLinkIDRequired) { - t.Fatalf("expected %v, got %v", errEitherOrderIDOROrderLinkIDRequired, err) - } + require.ErrorIs(t, err, errEitherOrderIDOROrderLinkIDRequired) + _, err = b.AmendOrder(t.Context(), &AmendOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", }) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.AmendOrder(t.Context(), &AmendOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: "mycat", }) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.AmendOrder(t.Context(), &AmendOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: "option", }) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyPairEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + _, err = b.AmendOrder(t.Context(), &AmendOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: cSpot, @@ -977,33 +945,28 @@ func TestCancelTradeOrder(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.CancelTradeOrder(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CancelTradeOrder(t.Context(), &CancelOrderParams{}) - if !errors.Is(err, errEitherOrderIDOROrderLinkIDRequired) { - t.Fatalf("expected %v, got %v", errEitherOrderIDOROrderLinkIDRequired, err) - } + require.ErrorIs(t, err, errEitherOrderIDOROrderLinkIDRequired) + _, err = b.CancelTradeOrder(t.Context(), &CancelOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", }) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.CancelTradeOrder(t.Context(), &CancelOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: "mycat", }) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.CancelTradeOrder(t.Context(), &CancelOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: "option", }) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyPairEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + _, err = b.CancelTradeOrder(t.Context(), &CancelOrderParams{ OrderID: "c6f055d9-7f21-4079-913d-e6523a9cfffa", Category: "option", @@ -1020,9 +983,8 @@ func TestGetOpenOrders(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetOpenOrders(t.Context(), "", "", "", "", "", "", "", "", 0, 100) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetOpenOrders(t.Context(), "spot", "", "", "", "", "", "", "", 0, 0) if err != nil { t.Error(err) @@ -1036,13 +998,11 @@ func TestCancelAllTradeOrders(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.CancelAllTradeOrders(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CancelAllTradeOrders(t.Context(), &CancelAllOrdersParam{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.CancelAllTradeOrders(t.Context(), &CancelAllOrdersParam{Category: "option"}) if err != nil { t.Error(err) @@ -1060,9 +1020,8 @@ func TestGetTradeOrderHistory(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetTradeOrderHistory(t.Context(), "", "", "", "", "", "", "", "", "", start, end, 100) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetTradeOrderHistory(t.Context(), "spot", spotTradablePair.String(), "", "", "BTC", "", "StopOrder", "", "", start, end, 100) if err != nil { t.Error(err) @@ -1076,19 +1035,16 @@ func TestPlaceBatchOrder(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.PlaceBatchOrder(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.PlaceBatchOrder(t.Context(), &PlaceBatchOrderParam{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.PlaceBatchOrder(t.Context(), &PlaceBatchOrderParam{ Category: "linear", }) - if !errors.Is(err, errNoOrderPassed) { - t.Fatalf("expected %v, got %v", errNoOrderPassed, err) - } + require.ErrorIs(t, err, errNoOrderPassed) + _, err = b.PlaceBatchOrder(t.Context(), &PlaceBatchOrderParam{ Category: "option", Request: []BatchOrderItemParam{ @@ -1158,9 +1114,8 @@ func TestBatchAmendOrder(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.BatchAmendOrder(t.Context(), "linear", nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.BatchAmendOrder(t.Context(), "", []BatchAmendOrderParamItem{ { Symbol: optionsTradablePair, @@ -1168,9 +1123,8 @@ func TestBatchAmendOrder(t *testing.T) { OrderID: "b551f227-7059-4fb5-a6a6-699c04dbd2f2", }, }) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.BatchAmendOrder(t.Context(), "option", []BatchAmendOrderParamItem{ { Symbol: optionsTradablePair, @@ -1195,17 +1149,14 @@ func TestCancelBatchOrder(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.CancelBatchOrder(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CancelBatchOrder(t.Context(), &CancelBatchOrder{}) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.CancelBatchOrder(t.Context(), &CancelBatchOrder{Category: cOption}) - if !errors.Is(err, errNoOrderPassed) { - t.Fatalf("expected %v, got %v", errNoOrderPassed, err) - } + require.ErrorIs(t, err, errNoOrderPassed) + _, err = b.CancelBatchOrder(t.Context(), &CancelBatchOrder{ Category: "option", Request: []CancelOrderParams{ @@ -1230,17 +1181,14 @@ func TestGetBorrowQuota(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetBorrowQuota(t.Context(), "", "BTCUSDT", "Buy") - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetBorrowQuota(t.Context(), "spot", "", "Buy") - if !errors.Is(err, errSymbolMissing) { - t.Fatalf("expected %v, got %v", errSymbolMissing, err) - } + require.ErrorIs(t, err, errSymbolMissing) + _, err = b.GetBorrowQuota(t.Context(), "spot", spotTradablePair.String(), "") - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) + _, err = b.GetBorrowQuota(t.Context(), "spot", spotTradablePair.String(), "Buy") if err != nil { t.Error(err) @@ -1254,9 +1202,8 @@ func TestSetDisconnectCancelAll(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SetDisconnectCancelAll(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + err = b.SetDisconnectCancelAll(t.Context(), &SetDCPParams{TimeWindow: 300}) if err != nil { t.Fatal(err) @@ -1269,13 +1216,11 @@ func TestGetPositionInfo(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetPositionInfo(t.Context(), "", "", "", "", "", 20) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetPositionInfo(t.Context(), "spot", "", "", "", "", 20) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPositionInfo(t.Context(), "linear", "BTCUSDT", "", "", "", 20) if err != nil { t.Error(err) @@ -1293,25 +1238,20 @@ func TestSetLeverageLevel(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SetLeverageLevel(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + err = b.SetLeverageLevel(t.Context(), &SetLeverageParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + err = b.SetLeverageLevel(t.Context(), &SetLeverageParams{Category: "spot"}) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + err = b.SetLeverageLevel(t.Context(), &SetLeverageParams{Category: "linear"}) - if !errors.Is(err, errSymbolMissing) { - t.Fatalf("expected %v, got %v", errSymbolMissing, err) - } + require.ErrorIs(t, err, errSymbolMissing) + err = b.SetLeverageLevel(t.Context(), &SetLeverageParams{Category: "linear", Symbol: "BTCUSDT"}) - if !errors.Is(err, errInvalidLeverage) { - t.Fatalf("expected %v, got %v", errInvalidLeverage, err) - } + require.ErrorIs(t, err, errInvalidLeverage) + err = b.SetLeverageLevel(t.Context(), &SetLeverageParams{Category: "linear", Symbol: "BTCUSDT", SellLeverage: 3, BuyLeverage: 3}) if err != nil { t.Error(err) @@ -1325,29 +1265,23 @@ func TestSwitchTradeMode(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SwitchTradeMode(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{Category: "spot"}) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{Category: "linear"}) - if !errors.Is(err, errSymbolMissing) { - t.Fatalf("expected %v, got %v", errSymbolMissing, err) - } + require.ErrorIs(t, err, errSymbolMissing) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{Category: "linear", Symbol: usdtMarginedTradablePair.String()}) - if !errors.Is(err, errInvalidLeverage) { - t.Fatalf("expected %v, got %v", errInvalidLeverage, err) - } + require.ErrorIs(t, err, errInvalidLeverage) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{Category: "linear", Symbol: usdcMarginedTradablePair.String(), SellLeverage: 3, BuyLeverage: 3, TradeMode: 2}) - if !errors.Is(err, errInvalidTradeModeValue) { - t.Fatalf("expected %v, got %v", errInvalidTradeModeValue, err) - } + require.ErrorIs(t, err, errInvalidTradeModeValue) + err = b.SwitchTradeMode(t.Context(), &SwitchTradeModeParams{Category: "linear", Symbol: usdtMarginedTradablePair.String(), SellLeverage: 3, BuyLeverage: 3, TradeMode: 1}) if err != nil { t.Error(err) @@ -1361,31 +1295,25 @@ func TestSetTakeProfitStopLossMode(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.SetTakeProfitStopLossMode(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{ Category: "spot", }) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{Category: "spot"}) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{Category: "linear"}) - if !errors.Is(err, errSymbolMissing) { - t.Fatalf("expected %v, got %v", errSymbolMissing, err) - } + require.ErrorIs(t, err, errSymbolMissing) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{Category: "linear", Symbol: "BTCUSDT"}) - if !errors.Is(err, errTakeProfitOrStopLossModeMissing) { - t.Fatalf("expected %v, got %v", errTakeProfitOrStopLossModeMissing, err) - } + require.ErrorIs(t, err, errTakeProfitOrStopLossModeMissing) + _, err = b.SetTakeProfitStopLossMode(t.Context(), &TPSLModeParams{Category: "linear", Symbol: "BTCUSDT", TpslMode: "Partial"}) if err != nil { t.Error(err) @@ -1399,17 +1327,14 @@ func TestSwitchPositionMode(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SwitchPositionMode(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + err = b.SwitchPositionMode(t.Context(), &SwitchPositionModeParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + err = b.SwitchPositionMode(t.Context(), &SwitchPositionModeParams{Category: "linear"}) - if !errors.Is(err, errEitherSymbolOrCoinRequired) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errEitherSymbolOrCoinRequired) + err = b.SwitchPositionMode(t.Context(), &SwitchPositionModeParams{Category: "linear", Symbol: usdtMarginedTradablePair, PositionMode: 3}) if err != nil { t.Error(err) @@ -1423,21 +1348,17 @@ func TestSetRiskLimit(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.SetRiskLimit(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.SetRiskLimit(t.Context(), &SetRiskLimitParam{}) - if !errors.Is(err, errCategoryNotSet) { - t.Errorf("expected %v, got %v", errCategoryNotSet, err) - } + assert.ErrorIs(t, err, errCategoryNotSet) + _, err = b.SetRiskLimit(t.Context(), &SetRiskLimitParam{Category: "linear", PositionMode: -2}) - if !errors.Is(err, errInvalidPositionMode) { - t.Errorf("expected %v, got %v", errInvalidPositionMode, err) - } + assert.ErrorIs(t, err, errInvalidPositionMode) + _, err = b.SetRiskLimit(t.Context(), &SetRiskLimitParam{Category: "linear"}) - if !errors.Is(err, errSymbolMissing) { - t.Errorf("expected %v, got %v", errSymbolMissing, err) - } + assert.ErrorIs(t, err, errSymbolMissing) + _, err = b.SetRiskLimit(t.Context(), &SetRiskLimitParam{ Category: "linear", RiskID: 1234, @@ -1456,13 +1377,11 @@ func TestSetTradingStop(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SetTradingStop(t.Context(), &TradingStopParams{}) - if !errors.Is(err, errCategoryNotSet) { - t.Errorf("expected %v, got %v", errCategoryNotSet, err) - } + assert.ErrorIs(t, err, errCategoryNotSet) + err = b.SetTradingStop(t.Context(), &TradingStopParams{Category: "spot"}) - if !errors.Is(err, errInvalidCategory) { - t.Errorf("expected %v, got %v", errInvalidCategory, err) - } + assert.ErrorIs(t, err, errInvalidCategory) + err = b.SetTradingStop(t.Context(), &TradingStopParams{ Category: "linear", Symbol: usdtMarginedTradablePair, @@ -1554,9 +1473,8 @@ func TestGetClosedPnL(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetClosedPnL(t.Context(), "spot", "", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", err, errInvalidCategory) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetClosedPnL(t.Context(), "linear", "", "", time.Time{}, time.Time{}, 0) if err != nil { t.Fatal(err) @@ -1582,13 +1500,11 @@ func TestGetPreUpgradeOrderHistory(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeOrderHistory(t.Context(), "", "", "", "", "", "", "", "", time.Time{}, time.Time{}, 100) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("expected %v, got %v", errCategoryNotSet, err) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetPreUpgradeOrderHistory(t.Context(), "option", "", "", "", "", "", "", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errBaseNotSet) { - t.Fatalf("expected %v, got %v", errBaseNotSet, err) - } + require.ErrorIs(t, err, errBaseNotSet) + _, err = b.GetPreUpgradeOrderHistory(t.Context(), "linear", "", "", "", "", "", "", "", time.Time{}, time.Time{}, 0) if err != nil { t.Error(err) @@ -1602,13 +1518,11 @@ func TestGetPreUpgradeTradeHistory(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeTradeHistory(t.Context(), "", "", "", "", "", "", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errCategoryNotSet) { - t.Fatalf("found %v, expected %v", err, errCategoryNotSet) - } + require.ErrorIs(t, err, errCategoryNotSet) + _, err = b.GetPreUpgradeTradeHistory(t.Context(), "option", "", "", "", "", "", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("found %v, expected %v", err, errInvalidCategory) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPreUpgradeTradeHistory(t.Context(), "linear", "", "", "", "", "", "", time.Time{}, time.Time{}, 0) if err != nil { t.Error(err) @@ -1622,9 +1536,8 @@ func TestGetPreUpgradeClosedPnL(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeClosedPnL(t.Context(), "option", "BTCUSDT", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPreUpgradeClosedPnL(t.Context(), "linear", "BTCUSDT", "", time.Time{}, time.Time{}, 0) if err != nil { t.Error(err) @@ -1638,9 +1551,8 @@ func TestGetPreUpgradeTransactionLog(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeTransactionLog(t.Context(), "option", "", "", "", time.Time{}, time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("found %v, expected %v", err, errInvalidCategory) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPreUpgradeTransactionLog(t.Context(), "linear", "", "", "", time.Time{}, time.Time{}, 0) if err != nil { t.Error(err) @@ -1654,9 +1566,8 @@ func TestGetPreUpgradeOptionDeliveryRecord(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeOptionDeliveryRecord(t.Context(), "linear", "", "", time.Time{}, 0) - if !errors.Is(err, errInvalidCategory) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPreUpgradeOptionDeliveryRecord(t.Context(), "option", "", "", time.Time{}, 0) if err != nil { t.Error(err) @@ -1670,9 +1581,8 @@ func TestGetPreUpgradeUSDCSessionSettlement(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetPreUpgradeUSDCSessionSettlement(t.Context(), "option", "", "", 10) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetPreUpgradeUSDCSessionSettlement(t.Context(), "linear", "", "", 10) if err != nil { t.Error(err) @@ -1831,9 +1741,8 @@ func TestGetFeeRate(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetFeeRate(t.Context(), "something", "", "BTC") - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetFeeRate(t.Context(), "linear", "", "BTC") if err != nil { t.Error(err) @@ -1894,9 +1803,8 @@ func TestGetSubAccountALLAPIKeys(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetSubAccountAllAPIKeys(t.Context(), "", "", 10) - if !errors.Is(err, errMemberIDRequired) { - t.Errorf("expected %v, got %v", errMemberIDRequired, err) - } + assert.ErrorIs(t, err, errMemberIDRequired) + _, err = b.GetSubAccountAllAPIKeys(t.Context(), "1234", "", 10) if err != nil { t.Error(err) @@ -1910,9 +1818,8 @@ func TestSetMMP(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.SetMMP(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("found %v, expected %v", err, errNilArgument) - } + require.ErrorIs(t, err, errNilArgument) + err = b.SetMMP(t.Context(), &MMPRequestParam{ BaseCoin: "ETH", TimeWindowMS: 5000, @@ -1932,9 +1839,8 @@ func TestResetMMP(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.ResetMMP(t.Context(), "USDT") - if !errors.Is(err, errNilArgument) { - t.Fatalf("found %v, expected %v", err, errNilArgument) - } + require.ErrorIs(t, err, errNilArgument) + err = b.ResetMMP(t.Context(), "BTC") if err != nil { t.Error(err) @@ -1972,13 +1878,9 @@ func TestGetDeliveryRecord(t *testing.T) { expiryTime = time.UnixMilli(1700216290093) } _, err := b.GetDeliveryRecord(t.Context(), "spot", "", "", expiryTime, 20) - if !errors.Is(err, errInvalidCategory) { - t.Fatal(err) - } + assert.ErrorIs(t, err, errInvalidCategory) _, err = b.GetDeliveryRecord(t.Context(), "linear", "", "", expiryTime, 20) - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "GetDeliveryRecord should not error for linear category") } func TestGetUSDCSessionSettlement(t *testing.T) { @@ -1987,9 +1889,8 @@ func TestGetUSDCSessionSettlement(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetUSDCSessionSettlement(t.Context(), "option", "", "", 10) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) + _, err = b.GetUSDCSessionSettlement(t.Context(), "linear", "", "", 10) if err != nil { t.Error(err) @@ -2002,13 +1903,10 @@ func TestGetAssetInfo(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetAssetInfo(t.Context(), "", "BTC") - if !errors.Is(err, errMissingAccountType) { - t.Fatal(err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.GetAssetInfo(t.Context(), "SPOT", "BTC") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "GetAssetInfo should not error for SPOT account type") } func TestGetAllCoinBalance(t *testing.T) { @@ -2017,9 +1915,8 @@ func TestGetAllCoinBalance(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetAllCoinBalance(t.Context(), "", "", "", 0) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.GetAllCoinBalance(t.Context(), "FUND", "", "", 0) if err != nil { t.Fatal(err) @@ -2032,9 +1929,8 @@ func TestGetSingleCoinBalance(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, b) } _, err := b.GetSingleCoinBalance(t.Context(), "", "", "", 0, 0) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.GetSingleCoinBalance(t.Context(), "SPOT", currency.BTC.String(), "", 0, 0) if err != nil { t.Fatal(err) @@ -2059,50 +1955,43 @@ func TestCreateInternalTransfer(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.CreateInternalTransfer(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{}) - if !errors.Is(err, errMissingTransferID) { - t.Fatalf("expected %v, got %v", errMissingTransferID, err) - } + require.ErrorIs(t, err, errMissingTransferID) + transferID, err := uuid.NewV7() if err != nil { t.Fatal(err) } _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{TransferID: transferID}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, }) - if !errors.Is(err, order.ErrAmountIsInvalid) { - t.Fatalf("expected %v, got %v", order.ErrAmountIsInvalid, err) - } + require.ErrorIs(t, err, order.ErrAmountIsInvalid) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, FromAccountType: "UNIFIED", }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateInternalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, @@ -2150,9 +2039,8 @@ func TestEnableUniversalTransferForSubUID(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) err := b.EnableUniversalTransferForSubUID(t.Context()) - if !errors.Is(err, errMembersIDsNotSet) { - t.Fatalf("expected %v, got %v", errMembersIDsNotSet, err) - } + require.ErrorIs(t, err, errMembersIDsNotSet) + transferID1, err := uuid.NewV7() if err != nil { t.Fatal(err) @@ -2170,59 +2058,51 @@ func TestEnableUniversalTransferForSubUID(t *testing.T) { func TestCreateUniversalTransfer(t *testing.T) { t.Parallel() _, err := b.CreateUniversalTransfer(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{}) - if !errors.Is(err, errMissingTransferID) { - t.Fatalf("expected %v, got %v", errMissingTransferID, err) - } + require.ErrorIs(t, err, errMissingTransferID) + transferID, err := uuid.NewV7() if err != nil { t.Fatal(err) } _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{TransferID: transferID}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, }) - if !errors.Is(err, order.ErrAmountIsInvalid) { - t.Fatalf("expected %v, got %v", order.ErrAmountIsInvalid, err) - } + require.ErrorIs(t, err, order.ErrAmountIsInvalid) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, FromAccountType: "UNIFIED", }) - if !errors.Is(err, errMissingAccountType) { - t.Fatalf("expected %v, got %v", errMissingAccountType, err) - } + require.ErrorIs(t, err, errMissingAccountType) + _, err = b.CreateUniversalTransfer(t.Context(), &TransferParams{ TransferID: transferID, Coin: currency.BTC, Amount: 123.456, ToAccountType: "CONTRACT", FromAccountType: "UNIFIED", }) - if !errors.Is(err, errMemberIDRequired) { - t.Fatalf("expected %v, got %v", errMemberIDRequired, err) - } + require.ErrorIs(t, err, errMemberIDRequired) + if mockTests { t.Skip(skipAuthenticatedFunctionsForMockTesting) } @@ -2379,25 +2259,20 @@ func TestWithdrawCurrency(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.WithdrawCurrency(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.WithdrawCurrency(t.Context(), &WithdrawalParam{}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.WithdrawCurrency(t.Context(), &WithdrawalParam{Coin: currency.BTC}) - if !errors.Is(err, errMissingChainInformation) { - t.Fatalf("expected %v, got %v", errMissingChainInformation, err) - } + require.ErrorIs(t, err, errMissingChainInformation) + _, err = b.WithdrawCurrency(t.Context(), &WithdrawalParam{Coin: currency.LTC, Chain: "LTC"}) - if !errors.Is(err, errMissingAddressInfo) { - t.Fatalf("expected %v, got %v", errMissingAddressInfo, err) - } + require.ErrorIs(t, err, errMissingAddressInfo) + _, err = b.WithdrawCurrency(t.Context(), &WithdrawalParam{Coin: currency.LTC, Chain: "LTC", Address: "234234234"}) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Fatalf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + require.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.WithdrawCurrency(t.Context(), &WithdrawalParam{Coin: currency.LTC, Chain: "LTC", Address: "234234234", Amount: 123}) if err != nil { t.Fatal(err) @@ -2411,9 +2286,8 @@ func TestCancelWithdrawal(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.CancelWithdrawal(t.Context(), "") - if !errors.Is(err, errMissingWithdrawalID) { - t.Fatalf("expected %v, got %v", errMissingWithdrawalID, err) - } + require.ErrorIs(t, err, errMissingWithdrawalID) + _, err = b.CancelWithdrawal(t.Context(), "12314") if err != nil { t.Error(err) @@ -2423,17 +2297,14 @@ func TestCancelWithdrawal(t *testing.T) { func TestCreateNewSubUserID(t *testing.T) { t.Parallel() _, err := b.CreateNewSubUserID(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CreateNewSubUserID(t.Context(), &CreateSubUserParams{MemberType: 1, Switch: 1, Note: "test"}) - if !errors.Is(err, errMissingUsername) { - t.Fatalf("expected %v, got %v", errMissingUsername, err) - } + require.ErrorIs(t, err, errMissingUsername) + _, err = b.CreateNewSubUserID(t.Context(), &CreateSubUserParams{Username: "Sami", Switch: 1, Note: "test"}) - if !errors.Is(err, errInvalidMemberType) { - t.Fatalf("expected %v, got %v", errInvalidMemberType, err) - } + require.ErrorIs(t, err, errInvalidMemberType) + if mockTests { t.Skip(skipAuthenticatedFunctionsForMockTesting) } @@ -2447,13 +2318,11 @@ func TestCreateNewSubUserID(t *testing.T) { func TestCreateSubUIDAPIKey(t *testing.T) { t.Parallel() _, err := b.CreateSubUIDAPIKey(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.CreateSubUIDAPIKey(t.Context(), &SubUIDAPIKeyParam{}) - if !errors.Is(err, errMissingUserID) { - t.Fatalf("expected %v, got %v", errMissingUserID, err) - } + require.ErrorIs(t, err, errMissingUserID) + if mockTests { t.Skip(skipAuthenticatedFunctionsForMockTesting) } @@ -2524,9 +2393,8 @@ func TestModifyMasterAPIKey(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.ModifyMasterAPIKey(t.Context(), &SubUIDAPIKeyUpdateParam{}) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.ModifyMasterAPIKey(t.Context(), &SubUIDAPIKeyUpdateParam{ ReadOnly: 0, IPs: "*", @@ -2551,9 +2419,8 @@ func TestModifySubAPIKey(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.ModifySubAPIKey(t.Context(), &SubUIDAPIKeyUpdateParam{}) - if !errors.Is(err, errNilArgument) { - t.Fatalf("expected %v, got %v", errNilArgument, err) - } + require.ErrorIs(t, err, errNilArgument) + _, err = b.ModifySubAPIKey(t.Context(), &SubUIDAPIKeyUpdateParam{ APIKey: "lnqQ8ACaoMLi4168He", ReadOnly: 0, @@ -2579,9 +2446,8 @@ func TestDeleteSubUID(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) err := b.DeleteSubUID(t.Context(), "") - if !errors.Is(err, errMemberIDRequired) { - t.Errorf("expected %v, got %v", errMemberIDRequired, err) - } + assert.ErrorIs(t, err, errMemberIDRequired) + err = b.DeleteSubUID(t.Context(), "1234") if err != nil { t.Error(err) @@ -2643,9 +2509,8 @@ func TestGetLeveragedTokenMarket(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetLeveragedTokenMarket(t.Context(), currency.EMPTYCODE) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Fatalf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + require.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.GetLeveragedTokenMarket(t.Context(), currency.NewCode("BTC3L")) if err != nil { t.Error(err) @@ -2738,18 +2603,16 @@ func TestGetBorrowableCoinInfo(t *testing.T) { func TestGetInterestAndQuota(t *testing.T) { t.Parallel() + _, err := b.GetInterestAndQuota(t.Context(), currency.EMPTYCODE) + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + if mockTests { t.Skip(skipAuthenticatedFunctionsForMockTesting) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) - _, err := b.GetInterestAndQuota(t.Context(), currency.EMPTYCODE) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + _, err = b.GetInterestAndQuota(t.Context(), currency.BTC) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestGetLoanAccountInfo(t *testing.T) { @@ -2759,9 +2622,7 @@ func TestGetLoanAccountInfo(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetLoanAccountInfo(t.Context()) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestBorrow(t *testing.T) { @@ -2771,17 +2632,14 @@ func TestBorrow(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.Borrow(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.Borrow(t.Context(), &LendArgument{}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.Borrow(t.Context(), &LendArgument{Coin: currency.BTC}) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Errorf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + assert.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.Borrow(t.Context(), &LendArgument{Coin: currency.BTC, AmountToBorrow: 0.1}) if err != nil { t.Error(err) @@ -2795,17 +2653,14 @@ func TestRepay(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.Repay(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Errorf("expected %v, got %v", errNilArgument, err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.Repay(t.Context(), &LendArgument{}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.Repay(t.Context(), &LendArgument{Coin: currency.BTC}) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Errorf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + assert.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.Repay(t.Context(), &LendArgument{Coin: currency.BTC, AmountToBorrow: 0.1}) if err != nil { t.Error(err) @@ -2819,9 +2674,7 @@ func TestGetBorrowOrderDetail(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetBorrowOrderDetail(t.Context(), time.Time{}, time.Time{}, currency.BTC, 0, 0) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestGetRepaymentOrderDetail(t *testing.T) { @@ -2831,9 +2684,7 @@ func TestGetRepaymentOrderDetail(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetRepaymentOrderDetail(t.Context(), time.Time{}, time.Time{}, currency.BTC, 0) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestToggleMarginTradeNormal(t *testing.T) { @@ -2843,9 +2694,7 @@ func TestToggleMarginTradeNormal(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.ToggleMarginTradeNormal(t.Context(), true) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestGetProductInfo(t *testing.T) { @@ -2871,9 +2720,7 @@ func TestGetInstitutionalLoanOrders(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetInstitutionalLoanOrders(t.Context(), "", time.Time{}, time.Time{}, 0) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestGetInstitutionalRepayOrders(t *testing.T) { @@ -2883,9 +2730,7 @@ func TestGetInstitutionalRepayOrders(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetInstitutionalRepayOrders(t.Context(), time.Time{}, time.Time{}, 0) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestGetLTV(t *testing.T) { @@ -2895,9 +2740,7 @@ func TestGetLTV(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetLTV(t.Context()) - if err != nil && !errors.Is(err, errEndpointAvailableForNormalAPIKeyHolders) { - t.Error(err) - } + assert.NoError(t, err) } func TestBindOrUnbindUID(t *testing.T) { @@ -2931,17 +2774,14 @@ func TestC2CDepositFunds(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.C2CDepositFunds(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Error(err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.C2CDepositFunds(t.Context(), &C2CLendingFundsParams{}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.C2CDepositFunds(t.Context(), &C2CLendingFundsParams{Coin: currency.BTC}) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Errorf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + assert.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.C2CDepositFunds(t.Context(), &C2CLendingFundsParams{Coin: currency.BTC, Quantity: 1232}) if err != nil { t.Error(err) @@ -2955,17 +2795,14 @@ func TestC2CRedeemFunds(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b, canManipulateRealOrders) _, err := b.C2CRedeemFunds(t.Context(), nil) - if !errors.Is(err, errNilArgument) { - t.Error(err) - } + assert.ErrorIs(t, err, errNilArgument) + _, err = b.C2CRedeemFunds(t.Context(), &C2CLendingFundsParams{}) - if !errors.Is(err, currency.ErrCurrencyCodeEmpty) { - t.Errorf("expected %v, got %v", currency.ErrCurrencyCodeEmpty, err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyCodeEmpty) + _, err = b.C2CRedeemFunds(t.Context(), &C2CLendingFundsParams{Coin: currency.BTC}) - if !errors.Is(err, order.ErrAmountBelowMin) { - t.Errorf("expected %v, got %v", order.ErrAmountBelowMin, err) - } + assert.ErrorIs(t, err, order.ErrAmountBelowMin) + _, err = b.C2CRedeemFunds(t.Context(), &C2CLendingFundsParams{Coin: currency.BTC, Quantity: 1232}) if err != nil { t.Error(err) @@ -3054,9 +2891,8 @@ func TestGetWithdrawalsHistory(t *testing.T) { } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) _, err := b.GetWithdrawalsHistory(t.Context(), currency.BTC, asset.Futures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetWithdrawalsHistory(t.Context(), currency.BTC, asset.Spot) if err != nil { t.Error("GetWithdrawalsHistory()", err) @@ -3139,9 +2975,8 @@ func TestCancelBatchOrders(t *testing.T) { AssetType: asset.USDTMarginedFutures, }} _, err := b.CancelBatchOrders(t.Context(), orderCancellationParams) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + orderCancellationParams = []order.Cancel{{ OrderID: "1", AccountID: "1", @@ -3175,9 +3010,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { - t.Error(err) - } + assert.Truef(t, errors.Is(err, websocket.ErrWebsocketNotEnabled) || err == nil, "WsLinerConnect should not error: %s", err) } func TestWsInverseConnect(t *testing.T) { @@ -3186,9 +3019,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { - t.Error(err) - } + assert.Truef(t, errors.Is(err, websocket.ErrWebsocketNotEnabled) || err == nil, "WsInverseConnect should not error: %s", err) } func TestWsOptionsConnect(t *testing.T) { @@ -3197,9 +3028,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, websocket.ErrWebsocketNotEnabled) { - t.Error(err) - } + assert.Truef(t, errors.Is(err, websocket.ErrWebsocketNotEnabled) || err == nil, "WsOptionsConnect should not error: %s", err) } var pushDataMap = map[string]string{ @@ -3426,9 +3255,7 @@ func TestSetLeverage(t *testing.T) { } err = b.SetLeverage(ctx, asset.CoinMarginedFutures, inverseTradablePair, margin.Isolated, 5, order.UnknownSide) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Errorf("received '%v', expected '%v'", err, order.ErrSideIsInvalid) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) err = b.SetLeverage(ctx, asset.USDTMarginedFutures, usdtMarginedTradablePair, margin.Isolated, 5, order.Buy) if err != nil { @@ -3441,22 +3268,17 @@ func TestSetLeverage(t *testing.T) { } err = b.SetLeverage(ctx, asset.USDTMarginedFutures, usdtMarginedTradablePair, margin.Isolated, 5, order.CouldNotBuy) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Errorf("received '%v', expected '%v'", err, order.ErrSideIsInvalid) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) err = b.SetLeverage(ctx, asset.Spot, inverseTradablePair, margin.Multi, 5, order.UnknownSide) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("received '%v', expected '%v'", err, asset.ErrNotSupported) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestGetFuturesContractDetails(t *testing.T) { t.Parallel() _, err := b.GetFuturesContractDetails(t.Context(), asset.Spot) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Error(err) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) + _, err = b.GetFuturesContractDetails(t.Context(), asset.CoinMarginedFutures) assert.NoError(t, err) @@ -3490,9 +3312,7 @@ func TestFetchTradablePairs(t *testing.T) { t.Fatal(err) } _, err = b.FetchTradablePairs(t.Context(), asset.Futures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) } func TestDeltaUpdateOrderbook(t *testing.T) { @@ -3525,9 +3345,7 @@ func TestGetLongShortRatio(t *testing.T) { t.Fatal(err) } _, err = b.GetLongShortRatio(t.Context(), "spot", "BTCUSDT", kline.FiveMin, 0) - if !errors.Is(err, errInvalidCategory) { - t.Fatalf("expected %v, got %v", errInvalidCategory, err) - } + require.ErrorIs(t, err, errInvalidCategory) } func TestStringToOrderStatus(t *testing.T) { @@ -3611,23 +3429,20 @@ func TestGetLatestFundingRates(t *testing.T) { Asset: asset.Futures, Pair: usdtMarginedTradablePair, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.Spot, Pair: spotTradablePair, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.Options, Pair: optionsTradablePair, }) - if !errors.Is(err, asset.ErrNotSupported) { - t.Errorf("expected %v, got %v", asset.ErrNotSupported, err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) + _, err = b.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.USDTMarginedFutures, }) diff --git a/exchanges/collateral/collateral_test.go b/exchanges/collateral/collateral_test.go index 61cf5217..4d6f4eb3 100644 --- a/exchanges/collateral/collateral_test.go +++ b/exchanges/collateral/collateral_test.go @@ -1,10 +1,10 @@ package collateral import ( - "errors" "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/thrasher-corp/gocryptotrader/encoding/json" ) @@ -69,9 +69,8 @@ func TestUnmarshalJSONCollateralType(t *testing.T) { jason = []byte(`{"collateral":"hello moto"}`) err = json.Unmarshal(jason, &alien) - if !errors.Is(err, ErrInvalidCollateralMode) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrInvalidCollateralMode) + if alien.M != UnknownMode { t.Errorf("received '%v' expected 'UnknownMode'", alien.M) } @@ -143,9 +142,8 @@ func TestIsValidCollateralTypeString(t *testing.T) { func TestStringToCollateralType(t *testing.T) { t.Parallel() resp, err := StringToMode("lol") - if !errors.Is(err, ErrInvalidCollateralMode) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrInvalidCollateralMode) + if resp != UnknownMode { t.Errorf("received '%v' expected '%v'", resp, UnknownMode) } diff --git a/exchanges/credentials_test.go b/exchanges/credentials_test.go index 49f4229b..7d458f28 100644 --- a/exchanges/credentials_test.go +++ b/exchanges/credentials_test.go @@ -2,9 +2,9 @@ package exchange import ( "context" - "errors" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" @@ -15,32 +15,25 @@ func TestGetCredentials(t *testing.T) { t.Parallel() var b Base _, err := b.GetCredentials(t.Context()) - if !errors.Is(err, ErrCredentialsAreEmpty) { - t.Fatalf("received: %v but expected: %v", err, ErrCredentialsAreEmpty) - } + require.ErrorIs(t, err, ErrCredentialsAreEmpty) b.API.CredentialsValidator.RequiresKey = true ctx := account.DeployCredentialsToContext(t.Context(), &account.Credentials{Secret: "wow"}) _, err = b.GetCredentials(ctx) - if !errors.Is(err, errRequiresAPIKey) { - t.Fatalf("received: %v but expected: %v", err, errRequiresAPIKey) - } + require.ErrorIs(t, err, errRequiresAPIKey) b.API.CredentialsValidator.RequiresSecret = true ctx = account.DeployCredentialsToContext(t.Context(), &account.Credentials{Key: "wow"}) _, err = b.GetCredentials(ctx) - if !errors.Is(err, errRequiresAPISecret) { - t.Fatalf("received: %v but expected: %v", err, errRequiresAPISecret) - } + require.ErrorIs(t, err, errRequiresAPISecret) b.API.CredentialsValidator.RequiresBase64DecodeSecret = true ctx = account.DeployCredentialsToContext(t.Context(), &account.Credentials{ Key: "meow", Secret: "invalidb64", }) - if _, err = b.GetCredentials(ctx); !errors.Is(err, errBase64DecodeFailure) { - t.Fatalf("received: %v but expected: %v", err, errBase64DecodeFailure) - } + _, err = b.GetCredentials(ctx) + require.ErrorIs(t, err, errBase64DecodeFailure) const expectedBase64DecodedOutput = "hello world" ctx = account.DeployCredentialsToContext(t.Context(), &account.Credentials{ @@ -92,9 +85,7 @@ func TestGetCredentials(t *testing.T) { ctx = account.DeployCredentialsToContext(t.Context(), lonelyCred) b.API.CredentialsValidator.RequiresClientID = true _, err = b.GetCredentials(ctx) - if !errors.Is(err, errRequiresAPIClientID) { - t.Fatalf("received: %v but expected: %v", err, errRequiresAPIClientID) - } + require.ErrorIs(t, err, errRequiresAPIClientID) b.API.SetKey("hello") b.API.SetSecret("sir") @@ -203,9 +194,8 @@ func TestVerifyAPICredentials(t *testing.T) { t.Run("", func(t *testing.T) { t.Parallel() b := setupBase(&tc) - if err := b.VerifyAPICredentials(&b.API.credentials); !errors.Is(err, tc.Expected) { - t.Errorf("Test %d: expected: %v: got %v", x+1, tc.Expected, err) - } + assert.ErrorIs(t, b.VerifyAPICredentials(&b.API.credentials), tc.Expected) + if tc.CheckBase64DecodedOutput { if b.API.credentials.Secret != expectedBase64DecodedOutput { t.Errorf("Test %d: expected: %v: got %v", x+1, expectedBase64DecodedOutput, b.API.credentials.Secret) @@ -301,9 +291,8 @@ func TestCheckCredentials(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - if err := tc.base.CheckCredentials(&tc.base.API.credentials, false); !errors.Is(err, tc.expectedErr) { - t.Errorf("%s: received '%v' but expected '%v'", tc.name, err, tc.expectedErr) - } + assert.ErrorIs(t, tc.base.CheckCredentials(&tc.base.API.credentials, false), tc.expectedErr) + if tc.checkBase64Output { if tc.base.API.credentials.SecretBase64Decoded != true { t.Errorf("%s: expected secret to be base64 decoded", tc.name) diff --git a/exchanges/currencystate/currency_state_test.go b/exchanges/currencystate/currency_state_test.go index 9d8de761..eda41af9 100644 --- a/exchanges/currencystate/currency_state_test.go +++ b/exchanges/currencystate/currency_state_test.go @@ -1,7 +1,6 @@ package currencystate import ( - "errors" "sync" "testing" @@ -20,9 +19,7 @@ func TestNewCurrencyStates(t *testing.T) { func TestGetSnapshot(t *testing.T) { t.Parallel() _, err := (*States)(nil).GetCurrencyStateSnapshot() - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) o, err := (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -43,20 +40,14 @@ func TestGetSnapshot(t *testing.T) { func TestCanTradePair(t *testing.T) { t.Parallel() err := (*States)(nil).CanTradePair(currency.EMPTYPAIR, asset.Empty) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) err = (&States{}).CanTradePair(currency.EMPTYPAIR, asset.Empty) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) cp := currency.NewBTCUSD() err = (&States{}).CanTradePair(cp, asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = (&States{}).CanTradePair(cp, asset.Spot) require.NoError(t, err) @@ -80,9 +71,7 @@ func TestCanTradePair(t *testing.T) { }, }, }).CanTradePair(cp, asset.Spot) - if !errors.Is(err, errTradingNotAllowed) { - t.Fatalf("received: %v, but expected: %v", err, errTradingNotAllowed) - } + require.ErrorIs(t, err, errTradingNotAllowed) err = (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -92,9 +81,7 @@ func TestCanTradePair(t *testing.T) { }, }, }).CanTradePair(cp, asset.Spot) - if !errors.Is(err, errTradingNotAllowed) { - t.Fatalf("received: %v, but expected: %v", err, errTradingNotAllowed) - } + require.ErrorIs(t, err, errTradingNotAllowed) err = (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -104,33 +91,25 @@ func TestCanTradePair(t *testing.T) { }, }, }).CanTradePair(cp, asset.Spot) - if !errors.Is(err, errTradingNotAllowed) { - t.Fatalf("received: %v, but expected: %v", err, errTradingNotAllowed) - } + require.ErrorIs(t, err, errTradingNotAllowed) } func TestStatesCanTrade(t *testing.T) { t.Parallel() err := (*States)(nil).CanTrade(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) + err = (&States{}).CanTrade(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) } func TestStatesCanWithdraw(t *testing.T) { t.Parallel() err := (*States)(nil).CanWithdraw(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) + err = (&States{}).CanWithdraw(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) err = (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -148,21 +127,16 @@ func TestStatesCanWithdraw(t *testing.T) { }, }, }).CanWithdraw(currency.BTC, asset.Spot) - if !errors.Is(err, errWithdrawalsNotAllowed) { - t.Fatalf("received: %v, but expected: %v", err, errWithdrawalsNotAllowed) - } + require.ErrorIs(t, err, errWithdrawalsNotAllowed) } func TestStatesCanDeposit(t *testing.T) { t.Parallel() err := (*States)(nil).CanDeposit(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) + err = (&States{}).CanDeposit(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) err = (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -180,27 +154,19 @@ func TestStatesCanDeposit(t *testing.T) { }, }, }).CanDeposit(currency.BTC, asset.Spot) - if !errors.Is(err, errDepositNotAllowed) { - t.Fatalf("received: %v, but expected: %v", err, errDepositNotAllowed) - } + require.ErrorIs(t, err, errDepositNotAllowed) } func TestStatesUpdateAll(t *testing.T) { t.Parallel() err := (*States)(nil).UpdateAll(asset.Empty, nil) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) err = (&States{}).UpdateAll(asset.Empty, nil) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = (&States{}).UpdateAll(asset.Spot, nil) - if !errors.Is(err, errUpdatesAreNil) { - t.Fatalf("received: %v, but expected: %v", err, errUpdatesAreNil) - } + require.ErrorIs(t, err, errUpdatesAreNil) s := &States{ m: map[asset.Item]map[*currency.Item]*Currency{}, @@ -233,19 +199,13 @@ func TestStatesUpdateAll(t *testing.T) { func TestStatesUpdate(t *testing.T) { t.Parallel() err := (*States)(nil).Update(currency.EMPTYCODE, asset.Empty, Options{}) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) err = (&States{}).Update(currency.EMPTYCODE, asset.Empty, Options{}) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) err = (&States{}).Update(currency.BTC, asset.Empty, Options{}) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = (&States{ m: map[asset.Item]map[*currency.Item]*Currency{ @@ -258,24 +218,16 @@ func TestStatesUpdate(t *testing.T) { func TestStatesGet(t *testing.T) { t.Parallel() _, err := (*States)(nil).Get(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errNilStates) { - t.Fatalf("received: %v, but expected: %v", err, errNilStates) - } + require.ErrorIs(t, err, errNilStates) _, err = (&States{}).Get(currency.EMPTYCODE, asset.Empty) - if !errors.Is(err, errEmptyCurrency) { - t.Fatalf("received: %v, but expected: %v", err, errEmptyCurrency) - } + require.ErrorIs(t, err, errEmptyCurrency) _, err = (&States{}).Get(currency.BTC, asset.Empty) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: %v, but expected: %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = (&States{}).Get(currency.BTC, asset.Spot) - if !errors.Is(err, ErrCurrencyStateNotFound) { - t.Fatalf("received: %v, but expected: %v", err, ErrCurrencyStateNotFound) - } + require.ErrorIs(t, err, ErrCurrencyStateNotFound) } func TestCurrencyGetState(t *testing.T) { diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index 71a3a263..6912fd7a 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -1548,9 +1548,7 @@ func TestCheckTransientError(t *testing.T) { func TestDisableEnableRateLimiter(t *testing.T) { b := Base{} err := b.EnableRateLimiter() - if !errors.Is(err, request.ErrRequestSystemIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, request.ErrRequestSystemIsNil) b.Requester, err = request.New("testingRateLimiter", common.NewHTTPClientWithTimeout(0)) if err != nil { @@ -1561,17 +1559,13 @@ func TestDisableEnableRateLimiter(t *testing.T) { require.NoError(t, err) err = b.DisableRateLimiter() - if !errors.Is(err, request.ErrRateLimiterAlreadyDisabled) { - t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyDisabled) - } + require.ErrorIs(t, err, request.ErrRateLimiterAlreadyDisabled) err = b.EnableRateLimiter() require.NoError(t, err) err = b.EnableRateLimiter() - if !errors.Is(err, request.ErrRateLimiterAlreadyEnabled) { - t.Fatalf("received: '%v' but expected: '%v'", err, request.ErrRateLimiterAlreadyEnabled) - } + require.ErrorIs(t, err, request.ErrRateLimiterAlreadyEnabled) } func TestGetWebsocket(t *testing.T) { @@ -1878,9 +1872,7 @@ func TestAssetWebsocketFunctionality(t *testing.T) { } err := b.DisableAssetWebsocketSupport(asset.Spot) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("expected error: %v but received: %v", asset.ErrNotSupported, err) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = b.SetAssetPairStore(asset.Spot, currency.PairStore{ RequestFormat: ¤cy.PairFormat{ @@ -1942,9 +1934,7 @@ func TestGetGetURLTypeFromString(t *testing.T) { t.Run(tt.Endpoint, func(t *testing.T) { t.Parallel() u, err := getURLTypeFromString(tt.Endpoint) - if !errors.Is(err, tt.Error) { - t.Fatalf("received: %v but expected: %v", err, tt.Error) - } + require.ErrorIs(t, err, tt.Error) if u != tt.Expected { t.Fatalf("received: %v but expected: %v", u, tt.Expected) @@ -1956,41 +1946,36 @@ func TestGetGetURLTypeFromString(t *testing.T) { func TestGetAvailableTransferChains(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetAvailableTransferChains(t.Context(), currency.BTC); !errors.Is(err, common.ErrFunctionNotSupported) { - t.Errorf("received: %v, expected: %v", err, common.ErrFunctionNotSupported) - } + _, err := b.GetAvailableTransferChains(t.Context(), currency.BTC) + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) } func TestCalculatePNL(t *testing.T) { t.Parallel() var b Base - if _, err := b.CalculatePNL(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.CalculatePNL(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestScaleCollateral(t *testing.T) { t.Parallel() var b Base - if _, err := b.ScaleCollateral(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.ScaleCollateral(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestCalculateTotalCollateral(t *testing.T) { t.Parallel() var b Base - if _, err := b.CalculateTotalCollateral(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.CalculateTotalCollateral(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestUpdateCurrencyStates(t *testing.T) { t.Parallel() var b Base - if err := b.UpdateCurrencyStates(t.Context(), asset.Spot); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + err := b.UpdateCurrencyStates(t.Context(), asset.Spot) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestSetTradeFeedStatus(t *testing.T) { @@ -2032,49 +2017,43 @@ func TestSetFillsFeedStatus(t *testing.T) { func TestGetMarginRateHistory(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetMarginRatesHistory(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.GetMarginRatesHistory(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetPositionSummary(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetFuturesPositionSummary(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.GetFuturesPositionSummary(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetFuturesPositions(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetFuturesPositionOrders(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.GetFuturesPositionOrders(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetHistoricalFundingRates(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetHistoricalFundingRates(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.GetHistoricalFundingRates(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetFundingRates(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetHistoricalFundingRates(t.Context(), nil); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.GetHistoricalFundingRates(t.Context(), nil) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestIsPerpetualFutureCurrency(t *testing.T) { t.Parallel() var b Base - if _, err := b.IsPerpetualFutureCurrency(asset.Spot, currency.NewBTCUSD()); !errors.Is(err, common.ErrNotYetImplemented) { - t.Errorf("received: %v, expected: %v", err, common.ErrNotYetImplemented) - } + _, err := b.IsPerpetualFutureCurrency(asset.Spot, currency.NewBTCUSD()) + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetPairAndAssetTypeRequestFormatted(t *testing.T) { @@ -2109,19 +2088,13 @@ func TestGetPairAndAssetTypeRequestFormatted(t *testing.T) { } _, _, err := b.GetPairAndAssetTypeRequestFormatted("") - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) _, _, err = b.GetPairAndAssetTypeRequestFormatted("BTCAUD") - if !errors.Is(err, ErrSymbolCannotBeMatched) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSymbolCannotBeMatched) - } + require.ErrorIs(t, err, ErrSymbolCannotBeMatched) _, _, err = b.GetPairAndAssetTypeRequestFormatted("BTCUSDT") - if !errors.Is(err, ErrSymbolCannotBeMatched) { - t.Fatalf("received: '%v' but expected: '%v'", err, ErrSymbolCannotBeMatched) - } + require.ErrorIs(t, err, ErrSymbolCannotBeMatched) p, a, err := b.GetPairAndAssetTypeRequestFormatted("BTC-USDT") require.NoError(t, err) @@ -2176,18 +2149,14 @@ func TestGetCollateralCurrencyForContract(t *testing.T) { t.Parallel() b := Base{} _, _, err := b.GetCollateralCurrencyForContract(asset.Futures, currency.NewPair(currency.XRP, currency.BABYDOGE)) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrNotYetImplemented) - } + require.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetCurrencyForRealisedPNL(t *testing.T) { t.Parallel() b := Base{} _, _, err := b.GetCurrencyForRealisedPNL(asset.Empty, currency.EMPTYPAIR) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrNotYetImplemented) - } + require.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestHasAssetTypeAccountSegregation(t *testing.T) { @@ -2361,63 +2330,50 @@ func TestSetCollateralMode(t *testing.T) { t.Parallel() b := Base{} err := b.SetCollateralMode(t.Context(), asset.Spot, collateral.SingleMode) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetCollateralMode(t *testing.T) { t.Parallel() b := Base{} _, err := b.GetCollateralMode(t.Context(), asset.Spot) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestSetMarginType(t *testing.T) { t.Parallel() b := Base{} err := b.SetMarginType(t.Context(), asset.Spot, currency.NewBTCUSD(), margin.Multi) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestChangePositionMargin(t *testing.T) { t.Parallel() b := Base{} _, err := b.ChangePositionMargin(t.Context(), nil) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestSetLeverage(t *testing.T) { t.Parallel() b := Base{} err := b.SetLeverage(t.Context(), asset.Spot, currency.NewBTCUSD(), margin.Multi, 1, order.UnknownSide) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestGetLeverage(t *testing.T) { t.Parallel() b := Base{} _, err := b.GetLeverage(t.Context(), asset.Spot, currency.NewBTCUSD(), margin.Multi, order.UnknownSide) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNotYetImplemented) } func TestEnsureOnePairEnabled(t *testing.T) { t.Parallel() b := Base{Name: "test"} err := b.EnsureOnePairEnabled() - if !errors.Is(err, currency.ErrCurrencyPairsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrCurrencyPairsEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairsEmpty) + b.CurrencyPairs = currency.PairsManager{ Pairs: map[asset.Item]*currency.PairStore{ asset.Futures: {}, @@ -2449,15 +2405,11 @@ func TestGetStandardConfig(t *testing.T) { var b *Base _, err := b.GetStandardConfig() - if !errors.Is(err, errExchangeIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeIsNil) - } + require.ErrorIs(t, err, errExchangeIsNil) b = &Base{} _, err = b.GetStandardConfig() - if !errors.Is(err, errSetDefaultsNotCalled) { - t.Fatalf("received: '%v' but expected: '%v'", err, errSetDefaultsNotCalled) - } + require.ErrorIs(t, err, errSetDefaultsNotCalled) b.Name = "test" b.Features.Supports.Websocket = true @@ -2499,9 +2451,7 @@ func TestMatchSymbolWithAvailablePairs(t *testing.T) { } _, err = b.MatchSymbolWithAvailablePairs("sillBillies", asset.Futures, false) - if !errors.Is(err, currency.ErrPairNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrPairNotFound) - } + require.ErrorIs(t, err, currency.ErrPairNotFound) whatIGot, err := b.MatchSymbolWithAvailablePairs("btcusdT", asset.Spot, false) require.NoError(t, err) @@ -2533,9 +2483,7 @@ func TestMatchSymbolCheckEnabled(t *testing.T) { } _, _, err = b.MatchSymbolCheckEnabled("sillBillies", asset.Futures, false) - if !errors.Is(err, currency.ErrPairNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, currency.ErrPairNotFound) - } + require.ErrorIs(t, err, currency.ErrPairNotFound) whatIGot, enabled, err := b.MatchSymbolCheckEnabled("btcusdT", asset.Spot, false) require.NoError(t, err) @@ -2610,9 +2558,8 @@ func TestIsPairEnabled(t *testing.T) { func TestGetOpenInterest(t *testing.T) { t.Parallel() var b Base - if _, err := b.GetOpenInterest(t.Context()); !errors.Is(err, common.ErrFunctionNotSupported) { - t.Errorf("received: %v, expected: %v", err, common.ErrFunctionNotSupported) - } + _, err := b.GetOpenInterest(t.Context()) + assert.ErrorIs(t, err, common.ErrFunctionNotSupported) } func TestGetCachedOpenInterest(t *testing.T) { diff --git a/exchanges/fill/fill_test.go b/exchanges/fill/fill_test.go index fe15e1c1..87250501 100644 --- a/exchanges/fill/fill_test.go +++ b/exchanges/fill/fill_test.go @@ -1,9 +1,10 @@ package fill import ( - "errors" "testing" "time" + + "github.com/stretchr/testify/assert" ) // TestSetup tests the setup function of the Fills struct @@ -28,9 +29,7 @@ func TestUpdateDisabledFeed(t *testing.T) { // Send a test data to the Update function testData := Data{Timestamp: time.Now(), Price: 15.2, Amount: 3.2} - if err := fill.Update(testData); !errors.Is(err, ErrFeedDisabled) { - t.Errorf("Expected ErrFeedDisabled, got %v", err) - } + assert.ErrorIs(t, fill.Update(testData), ErrFeedDisabled) select { case <-channel: diff --git a/exchanges/futures/futures_test.go b/exchanges/futures/futures_test.go index f1797208..9c96bb74 100644 --- a/exchanges/futures/futures_test.go +++ b/exchanges/futures/futures_test.go @@ -2,7 +2,6 @@ package futures import ( "context" - "errors" "testing" "time" @@ -48,9 +47,8 @@ func TestUpsertPNLEntry(t *testing.T) { IsOrder: true, } _, err := upsertPNLEntry(results, result) - if !errors.Is(err, errTimeUnset) { - t.Error(err) - } + assert.ErrorIs(t, err, errTimeUnset) + tt := time.Now() result.Time = tt results, err = upsertPNLEntry(results, result) @@ -87,13 +85,10 @@ func TestTrackNewOrder(t *testing.T) { assert.NoError(t, err) err = c.TrackNewOrder(nil, false) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) + err = c.TrackNewOrder(&order.Detail{}, false) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) od := &order.Detail{ Exchange: exch, @@ -103,17 +98,14 @@ func TestTrackNewOrder(t *testing.T) { Price: 1337, } err = c.TrackNewOrder(od, false) - if !errors.Is(err, order.ErrSideIsInvalid) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrSideIsInvalid) od.Side = order.Long od.Amount = 1 od.OrderID = "2" err = c.TrackNewOrder(od, false) - if !errors.Is(err, errTimeUnset) { - t.Error(err) - } + assert.ErrorIs(t, err, errTimeUnset) + c.openingDirection = order.Long od.Date = time.Now() err = c.TrackNewOrder(od, false) @@ -183,9 +175,8 @@ func TestTrackNewOrder(t *testing.T) { od.OrderID = "hellomoto" err = c.TrackNewOrder(od, false) - if !errors.Is(err, ErrPositionClosed) { - t.Errorf("received %v expected %v", err, ErrPositionClosed) - } + assert.ErrorIs(t, err, ErrPositionClosed) + if c.latestDirection != order.ClosePosition { t.Errorf("expected recognition that its closed, received '%v'", c.latestDirection) } @@ -194,9 +185,7 @@ func TestTrackNewOrder(t *testing.T) { } err = c.TrackNewOrder(od, true) - if !errors.Is(err, errCannotTrackInvalidParams) { - t.Error(err) - } + assert.ErrorIs(t, err, errCannotTrackInvalidParams) c, err = SetupPositionTracker(setup) assert.NoError(t, err) @@ -206,40 +195,30 @@ func TestTrackNewOrder(t *testing.T) { var ptp *PositionTracker err = ptp.TrackNewOrder(nil, false) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestSetupMultiPositionTracker(t *testing.T) { t.Parallel() _, err := SetupMultiPositionTracker(nil) - if !errors.Is(err, errNilSetup) { - t.Error(err) - } + assert.ErrorIs(t, err, errNilSetup) setup := &MultiPositionTrackerSetup{} _, err = SetupMultiPositionTracker(setup) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) + setup.Exchange = testExchange _, err = SetupMultiPositionTracker(setup) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) + setup.Asset = asset.Futures _, err = SetupMultiPositionTracker(setup) - if !errors.Is(err, order.ErrPairIsEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrPairIsEmpty) setup.Pair = currency.NewBTCUSDT() _, err = SetupMultiPositionTracker(setup) - if !errors.Is(err, errEmptyUnderlying) { - t.Error(err) - } + assert.ErrorIs(t, err, errEmptyUnderlying) setup.Underlying = currency.BTC _, err = SetupMultiPositionTracker(setup) @@ -247,9 +226,7 @@ func TestSetupMultiPositionTracker(t *testing.T) { setup.UseExchangePNLCalculation = true _, err = SetupMultiPositionTracker(setup) - if !errors.Is(err, errMissingPNLCalculationFunctions) { - t.Error(err) - } + assert.ErrorIs(t, err, errMissingPNLCalculationFunctions) setup.ExchangePNLCalculation = &FakePNL{} resp, err := SetupMultiPositionTracker(setup) @@ -272,9 +249,7 @@ func TestMultiPositionTrackerTrackNewOrder(t *testing.T) { ExchangePNLCalculation: &FakePNL{}, } _, err := SetupMultiPositionTracker(setup) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) setup.Exchange = testExchange resp, err := SetupMultiPositionTracker(setup) @@ -289,9 +264,7 @@ func TestMultiPositionTrackerTrackNewOrder(t *testing.T) { OrderID: "1", Amount: 1, }) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) err = resp.TrackNewOrder(&order.Detail{ Date: tt, @@ -351,9 +324,7 @@ func TestMultiPositionTrackerTrackNewOrder(t *testing.T) { OrderID: "4", Amount: 2, }) - if !errors.Is(err, errPositionDiscrepancy) { - t.Errorf("received '%v' expected '%v", err, errPositionDiscrepancy) - } + assert.ErrorIs(t, err, errPositionDiscrepancy) resp.positions = []*PositionTracker{resp.positions[0]} resp.positions[0].status = order.Closed @@ -397,14 +368,10 @@ func TestMultiPositionTrackerTrackNewOrder(t *testing.T) { OrderID: "5", Amount: 2, }) - if !errors.Is(err, errAssetMismatch) { - t.Error(err) - } + assert.ErrorIs(t, err, errAssetMismatch) err = resp.TrackNewOrder(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) resp = nil err = resp.TrackNewOrder(&order.Detail{ @@ -416,9 +383,7 @@ func TestMultiPositionTrackerTrackNewOrder(t *testing.T) { OrderID: "5", Amount: 2, }) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestSetupPositionControllerReal(t *testing.T) { @@ -433,9 +398,7 @@ func TestPositionControllerTestTrackNewOrder(t *testing.T) { t.Parallel() pc := SetupPositionController() err := pc.TrackNewOrder(nil) - if !errors.Is(err, errNilOrder) { - t.Error(err) - } + assert.ErrorIs(t, err, errNilOrder) err = pc.TrackNewOrder(&order.Detail{ Date: time.Now(), @@ -445,9 +408,7 @@ func TestPositionControllerTestTrackNewOrder(t *testing.T) { Side: order.Long, OrderID: "lol", }) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) err = pc.TrackNewOrder(&order.Detail{ Date: time.Now(), @@ -456,9 +417,7 @@ func TestPositionControllerTestTrackNewOrder(t *testing.T) { Side: order.Long, OrderID: "lol", }) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) err = pc.TrackNewOrder(&order.Detail{ Exchange: testExchange, @@ -472,18 +431,14 @@ func TestPositionControllerTestTrackNewOrder(t *testing.T) { var pcp *PositionController err = pcp.TrackNewOrder(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestGetLatestPNLSnapshot(t *testing.T) { t.Parallel() pt := PositionTracker{} _, err := pt.GetLatestPNLSnapshot() - if !errors.Is(err, errNoPNLHistory) { - t.Error(err) - } + assert.ErrorIs(t, err, errNoPNLHistory) pnl := PNLResult{ Time: time.Now(), @@ -569,14 +524,11 @@ func TestGetPositionsForExchange(t *testing.T) { p := currency.NewBTCUSDT() _, err := c.GetPositionsForExchange("", asset.Futures, p) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) pos, err := c.GetPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) + if len(pos) != 0 { t.Error("expected zero") } @@ -588,9 +540,8 @@ func TestGetPositionsForExchange(t *testing.T) { Asset: asset.Futures, }] = nil _, err = c.GetPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) + c.multiPositionTrackers[key.ExchangePairAsset{ Exchange: testExchange, Base: p.Base.Item, @@ -598,13 +549,10 @@ func TestGetPositionsForExchange(t *testing.T) { Asset: asset.Futures, }] = nil _, err = c.GetPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) + _, err = c.GetPositionsForExchange(testExchange, asset.Spot, p) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) c.multiPositionTrackers[key.ExchangePairAsset{ Exchange: testExchange, @@ -645,9 +593,7 @@ func TestGetPositionsForExchange(t *testing.T) { } c = nil _, err = c.GetPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestClearPositionsForExchange(t *testing.T) { @@ -655,24 +601,17 @@ func TestClearPositionsForExchange(t *testing.T) { c := &PositionController{} p := currency.NewBTCUSDT() err := c.ClearPositionsForExchange("", asset.Futures, p) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) err = c.ClearPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) + c.multiPositionTrackers = make(map[key.ExchangePairAsset]*MultiPositionTracker) err = c.ClearPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) err = c.ClearPositionsForExchange(testExchange, asset.Spot, p) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) c.multiPositionTrackers[key.ExchangePairAsset{ Exchange: testExchange, @@ -701,9 +640,7 @@ func TestClearPositionsForExchange(t *testing.T) { } c = nil _, err = c.GetPositionsForExchange(testExchange, asset.Futures, p) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestCalculateRealisedPNL(t *testing.T) { @@ -742,18 +679,16 @@ func TestCalculateRealisedPNL(t *testing.T) { func TestSetupPositionTracker(t *testing.T) { t.Parallel() p, err := SetupPositionTracker(nil) - if !errors.Is(err, errNilSetup) { - t.Errorf("received '%v' expected '%v", err, errNilSetup) - } + assert.ErrorIs(t, err, errNilSetup) + if p != nil { t.Error("expected nil") } p, err = SetupPositionTracker(&PositionTrackerSetup{ Asset: asset.Spot, }) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) + if p != nil { t.Error("expected nil") } @@ -762,9 +697,8 @@ func TestSetupPositionTracker(t *testing.T) { Exchange: testExchange, Asset: asset.Spot, }) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) + if p != nil { t.Error("expected nil") } @@ -773,9 +707,8 @@ func TestSetupPositionTracker(t *testing.T) { Exchange: testExchange, Asset: asset.Futures, }) - if !errors.Is(err, order.ErrPairIsEmpty) { - t.Errorf("received '%v' expected '%v", err, order.ErrPairIsEmpty) - } + assert.ErrorIs(t, err, order.ErrPairIsEmpty) + if p != nil { t.Error("expected nil") } @@ -801,9 +734,8 @@ func TestSetupPositionTracker(t *testing.T) { Pair: cp, UseExchangePNLCalculation: true, }) - if !errors.Is(err, ErrNilPNLCalculator) { - t.Errorf("received '%v' expected '%v", err, ErrNilPNLCalculator) - } + assert.ErrorIs(t, err, ErrNilPNLCalculator) + p, err = SetupPositionTracker(&PositionTrackerSetup{ Exchange: testExchange, Asset: asset.Futures, @@ -822,22 +754,17 @@ func TestCalculatePNL(t *testing.T) { t.Parallel() p := &PNLCalculator{} _, err := p.CalculatePNL(t.Context(), nil) - if !errors.Is(err, ErrNilPNLCalculator) { - t.Errorf("received '%v' expected '%v", err, ErrNilPNLCalculator) - } + assert.ErrorIs(t, err, ErrNilPNLCalculator) + _, err = p.CalculatePNL(t.Context(), &PNLCalculatorRequest{}) - if !errors.Is(err, errCannotCalculateUnrealisedPNL) { - t.Errorf("received '%v' expected '%v", err, errCannotCalculateUnrealisedPNL) - } + assert.ErrorIs(t, err, errCannotCalculateUnrealisedPNL) _, err = p.CalculatePNL(t.Context(), &PNLCalculatorRequest{ OrderDirection: order.Short, CurrentDirection: order.Long, }) - if !errors.Is(err, errCannotCalculateUnrealisedPNL) { - t.Errorf("received '%v' expected '%v", err, errCannotCalculateUnrealisedPNL) - } + assert.ErrorIs(t, err, errCannotCalculateUnrealisedPNL) } func TestTrackPNLByTime(t *testing.T) { @@ -854,9 +781,7 @@ func TestTrackPNLByTime(t *testing.T) { } p = nil err = p.TrackPNLByTime(time.Now(), 2) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { @@ -864,19 +789,13 @@ func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { pc := SetupPositionController() _, err := pc.UpdateOpenPositionUnrealisedPNL("", asset.Futures, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) _, err = pc.UpdateOpenPositionUnrealisedPNL("hi", asset.Futures, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) _, err = pc.UpdateOpenPositionUnrealisedPNL("hi", asset.Spot, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) err = pc.TrackNewOrder(&order.Detail{ Date: time.Now(), @@ -891,19 +810,13 @@ func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { assert.NoError(t, err) _, err = pc.UpdateOpenPositionUnrealisedPNL("hi2", asset.Futures, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) _, err = pc.UpdateOpenPositionUnrealisedPNL("hi", asset.PerpetualSwap, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) _, err = pc.UpdateOpenPositionUnrealisedPNL("hi", asset.Futures, currency.NewPair(currency.BTC, currency.DOGE), 2, time.Now()) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) pnl, err := pc.UpdateOpenPositionUnrealisedPNL("hi", asset.Futures, currency.NewBTCUSDT(), 2, time.Now()) assert.NoError(t, err) @@ -914,34 +827,25 @@ func TestUpdateOpenPositionUnrealisedPNL(t *testing.T) { var nilPC *PositionController _, err = nilPC.UpdateOpenPositionUnrealisedPNL("hi", asset.Futures, currency.NewBTCUSDT(), 2, time.Now()) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestSetCollateralCurrency(t *testing.T) { t.Parallel() pc := SetupPositionController() err := pc.SetCollateralCurrency("", asset.Spot, currency.EMPTYPAIR, currency.Code{}) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) err = pc.SetCollateralCurrency("hi", asset.Spot, currency.EMPTYPAIR, currency.Code{}) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) + p := currency.NewBTCUSDT() pc.multiPositionTrackers = make(map[key.ExchangePairAsset]*MultiPositionTracker) err = pc.SetCollateralCurrency("hi", asset.Futures, p, currency.DOGE) - if !errors.Is(err, ErrPositionNotFound) { - t.Fatalf("received '%v' expected '%v", err, ErrPositionNotFound) - } + require.ErrorIs(t, err, ErrPositionNotFound) err = pc.SetCollateralCurrency("hi", asset.Futures, p, currency.DOGE) - if !errors.Is(err, ErrPositionNotFound) { - t.Fatalf("received '%v' expected '%v", err, ErrPositionNotFound) - } + require.ErrorIs(t, err, ErrPositionNotFound) mapKey := key.ExchangePairAsset{ Exchange: "hi", @@ -981,9 +885,7 @@ func TestSetCollateralCurrency(t *testing.T) { var nilPC *PositionController err = nilPC.SetCollateralCurrency("hi", asset.Spot, currency.EMPTYPAIR, currency.Code{}) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestMPTUpdateOpenPositionUnrealisedPNL(t *testing.T) { @@ -1018,15 +920,11 @@ func TestMPTUpdateOpenPositionUnrealisedPNL(t *testing.T) { pc.multiPositionTrackers[mapKey].positions[0].status = order.Closed _, err = pc.multiPositionTrackers[mapKey].UpdateOpenPositionUnrealisedPNL(1337, time.Now()) - if !errors.Is(err, ErrPositionClosed) { - t.Fatalf("received '%v' expected '%v", err, ErrPositionClosed) - } + require.ErrorIs(t, err, ErrPositionClosed) pc.multiPositionTrackers[mapKey].positions = nil _, err = pc.multiPositionTrackers[mapKey].UpdateOpenPositionUnrealisedPNL(1337, time.Now()) - if !errors.Is(err, ErrPositionNotFound) { - t.Fatalf("received '%v' expected '%v", err, ErrPositionNotFound) - } + require.ErrorIs(t, err, ErrPositionNotFound) } func TestMPTLiquidate(t *testing.T) { @@ -1043,18 +941,14 @@ func TestMPTLiquidate(t *testing.T) { } err = e.Liquidate(decimal.Zero, time.Time{}) - if !errors.Is(err, ErrPositionNotFound) { - t.Error(err) - } + assert.ErrorIs(t, err, ErrPositionNotFound) setup := &PositionTrackerSetup{ Pair: pair, Asset: item, } _, err = SetupPositionTracker(setup) - if !errors.Is(err, errExchangeNameEmpty) { - t.Error(err) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) setup.Exchange = "exch" _, err = SetupPositionTracker(setup) @@ -1074,9 +968,7 @@ func TestMPTLiquidate(t *testing.T) { assert.NoError(t, err) err = e.Liquidate(decimal.Zero, time.Time{}) - if !errors.Is(err, order.ErrCannotLiquidate) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrCannotLiquidate) err = e.Liquidate(decimal.Zero, tt) assert.NoError(t, err) @@ -1090,9 +982,7 @@ func TestMPTLiquidate(t *testing.T) { e = nil err = e.Liquidate(decimal.Zero, tt) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestPositionLiquidate(t *testing.T) { @@ -1124,9 +1014,7 @@ func TestPositionLiquidate(t *testing.T) { assert.NoError(t, err) err = p.Liquidate(decimal.Zero, time.Time{}) - if !errors.Is(err, order.ErrCannotLiquidate) { - t.Error(err) - } + assert.ErrorIs(t, err, order.ErrCannotLiquidate) err = p.Liquidate(decimal.Zero, tt) assert.NoError(t, err) @@ -1140,9 +1028,7 @@ func TestPositionLiquidate(t *testing.T) { p = nil err = p.Liquidate(decimal.Zero, tt) - if !errors.Is(err, common.ErrNilPointer) { - t.Error(err) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestGetOpenPosition(t *testing.T) { @@ -1152,14 +1038,10 @@ func TestGetOpenPosition(t *testing.T) { tn := time.Now() _, err := pc.GetOpenPosition("", asset.Futures, cp) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) _, err = pc.GetOpenPosition(testExchange, asset.Futures, cp) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) err = pc.TrackNewOrder(&order.Detail{ Date: tn, @@ -1182,9 +1064,7 @@ func TestGetAllOpenPositions(t *testing.T) { pc := SetupPositionController() _, err := pc.GetAllOpenPositions() - if !errors.Is(err, ErrNoPositionsFound) { - t.Errorf("received '%v' expected '%v", err, ErrNoPositionsFound) - } + assert.ErrorIs(t, err, ErrNoPositionsFound) cp := currency.NewPair(currency.BTC, currency.PERP) tn := time.Now() @@ -1208,9 +1088,7 @@ func TestPCTrackFundingDetails(t *testing.T) { t.Parallel() pc := SetupPositionController() err := pc.TrackFundingDetails(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) p := currency.NewPair(currency.BTC, currency.PERP) rates := &fundingrate.HistoricalRates{ @@ -1218,15 +1096,11 @@ func TestPCTrackFundingDetails(t *testing.T) { Pair: p, } err = pc.TrackFundingDetails(rates) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) rates.Exchange = testExchange err = pc.TrackFundingDetails(rates) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) tn := time.Now() err = pc.TrackNewOrder(&order.Detail{ @@ -1271,9 +1145,7 @@ func TestMPTTrackFundingDetails(t *testing.T) { } err := mpt.TrackFundingDetails(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) cp := currency.NewPair(currency.BTC, currency.PERP) rates := &fundingrate.HistoricalRates{ @@ -1281,9 +1153,7 @@ func TestMPTTrackFundingDetails(t *testing.T) { Pair: cp, } err = mpt.TrackFundingDetails(rates) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) mpt.exchange = testExchange rates = &fundingrate.HistoricalRates{ @@ -1292,16 +1162,12 @@ func TestMPTTrackFundingDetails(t *testing.T) { Pair: cp, } err = mpt.TrackFundingDetails(rates) - if !errors.Is(err, errAssetMismatch) { - t.Errorf("received '%v' expected '%v", err, errAssetMismatch) - } + assert.ErrorIs(t, err, errAssetMismatch) mpt.asset = rates.Asset mpt.pair = cp err = mpt.TrackFundingDetails(rates) - if !errors.Is(err, ErrPositionNotFound) { - t.Errorf("received '%v' expected '%v", err, ErrPositionNotFound) - } + assert.ErrorIs(t, err, ErrPositionNotFound) tn := time.Now() err = mpt.TrackNewOrder(&order.Detail{ @@ -1329,18 +1195,14 @@ func TestMPTTrackFundingDetails(t *testing.T) { mpt.orderPositions["lol"].lastUpdated = tn rates.Exchange = "lol" err = mpt.TrackFundingDetails(rates) - if !errors.Is(err, errExchangeNameMismatch) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameMismatch) - } + assert.ErrorIs(t, err, errExchangeNameMismatch) } func TestPTTrackFundingDetails(t *testing.T) { t.Parallel() p := &PositionTracker{} err := p.TrackFundingDetails(nil) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) cp := currency.NewPair(currency.BTC, currency.PERP) rates := &fundingrate.HistoricalRates{ @@ -1349,25 +1211,19 @@ func TestPTTrackFundingDetails(t *testing.T) { Pair: cp, } err = p.TrackFundingDetails(rates) - if !errors.Is(err, errDoesntMatch) { - t.Errorf("received '%v' expected '%v", err, errDoesntMatch) - } + assert.ErrorIs(t, err, errDoesntMatch) p.exchange = testExchange p.asset = asset.Futures p.contractPair = cp err = p.TrackFundingDetails(rates) - if !errors.Is(err, common.ErrDateUnset) { - t.Errorf("received '%v' expected '%v", err, common.ErrDateUnset) - } + assert.ErrorIs(t, err, common.ErrDateUnset) rates.StartDate = time.Now().Add(-time.Hour) rates.EndDate = time.Now() p.openingDate = rates.StartDate err = p.TrackFundingDetails(rates) - if !errors.Is(err, ErrNoPositionsFound) { - t.Errorf("received '%v' expected '%v", err, ErrNoPositionsFound) - } + assert.ErrorIs(t, err, ErrNoPositionsFound) p.pnlHistory = append(p.pnlHistory, PNLResult{ Time: rates.EndDate, @@ -1399,15 +1255,11 @@ func TestPTTrackFundingDetails(t *testing.T) { rates.Exchange = "" err = p.TrackFundingDetails(rates) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) p = nil err = p.TrackFundingDetails(rates) - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestAreFundingRatePrerequisitesMet(t *testing.T) { @@ -1428,19 +1280,13 @@ func TestAreFundingRatePrerequisitesMet(t *testing.T) { assert.NoError(t, err) err = CheckFundingRatePrerequisites(false, false, true) - if !errors.Is(err, ErrGetFundingDataRequired) { - t.Errorf("received '%v' expected '%v", err, ErrGetFundingDataRequired) - } + assert.ErrorIs(t, err, ErrGetFundingDataRequired) err = CheckFundingRatePrerequisites(false, true, true) - if !errors.Is(err, ErrGetFundingDataRequired) { - t.Errorf("received '%v' expected '%v", err, ErrGetFundingDataRequired) - } + assert.ErrorIs(t, err, ErrGetFundingDataRequired) err = CheckFundingRatePrerequisites(false, true, false) - if !errors.Is(err, ErrGetFundingDataRequired) { - t.Errorf("received '%v' expected '%v", err, ErrGetFundingDataRequired) - } + assert.ErrorIs(t, err, ErrGetFundingDataRequired) } func TestLastUpdated(t *testing.T) { @@ -1461,9 +1307,7 @@ func TestLastUpdated(t *testing.T) { } p = nil _, err = p.LastUpdated() - if !errors.Is(err, common.ErrNilPointer) { - t.Errorf("received '%v' expected '%v", err, common.ErrNilPointer) - } + assert.ErrorIs(t, err, common.ErrNilPointer) } func TestGetCurrencyForRealisedPNL(t *testing.T) { @@ -1483,18 +1327,15 @@ func TestGetCurrencyForRealisedPNL(t *testing.T) { func TestCheckTrackerPrerequisitesLowerExchange(t *testing.T) { t.Parallel() _, err := checkTrackerPrerequisitesLowerExchange("", asset.Spot, currency.EMPTYPAIR) - if !errors.Is(err, errExchangeNameEmpty) { - t.Errorf("received '%v' expected '%v", err, errExchangeNameEmpty) - } + assert.ErrorIs(t, err, errExchangeNameEmpty) + upperExch := "IM UPPERCASE" _, err = checkTrackerPrerequisitesLowerExchange(upperExch, asset.Spot, currency.EMPTYPAIR) - if !errors.Is(err, ErrNotFuturesAsset) { - t.Errorf("received '%v' expected '%v", err, ErrNotFuturesAsset) - } + assert.ErrorIs(t, err, ErrNotFuturesAsset) + _, err = checkTrackerPrerequisitesLowerExchange(upperExch, asset.Futures, currency.EMPTYPAIR) - if !errors.Is(err, order.ErrPairIsEmpty) { - t.Errorf("received '%v' expected '%v", err, order.ErrPairIsEmpty) - } + assert.ErrorIs(t, err, order.ErrPairIsEmpty) + lowerExch, err := checkTrackerPrerequisitesLowerExchange(upperExch, asset.Futures, currency.NewBTCUSDT()) assert.NoError(t, err) diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index fedf394c..d459a1fb 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -3,7 +3,6 @@ package gateio import ( "bytes" "context" - "errors" "fmt" "log" "os" @@ -340,15 +339,13 @@ func TestAmendSpotOrder(t *testing.T) { _, err := g.AmendSpotOrder(t.Context(), "", getPair(t, asset.Spot), false, &PriceAndAmount{ Price: 1000, }) - if !errors.Is(err, errInvalidOrderID) { - t.Errorf("expecting %v, but found %v", errInvalidOrderID, err) - } + assert.ErrorIs(t, err, errInvalidOrderID) + _, err = g.AmendSpotOrder(t.Context(), "123", currency.EMPTYPAIR, false, &PriceAndAmount{ Price: 1000, }) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Errorf("expecting %v, but found %v", currency.ErrCurrencyPairEmpty, err) - } + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) _, err = g.AmendSpotOrder(t.Context(), "123", getPair(t, asset.Spot), false, &PriceAndAmount{ Price: 1000, @@ -513,13 +510,13 @@ func TestRetriveOneSingleLoanDetail(t *testing.T) { func TestModifyALoan(t *testing.T) { t.Parallel() - if _, err := g.ModifyALoan(t.Context(), "1234", &ModifyLoanRequestParam{ + _, err := g.ModifyALoan(t.Context(), "1234", &ModifyLoanRequestParam{ Currency: currency.BTC, Side: "borrow", AutoRenew: false, - }); !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Errorf("%s ModifyALoan() error %v", g.Name, err) - } + }) + assert.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) if _, err := g.ModifyALoan(t.Context(), "1234", &ModifyLoanRequestParam{ Currency: currency.BTC, @@ -1608,15 +1605,13 @@ func TestGetUsersPositionSpecifiedUnderlying(t *testing.T) { func TestGetSpecifiedContractPosition(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, g) _, err := g.GetSpecifiedContractPosition(t.Context(), currency.EMPTYPAIR) - if err != nil && !errors.Is(err, errInvalidOrMissingContractParam) { - t.Errorf("%s GetSpecifiedContractPosition() error expecting %v, but found %v", g.Name, errInvalidOrMissingContractParam, err) - } + assert.ErrorIs(t, err, errInvalidOrMissingContractParam) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g) + _, err = g.GetSpecifiedContractPosition(t.Context(), getPair(t, asset.Options)) - if err != nil { - t.Errorf("%s GetSpecifiedContractPosition() error expecting %v, but found %v", g.Name, errInvalidOrMissingContractParam, err) - } + assert.NoError(t, err, "GetSpecifiedContractPosition should not error") } func TestGetUsersLiquidationHistoryForSpecifiedUnderlying(t *testing.T) { @@ -1661,13 +1656,13 @@ func TestCancelOptionOpenOrders(t *testing.T) { func TestGetSingleOptionOrder(t *testing.T) { t.Parallel() + _, err := g.GetSingleOptionOrder(t.Context(), "") + assert.ErrorIs(t, err, errInvalidOrderID) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - if _, err := g.GetSingleOptionOrder(t.Context(), ""); err != nil && !errors.Is(errInvalidOrderID, err) { - t.Errorf("%s GetSingleOptionorder() expecting %v, but found %v", g.Name, errInvalidOrderID, err) - } - if _, err := g.GetSingleOptionOrder(t.Context(), "1234"); err != nil { - t.Errorf("%s GetSingleOptionOrder() error %v", g.Name, err) - } + + _, err = g.GetSingleOptionOrder(t.Context(), "1234") + assert.NoError(t, err, "GetSingleOptionOrder should not error") } func TestCancelSingleOrder(t *testing.T) { @@ -1688,11 +1683,9 @@ func TestGetMyOptionsTradingHistory(t *testing.T) { func TestWithdrawCurrency(t *testing.T) { t.Parallel() - sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) _, err := g.WithdrawCurrency(t.Context(), WithdrawalRequestParam{}) - if err != nil && !errors.Is(err, errInvalidAmount) { - t.Errorf("%s WithdrawCurrency() expecting error %v, but found %v", g.Name, errInvalidAmount, err) - } + assert.ErrorIs(t, err, errInvalidAmount) + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) _, err = g.WithdrawCurrency(t.Context(), WithdrawalRequestParam{ Currency: currency.BTC, Amount: 0.00000001, @@ -2563,14 +2556,10 @@ func TestUpdateOrderExecutionLimits(t *testing.T) { testexch.UpdatePairsOnce(t, g) err := g.UpdateOrderExecutionLimits(t.Context(), 1336) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received %v, expected %v", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) err = g.UpdateOrderExecutionLimits(t.Context(), asset.Options) - if !errors.Is(err, common.ErrNotYetImplemented) { - t.Fatalf("received %v, expected %v", err, common.ErrNotYetImplemented) - } + require.ErrorIs(t, err, common.ErrNotYetImplemented) err = g.UpdateOrderExecutionLimits(t.Context(), asset.Spot) if err != nil { diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 1fba6697..20e18922 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -1,7 +1,6 @@ package gemini import ( - "errors" "net/url" "strings" "testing" @@ -1268,14 +1267,10 @@ func TestSetExchangeOrderExecutionLimits(t *testing.T) { t.Fatal(err) } err = g.UpdateOrderExecutionLimits(t.Context(), asset.Futures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatal(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) availPairs, err := g.GetAvailablePairs(asset.Spot) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for x := range availPairs { var limit order.MinMaxLevel limit, err = g.GetOrderExecutionLimits(asset.Spot, availPairs[x]) diff --git a/exchanges/kline/kline_test.go b/exchanges/kline/kline_test.go index cce581b5..63f6f95b 100644 --- a/exchanges/kline/kline_test.go +++ b/exchanges/kline/kline_test.go @@ -1,7 +1,6 @@ package kline import ( - "errors" "fmt" "math/rand" "os" @@ -89,9 +88,7 @@ func TestCreateKline(t *testing.T) { pair := currency.NewBTCUSD() _, err := CreateKline(nil, OneMin, pair, asset.Spot, "Binance") - if !errors.Is(err, errInsufficientTradeData) { - t.Fatalf("received: '%v' but expected '%v'", err, errInsufficientTradeData) - } + require.ErrorIs(t, err, errInsufficientTradeData) tradeTotal := 24000 trades := make([]order.TradeHistory, tradeTotal) @@ -107,9 +104,7 @@ func TestCreateKline(t *testing.T) { } _, err = CreateKline(trades, 0, pair, asset.Spot, "Binance") - if !errors.Is(err, ErrInvalidInterval) { - t.Fatalf("received: '%v' but expected '%v'", err, ErrInvalidInterval) - } + require.ErrorIs(t, err, ErrInvalidInterval) c, err := CreateKline(trades, OneMin, pair, asset.Spot, "Binance") if err != nil { @@ -801,14 +796,10 @@ func BenchmarkJustifyIntervalTimeStoringUnixValues2(b *testing.B) { func TestConvertToNewInterval(t *testing.T) { _, err := (*Item)(nil).ConvertToNewInterval(OneMin) - if !errors.Is(err, errNilKline) { - t.Errorf("received '%v' expected '%v'", err, errNilKline) - } + assert.ErrorIs(t, err, errNilKline) _, err = (&Item{}).ConvertToNewInterval(OneMin) - if !errors.Is(err, ErrInvalidInterval) { - t.Errorf("received '%v' expected '%v'", err, ErrInvalidInterval) - } + assert.ErrorIs(t, err, ErrInvalidInterval) old := &Item{ Exchange: "lol", @@ -844,18 +835,14 @@ func TestConvertToNewInterval(t *testing.T) { } _, err = old.ConvertToNewInterval(0) - if !errors.Is(err, ErrInvalidInterval) { - t.Errorf("received '%v' expected '%v'", err, ErrInvalidInterval) - } + assert.ErrorIs(t, err, ErrInvalidInterval) + _, err = old.ConvertToNewInterval(OneMin) - if !errors.Is(err, ErrCanOnlyUpscaleCandles) { - t.Errorf("received '%v' expected '%v'", err, ErrCanOnlyUpscaleCandles) - } + assert.ErrorIs(t, err, ErrCanOnlyUpscaleCandles) + old.Interval = ThreeDay _, err = old.ConvertToNewInterval(OneWeek) - if !errors.Is(err, ErrWholeNumberScaling) { - t.Errorf("received '%v' expected '%v'", err, ErrWholeNumberScaling) - } + assert.ErrorIs(t, err, ErrWholeNumberScaling) old.Interval = OneDay newInterval := ThreeDay @@ -889,9 +876,7 @@ func TestConvertToNewInterval(t *testing.T) { } _, err = old.ConvertToNewInterval(OneMonth) - if !errors.Is(err, ErrInsufficientCandleData) { - t.Errorf("received '%v' expected '%v'", err, ErrInsufficientCandleData) - } + assert.ErrorIs(t, err, ErrInsufficientCandleData) tn := time.Now().Truncate(time.Duration(OneDay)) @@ -946,9 +931,7 @@ func TestConvertToNewInterval(t *testing.T) { } _, err = old.ConvertToNewInterval(newInterval) - if !errors.Is(err, errCandleDataNotPadded) { - t.Errorf("received '%v' expected '%v'", err, errCandleDataNotPadded) - } + assert.ErrorIs(t, err, errCandleDataNotPadded) err = old.addPadding(tn, tn.AddDate(0, 0, 9), false) require.NoError(t, err) @@ -968,9 +951,7 @@ func TestAddPadding(t *testing.T) { var k *Item err := k.addPadding(tn, tn.AddDate(0, 0, 5), false) - if !errors.Is(err, errNilKline) { - t.Fatalf("received '%v' expected '%v'", err, errNilKline) - } + require.ErrorIs(t, err, errNilKline) k = &Item{} k.Candles = []Candle{ @@ -984,9 +965,7 @@ func TestAddPadding(t *testing.T) { }, } err = k.addPadding(tn, tn.AddDate(0, 0, 5), false) - if !errors.Is(err, ErrInvalidInterval) { - t.Fatalf("received '%v' expected '%v'", err, ErrInvalidInterval) - } + require.ErrorIs(t, err, ErrInvalidInterval) k.Interval = OneDay k.Candles = []Candle{ @@ -1008,9 +987,7 @@ func TestAddPadding(t *testing.T) { }, } err = k.addPadding(tn.AddDate(0, 0, 5), tn, false) - if !errors.Is(err, errCannotEstablishTimeWindow) { - t.Fatalf("received '%v' expected '%v'", err, errCannotEstablishTimeWindow) - } + require.ErrorIs(t, err, errCannotEstablishTimeWindow) k.Candles = []Candle{ { @@ -1040,9 +1017,7 @@ func TestAddPadding(t *testing.T) { } err = k.addPadding(tn, tn.AddDate(0, 0, 3), false) - if !errors.Is(err, errCandleOpenTimeIsNotUTCAligned) { - t.Fatalf("received '%v' expected '%v'", err, errCandleOpenTimeIsNotUTCAligned) - } + require.ErrorIs(t, err, errCandleOpenTimeIsNotUTCAligned) k.Candles = []Candle{ { @@ -1127,9 +1102,7 @@ func TestGetClosePriceAtTime(t *testing.T) { t.Errorf("received '%v' expected '%v'", price, 1337) } _, err = k.GetClosePriceAtTime(tt.Add(time.Minute)) - if !errors.Is(err, ErrNotFoundAtTime) { - t.Errorf("received '%v' expected '%v'", err, ErrNotFoundAtTime) - } + assert.ErrorIs(t, err, ErrNotFoundAtTime) } func TestDeployExchangeIntervals(t *testing.T) { @@ -1145,14 +1118,10 @@ func TestDeployExchangeIntervals(t *testing.T) { } _, err := exchangeIntervals.Construct(0) - if !errors.Is(err, ErrInvalidInterval) { - t.Errorf("received '%v' expected '%v'", err, ErrInvalidInterval) - } + assert.ErrorIs(t, err, ErrInvalidInterval) _, err = exchangeIntervals.Construct(OneMin) - if !errors.Is(err, ErrCannotConstructInterval) { - t.Errorf("received '%v' expected '%v'", err, ErrCannotConstructInterval) - } + assert.ErrorIs(t, err, ErrCannotConstructInterval) request, err := exchangeIntervals.Construct(OneWeek) assert.NoError(t, err) @@ -1214,16 +1183,12 @@ func TestGetIntervalResultLimit(t *testing.T) { var e *ExchangeCapabilitiesEnabled _, err := e.GetIntervalResultLimit(OneMin) - if !errors.Is(err, errExchangeCapabilitiesEnabledIsNil) { - t.Errorf("received '%v' expected '%v'", err, errExchangeCapabilitiesEnabledIsNil) - } + assert.ErrorIs(t, err, errExchangeCapabilitiesEnabledIsNil) e = &ExchangeCapabilitiesEnabled{} e.Intervals = ExchangeIntervals{} _, err = e.GetIntervalResultLimit(OneDay) - if !errors.Is(err, errIntervalNotSupported) { - t.Errorf("received '%v' expected '%v'", err, errIntervalNotSupported) - } + assert.ErrorIs(t, err, errIntervalNotSupported) e.Intervals = ExchangeIntervals{ supported: map[Interval]uint64{ @@ -1233,9 +1198,7 @@ func TestGetIntervalResultLimit(t *testing.T) { } _, err = e.GetIntervalResultLimit(OneMin) - if !errors.Is(err, errCannotFetchIntervalLimit) { - t.Errorf("received '%v' expected '%v'", err, errCannotFetchIntervalLimit) - } + assert.ErrorIs(t, err, errCannotFetchIntervalLimit) limit, err := e.GetIntervalResultLimit(OneDay) assert.NoError(t, err) diff --git a/exchanges/kline/request_test.go b/exchanges/kline/request_test.go index 9f010064..5080135c 100644 --- a/exchanges/kline/request_test.go +++ b/exchanges/kline/request_test.go @@ -1,7 +1,6 @@ package kline import ( - "errors" "sync" "testing" "time" @@ -15,53 +14,35 @@ import ( func TestCreateKlineRequest(t *testing.T) { t.Parallel() _, err := CreateKlineRequest("", currency.EMPTYPAIR, currency.EMPTYPAIR, 0, 0, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, ErrUnsetName) { - t.Fatalf("received: '%v', but expected '%v'", err, ErrUnsetName) - } + require.ErrorIs(t, err, ErrUnsetName) _, err = CreateKlineRequest("name", currency.EMPTYPAIR, currency.EMPTYPAIR, 0, 0, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v', but expected '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) pair := currency.NewBTCUSDT() _, err = CreateKlineRequest("name", pair, currency.EMPTYPAIR, 0, 0, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, currency.ErrCurrencyPairEmpty) { - t.Fatalf("received: '%v', but expected '%v'", err, currency.ErrCurrencyPairEmpty) - } + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) pair2 := pair.Upper() _, err = CreateKlineRequest("name", pair, pair2, 0, 0, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, asset.ErrNotSupported) { - t.Fatalf("received: '%v', but expected '%v'", err, asset.ErrNotSupported) - } + require.ErrorIs(t, err, asset.ErrNotSupported) _, err = CreateKlineRequest("name", pair, pair2, asset.Spot, 0, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, ErrInvalidInterval) { - t.Fatalf("received: '%v', but expected '%v'", err, ErrInvalidInterval) - } + require.ErrorIs(t, err, ErrInvalidInterval) _, err = CreateKlineRequest("name", pair, pair2, asset.Spot, OneHour, 0, time.Time{}, time.Time{}, 0) - if !errors.Is(err, ErrInvalidInterval) { - t.Fatalf("received: '%v', but expected '%v'", err, ErrInvalidInterval) - } + require.ErrorIs(t, err, ErrInvalidInterval) _, err = CreateKlineRequest("name", pair, pair2, asset.Spot, OneHour, OneMin, time.Time{}, time.Time{}, 0) - if !errors.Is(err, common.ErrDateUnset) { - t.Fatalf("received: '%v', but expected '%v'", err, common.ErrDateUnset) - } + require.ErrorIs(t, err, common.ErrDateUnset) start := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) _, err = CreateKlineRequest("name", pair, pair2, asset.Spot, OneHour, OneMin, start, time.Time{}, 0) - if !errors.Is(err, common.ErrDateUnset) { - t.Fatalf("received: '%v', but expected '%v'", err, common.ErrDateUnset) - } + require.ErrorIs(t, err, common.ErrDateUnset) end := start.AddDate(0, 0, 1) _, err = CreateKlineRequest("name", pair, pair2, asset.Spot, OneHour, OneMin, start, end, 0) - if !errors.Is(err, errInvalidSpecificEndpointLimit) { - t.Fatalf("received: '%v', but expected '%v'", err, errInvalidSpecificEndpointLimit) - } + require.ErrorIs(t, err, errInvalidSpecificEndpointLimit) r, err := CreateKlineRequest("name", pair, pair2, asset.Spot, OneHour, OneMin, start, end, 1) require.NoError(t, err) @@ -119,9 +100,7 @@ func TestGetRanges(t *testing.T) { var r *Request _, err := r.GetRanges(100) - if !errors.Is(err, errNilRequest) { - t.Fatalf("received: '%v', but expected '%v'", err, errNilRequest) - } + require.ErrorIs(t, err, errNilRequest) r, err = CreateKlineRequest("name", pair, pair, asset.Spot, OneHour, OneMin, start, end, 1) require.NoError(t, err) @@ -195,20 +174,14 @@ func TestRequest_ProcessResponse(t *testing.T) { var r *Request _, err := r.ProcessResponse(nil) - if !errors.Is(err, errNilRequest) { - t.Fatalf("received: '%v', but expected '%v'", err, errNilRequest) - } + require.ErrorIs(t, err, errNilRequest) r = &Request{} _, err = r.ProcessResponse(nil) - if !errors.Is(err, ErrNoTimeSeriesDataToConvert) { - t.Fatalf("received: '%v', but expected '%v'", err, ErrNoTimeSeriesDataToConvert) - } + require.ErrorIs(t, err, ErrNoTimeSeriesDataToConvert) _, err = CreateKlineRequest("name", pair, pair, asset.Spot, OneHour, OneHour, start, end, 0) - if !errors.Is(err, errInvalidSpecificEndpointLimit) { - t.Fatalf("received: '%v', but expected '%v'", err, nil) - } + require.ErrorIs(t, err, errInvalidSpecificEndpointLimit) // no conversion r, err = CreateKlineRequest("name", pair, pair, asset.Spot, OneHour, OneHour, start, end, 1) @@ -310,15 +283,11 @@ func TestExtendedRequest_ProcessResponse(t *testing.T) { var rExt *ExtendedRequest _, err := rExt.ProcessResponse(nil) - if !errors.Is(err, errNilRequest) { - t.Fatalf("received: '%v', but expected '%v'", err, errNilRequest) - } + require.ErrorIs(t, err, errNilRequest) rExt = &ExtendedRequest{} _, err = rExt.ProcessResponse(nil) - if !errors.Is(err, ErrNoTimeSeriesDataToConvert) { - t.Fatalf("received: '%v', but expected '%v'", err, ErrNoTimeSeriesDataToConvert) - } + require.ErrorIs(t, err, ErrNoTimeSeriesDataToConvert) // no conversion r, err := CreateKlineRequest("name", pair, pair, asset.Spot, OneHour, OneHour, start, end, 1) diff --git a/exchanges/kline/technical_analysis_test.go b/exchanges/kline/technical_analysis_test.go index 01e4c781..ac519bf4 100644 --- a/exchanges/kline/technical_analysis_test.go +++ b/exchanges/kline/technical_analysis_test.go @@ -1,7 +1,6 @@ package kline import ( - "errors" "testing" "github.com/stretchr/testify/require" @@ -19,38 +18,26 @@ func TestGetAverageTrueRange(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetAverageTrueRange(0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetAverageTrueRange(0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetAverageTrueRange(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.High = append(ohlc.High, 1337) _, err = ohlc.GetAverageTrueRange(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Low = append(ohlc.Low, 1337) _, err = ohlc.GetAverageTrueRange(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Close = append(ohlc.Close, 1337) _, err = ohlc.GetAverageTrueRange(9) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetAverageTrueRange(1) require.NoError(t, err) @@ -65,36 +52,24 @@ func TestGetBollingerBands(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetBollingerBands(0, 0, 0, 5) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetBollingerBands(0, 0, 0, 5) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetBollingerBands(9, 0, 0, 5) - if !errors.Is(err, errInvalidDeviationMultiplier) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidDeviationMultiplier) - } + require.ErrorIs(t, err, errInvalidDeviationMultiplier) _, err = ohlc.GetBollingerBands(9, 1, 0, 5) - if !errors.Is(err, errInvalidDeviationMultiplier) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidDeviationMultiplier) - } + require.ErrorIs(t, err, errInvalidDeviationMultiplier) _, err = ohlc.GetBollingerBands(9, 1, 1, 5) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Close = append(ohlc.Close, 1337, 1337, 1337, 1337, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetBollingerBands(10, 1, 1, 5) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetBollingerBands(9, 1, 1, 5) require.NoError(t, err) @@ -109,48 +84,32 @@ func TestGetCorrelationCoefficient(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetCorrelationCoefficient(nil, 0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetCorrelationCoefficient(nil, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetCorrelationCoefficient(nil, 1) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetCorrelationCoefficient(nil, 2) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) _, err = ohlc.GetCorrelationCoefficient(&OHLC{}, 9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Close = append(ohlc.Close, 1337, 1337) _, err = ohlc.GetCorrelationCoefficient(&OHLC{}, 9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) _, err = ohlc.GetCorrelationCoefficient(&OHLC{Close: []float64{1337}}, 2) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) ohlc.Close = append(ohlc.Close, 1337) _, err = ohlc.GetCorrelationCoefficient(&OHLC{Close: []float64{1337, 1337}}, 2) - if !errors.Is(err, errInvalidDataSetLengths) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidDataSetLengths) - } + require.ErrorIs(t, err, errInvalidDataSetLengths) _, err = ohlc.GetCorrelationCoefficient(&OHLC{Close: []float64{1337, 1337, 1337}}, 2) require.NoError(t, err) @@ -165,25 +124,17 @@ func TestGetSimpleMovingAverage(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetSimpleMovingAverage(nil, 0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetSimpleMovingAverage(nil, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetSimpleMovingAverage(nil, 9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) _, err = ohlc.GetSimpleMovingAverage([]float64{1337}, 9) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetSimpleMovingAverage([]float64{1337, 1337}, 2) require.NoError(t, err) @@ -198,25 +149,17 @@ func TestGetExponentialMovingAverage(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetExponentialMovingAverage(nil, 0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetExponentialMovingAverage(nil, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetExponentialMovingAverage(nil, 9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) _, err = ohlc.GetExponentialMovingAverage([]float64{1337}, 9) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetExponentialMovingAverage([]float64{1337, 1337, 1337}, 2) require.NoError(t, err) @@ -231,40 +174,26 @@ func TestGetMovingAverageConvergenceDivergence(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetMovingAverageConvergenceDivergence(nil, 0, 0, 0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetMovingAverageConvergenceDivergence(nil, 0, 0, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMovingAverageConvergenceDivergence(nil, 1, 0, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMovingAverageConvergenceDivergence(nil, 1, 1, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMovingAverageConvergenceDivergence(nil, 1, 2, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMovingAverageConvergenceDivergence(nil, 1, 2, 1) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) _, err = ohlc.GetMovingAverageConvergenceDivergence([]float64{1337}, 1, 2, 2) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotEnoughData) - } + require.ErrorIs(t, err, errNotEnoughData) _, err = ohlc.GetMovingAverageConvergenceDivergence([]float64{1337, 1337, 1337, 1337, 1337, 1337, 1337, 1337}, 1, 2, 1) require.NoError(t, err) @@ -279,50 +208,34 @@ func TestGetMoneyFlowIndex(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetMoneyFlowIndex(0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetMoneyFlowIndex(0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMoneyFlowIndex(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.High = append(ohlc.High, 1337, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetMoneyFlowIndex(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Low = append(ohlc.Low, 1337, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetMoneyFlowIndex(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Close = append(ohlc.Close, 1337, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetMoneyFlowIndex(9) - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Volume = append(ohlc.Volume, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetMoneyFlowIndex(5) - if !errors.Is(err, errInvalidDataSetLengths) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidDataSetLengths) - } + require.ErrorIs(t, err, errInvalidDataSetLengths) ohlc.Volume = append(ohlc.Volume, 1337) _, err = ohlc.GetMoneyFlowIndex(6) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetMoneyFlowIndex(3) require.NoError(t, err) @@ -341,21 +254,15 @@ func TestGetOnBalanceVolume(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetOnBalanceVolume() - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetOnBalanceVolume() - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Close = append(ohlc.Close, 1337, 1337, 1337, 1337, 1337, 1337) _, err = ohlc.GetOnBalanceVolume() - if !errors.Is(err, errNoData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoData) - } + require.ErrorIs(t, err, errNoData) ohlc.Volume = append(ohlc.Volume, 0.00000001) _, err = ohlc.GetOnBalanceVolume() @@ -371,25 +278,17 @@ func TestGetRelativeStrengthIndex(t *testing.T) { var ohlc *OHLC _, err := ohlc.GetRelativeStrengthIndex(nil, 0) - if !errors.Is(err, errNilOHLC) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNilOHLC) - } + require.ErrorIs(t, err, errNilOHLC) ohlc = &OHLC{} _, err = ohlc.GetRelativeStrengthIndex(nil, 0) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) _, err = ohlc.GetRelativeStrengthIndex(nil, 9) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNotEnoughData) - } + require.ErrorIs(t, err, errNotEnoughData) _, err = ohlc.GetRelativeStrengthIndex([]float64{1337, 1337, 1337}, 9) - if !errors.Is(err, errInvalidPeriod) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidPeriod) - } + require.ErrorIs(t, err, errInvalidPeriod) wrap := Item{Candles: []Candle{{Close: 1337}, {Close: 1337}, {Close: 1337}}} _, err = wrap.GetRelativeStrengthIndexOnClose(2) diff --git a/exchanges/kline/weighted_price_test.go b/exchanges/kline/weighted_price_test.go index acb8ab39..022b2c2b 100644 --- a/exchanges/kline/weighted_price_test.go +++ b/exchanges/kline/weighted_price_test.go @@ -1,12 +1,12 @@ package kline import ( - "errors" "math" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var accuracy10dp = 1 / math.Pow10(10) @@ -142,10 +142,8 @@ var expectVWAPs = []float64{245.05046666666664, 245.00156932123465, 245.07320400 func TestGetVWAPs(t *testing.T) { t.Parallel() candles := Item{} - if _, err := candles.GetVWAPs(); !errors.Is(err, errNoData) { - t.Fatal(err) - } - + _, err := candles.GetVWAPs() + require.ErrorIs(t, err, errNoData) candles.Candles = vwapdataset vwap, err := candles.GetVWAPs() assert.NoError(t, err, "GetVWAPs should not error") diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 4c707a0f..23af0682 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1511,13 +1511,10 @@ func TestWsOrderbookMax10Depth(t *testing.T) { func TestGetFuturesContractDetails(t *testing.T) { t.Parallel() _, err := k.GetFuturesContractDetails(t.Context(), asset.Spot) - if !errors.Is(err, futures.ErrNotFuturesAsset) { - t.Error(err) - } + assert.ErrorIs(t, err, futures.ErrNotFuturesAsset) + _, err = k.GetFuturesContractDetails(t.Context(), asset.USDTMarginedFutures) - if !errors.Is(err, asset.ErrNotSupported) { - t.Error(err) - } + assert.ErrorIs(t, err, asset.ErrNotSupported) _, err = k.GetFuturesContractDetails(t.Context(), asset.Futures) assert.NoError(t, err, "GetFuturesContractDetails should not error") @@ -1538,7 +1535,7 @@ func TestGetLatestFundingRates(t *testing.T) { assert.NoError(t, err, "GetLatestFundingRates should not error") err = k.CurrencyPairs.EnablePair(asset.Futures, futuresTestPair) - assert.True(t, err == nil || errors.Is(err, currency.ErrPairAlreadyEnabled), "EnablePair should not error") + assert.Truef(t, err == nil || errors.Is(err, currency.ErrPairAlreadyEnabled), "EnablePair should not error: %s", err) _, err = k.GetLatestFundingRates(t.Context(), &fundingrate.LatestRateRequest{ Asset: asset.Futures, Pair: futuresTestPair, diff --git a/exchanges/kucoin/kucoin_test.go b/exchanges/kucoin/kucoin_test.go index 891680e4..b461bced 100644 --- a/exchanges/kucoin/kucoin_test.go +++ b/exchanges/kucoin/kucoin_test.go @@ -2563,7 +2563,7 @@ func TestGetDepositAddress(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, ku) _, err := ku.GetDepositAddress(t.Context(), currency.BTC, "", "") - assert.True(t, err == nil || errors.Is(err, errNoDepositAddress), err) + assert.Truef(t, err == nil || errors.Is(err, errNoDepositAddress), "GetDepositAddress should not error: %s", err) } func TestWithdrawCryptocurrencyFunds(t *testing.T) { @@ -3944,7 +3944,7 @@ func TestGetMarginHFOrderDetailByOrderID(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, ku) _, err = ku.GetMarginHFOrderDetailByOrderID(t.Context(), "243432432423the-order-id", marginTradablePair.String()) - assert.True(t, errors.Is(err, order.ErrOrderNotFound) || err == nil) + assert.Truef(t, errors.Is(err, order.ErrOrderNotFound) || err == nil, "GetMarginHFOrderDetailByOrderID should not error: %s", err) } func TestGetMarginHFOrderDetailByClientOrderID(t *testing.T) { diff --git a/exchanges/okx/okx_test.go b/exchanges/okx/okx_test.go index 5aeee376..90aab7db 100644 --- a/exchanges/okx/okx_test.go +++ b/exchanges/okx/okx_test.go @@ -2267,7 +2267,7 @@ func TestSetLeverageRate(t *testing.T) { MarginMode: "cross", InstrumentID: perpetualSwapPair.String(), }) - assert.True(t, err == nil || errors.Is(err, common.ErrNoResponse)) + assert.Truef(t, err == nil || errors.Is(err, common.ErrNoResponse), "SetLeverageRate should not error: %s", err) } func TestGetMaximumBuySellAmountOROpenAmount(t *testing.T) { @@ -4346,9 +4346,6 @@ func TestGetCollateralMode(t *testing.T) { result, err := ok.GetCollateralMode(contextGenerate(), asset.Spot) assert.NoError(t, err) assert.NotNil(t, result) - - _, err = ok.GetCollateralMode(contextGenerate(), asset.Futures) - assert.True(t, err == nil || errors.Is(err, asset.ErrNotSupported)) } func TestSetCollateralMode(t *testing.T) { diff --git a/exchanges/orderbook/calculator_test.go b/exchanges/orderbook/calculator_test.go index 861ef254..67b1cbc0 100644 --- a/exchanges/orderbook/calculator_test.go +++ b/exchanges/orderbook/calculator_test.go @@ -1,7 +1,6 @@ package orderbook import ( - "errors" "math" "strings" "testing" @@ -31,9 +30,7 @@ func TestWhaleBomb(t *testing.T) { b := testSetup() _, err := b.WhaleBomb(-1, true) - if !errors.Is(err, errPriceTargetInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, errPriceTargetInvalid) - } + require.ErrorIs(t, err, errPriceTargetInvalid) result, err := b.WhaleBomb(7001, true) // <- This price should not be wiped out on the book. require.NoError(t, err) @@ -100,14 +97,10 @@ func TestWhaleBomb(t *testing.T) { } _, err = b.WhaleBomb(6000, true) - if !errors.Is(err, errCannotShiftPrice) { - t.Fatalf("received: '%v' but expected: '%v'", err, errCannotShiftPrice) - } + require.ErrorIs(t, err, errCannotShiftPrice) _, err = b.WhaleBomb(-1, false) - if !errors.Is(err, errPriceTargetInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, errPriceTargetInvalid) - } + require.ErrorIs(t, err, errPriceTargetInvalid) result, err = b.WhaleBomb(6998, false) // <- This price should not be wiped out on the book. require.NoError(t, err) @@ -174,9 +167,7 @@ func TestWhaleBomb(t *testing.T) { } _, err = b.WhaleBomb(7500, false) - if !errors.Is(err, errCannotShiftPrice) { - t.Fatalf("received: '%v' but expected: '%v'", err, errCannotShiftPrice) - } + require.ErrorIs(t, err, errCannotShiftPrice) } func TestSimulateOrder(t *testing.T) { @@ -185,14 +176,10 @@ func TestSimulateOrder(t *testing.T) { // Invalid _, err := b.SimulateOrder(-8000, true) - if !errors.Is(err, errQuoteAmountInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, errQuoteAmountInvalid) - } + require.ErrorIs(t, err, errQuoteAmountInvalid) _, err = (&Base{}).SimulateOrder(1337, true) - if !errors.Is(err, errNoLiquidity) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoLiquidity) - } + require.ErrorIs(t, err, errNoLiquidity) // Full liquidity used result, err := b.SimulateOrder(21002, true) @@ -327,14 +314,10 @@ func TestSimulateOrder(t *testing.T) { // Invalid _, err = (&Base{}).SimulateOrder(-1, false) - if !errors.Is(err, errBaseAmountInvalid) { - t.Fatalf("received: '%v' but expected: '%v'", err, errBaseAmountInvalid) - } + require.ErrorIs(t, err, errBaseAmountInvalid) _, err = (&Base{}).SimulateOrder(2, false) - if !errors.Is(err, errNoLiquidity) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoLiquidity) - } + require.ErrorIs(t, err, errNoLiquidity) // Full liquidity used result, err = b.SimulateOrder(3, false) @@ -466,48 +449,34 @@ func TestSimulateOrder(t *testing.T) { } func TestGetAveragePrice(t *testing.T) { - var b Base - b.Exchange = "Binance" - cp, err := currency.NewPairFromString("ETH-USDT") - if err != nil { - t.Error(err) + b := Base{ + Exchange: "Binance", + Pair: currency.NewBTCUSD(), } - b.Pair = cp - b.Bids = []Tranche{} - _, err = b.GetAveragePrice(false, 5) - if errors.Is(errNotEnoughLiquidity, err) { - t.Error("expected: %w, received %w", errNotEnoughLiquidity, err) - } - b = Base{} - b.Pair = cp - b.Asks = []Tranche{ - {Amount: 5, Price: 1}, - {Amount: 5, Price: 2}, - {Amount: 5, Price: 3}, - {Amount: 5, Price: 4}, + _, err := b.GetAveragePrice(false, 5) + assert.ErrorIs(t, err, errNotEnoughLiquidity) + + b = Base{ + Asks: []Tranche{ + {Amount: 5, Price: 1}, + {Amount: 5, Price: 2}, + {Amount: 5, Price: 3}, + {Amount: 5, Price: 4}, + }, } _, err = b.GetAveragePrice(true, -2) - if !errors.Is(err, errAmountInvalid) { - t.Errorf("expected: %v, received %v", errAmountInvalid, err) - } + assert.ErrorIs(t, err, errAmountInvalid) + avgPrice, err := b.GetAveragePrice(true, 15) - if err != nil { - t.Error(err) - } - if avgPrice != 2 { - t.Errorf("avg price calculation failed: expected 2, received %f", avgPrice) - } + require.NoError(t, err) + assert.Equal(t, 2.0, avgPrice) + avgPrice, err = b.GetAveragePrice(true, 18) - if err != nil { - t.Error(err) - } - if math.Round(avgPrice*1000)/1000 != 2.333 { - t.Errorf("avg price calculation failed: expected 2.333, received %f", math.Round(avgPrice*1000)/1000) - } + require.NoError(t, err) + assert.Equal(t, 2.333, math.Round(avgPrice*1000)/1000) + _, err = b.GetAveragePrice(true, 25) - if !errors.Is(err, errNotEnoughLiquidity) { - t.Errorf("expected: %v, received %v", errNotEnoughLiquidity, err) - } + assert.ErrorIs(t, err, errNotEnoughLiquidity) } func TestFindNominalAmount(t *testing.T) { diff --git a/exchanges/orderbook/orderbook_test.go b/exchanges/orderbook/orderbook_test.go index c2831d2b..2d84ebad 100644 --- a/exchanges/orderbook/orderbook_test.go +++ b/exchanges/orderbook/orderbook_test.go @@ -1,7 +1,6 @@ package orderbook import ( - "errors" "log" "math/rand" "os" @@ -62,77 +61,55 @@ func TestVerify(t *testing.T) { b.Asks = []Tranche{{ID: 1337, Price: 99, Amount: 1}, {ID: 1337, Price: 100, Amount: 1}} err = b.Verify() - if !errors.Is(err, errIDDuplication) { - t.Fatalf("expecting %s error but received %v", errIDDuplication, err) - } + require.ErrorIs(t, err, errIDDuplication) b.Asks = []Tranche{{Price: 100, Amount: 1}, {Price: 100, Amount: 1}} err = b.Verify() - if !errors.Is(err, errDuplication) { - t.Fatalf("expecting %s error but received %v", errDuplication, err) - } + require.ErrorIs(t, err, errDuplication) b.Asks = []Tranche{{Price: 100, Amount: 1}, {Price: 99, Amount: 1}} b.IsFundingRate = true err = b.Verify() - if !errors.Is(err, errPeriodUnset) { - t.Fatalf("expecting %s error but received %v", errPeriodUnset, err) - } + require.ErrorIs(t, err, errPeriodUnset) + b.IsFundingRate = false err = b.Verify() - if !errors.Is(err, errPriceOutOfOrder) { - t.Fatalf("expecting %s error but received %v", errPriceOutOfOrder, err) - } + require.ErrorIs(t, err, errPriceOutOfOrder) b.Asks = []Tranche{{Price: 100, Amount: 1}, {Price: 100, Amount: 0}} err = b.Verify() - if !errors.Is(err, errAmountInvalid) { - t.Fatalf("expecting %s error but received %v", errAmountInvalid, err) - } + require.ErrorIs(t, err, errAmountInvalid) b.Asks = []Tranche{{Price: 100, Amount: 1}, {Price: 0, Amount: 100}} err = b.Verify() - if !errors.Is(err, errPriceNotSet) { - t.Fatalf("expecting %s error but received %v", errPriceNotSet, err) - } + require.ErrorIs(t, err, errPriceNotSet) b.Bids = []Tranche{{ID: 1337, Price: 100, Amount: 1}, {ID: 1337, Price: 99, Amount: 1}} err = b.Verify() - if !errors.Is(err, errIDDuplication) { - t.Fatalf("expecting %s error but received %v", errIDDuplication, err) - } + require.ErrorIs(t, err, errIDDuplication) b.Bids = []Tranche{{Price: 100, Amount: 1}, {Price: 100, Amount: 1}} err = b.Verify() - if !errors.Is(err, errDuplication) { - t.Fatalf("expecting %s error but received %v", errDuplication, err) - } + require.ErrorIs(t, err, errDuplication) b.Bids = []Tranche{{Price: 99, Amount: 1}, {Price: 100, Amount: 1}} b.IsFundingRate = true err = b.Verify() - if !errors.Is(err, errPeriodUnset) { - t.Fatalf("expecting %s error but received %v", errPeriodUnset, err) - } + require.ErrorIs(t, err, errPeriodUnset) + b.IsFundingRate = false err = b.Verify() - if !errors.Is(err, errPriceOutOfOrder) { - t.Fatalf("expecting %s error but received %v", errPriceOutOfOrder, err) - } + require.ErrorIs(t, err, errPriceOutOfOrder) b.Bids = []Tranche{{Price: 100, Amount: 1}, {Price: 100, Amount: 0}} err = b.Verify() - if !errors.Is(err, errAmountInvalid) { - t.Fatalf("expecting %s error but received %v", errAmountInvalid, err) - } + require.ErrorIs(t, err, errAmountInvalid) b.Bids = []Tranche{{Price: 100, Amount: 1}, {Price: 0, Amount: 100}} err = b.Verify() - if !errors.Is(err, errPriceNotSet) { - t.Fatalf("expecting %s error but received %v", errPriceNotSet, err) - } + require.ErrorIs(t, err, errPriceNotSet) } func TestCalculateTotalBids(t *testing.T) { @@ -539,9 +516,7 @@ func TestSorting(t *testing.T) { b.Asks = deployUnorderedSlice() err := b.Verify() - if !errors.Is(err, errPriceOutOfOrder) { - t.Fatalf("error expected %v received %v", errPriceOutOfOrder, err) - } + require.ErrorIs(t, err, errPriceOutOfOrder) b.Asks.SortAsks() err = b.Verify() @@ -551,9 +526,7 @@ func TestSorting(t *testing.T) { b.Bids = deployUnorderedSlice() err = b.Verify() - if !errors.Is(err, errPriceOutOfOrder) { - t.Fatalf("error expected %v received %v", errPriceOutOfOrder, err) - } + require.ErrorIs(t, err, errPriceOutOfOrder) b.Bids.SortBids() err = b.Verify() @@ -677,19 +650,14 @@ func TestCheckAlignment(t *testing.T) { t.Error(err) } err = checkAlignment(itemWithFunding, false, true, false, false, dsc, "Bitfinex") - if !errors.Is(err, errPriceNotSet) { - t.Fatalf("received: %v but expected: %v", err, errPriceNotSet) - } + require.ErrorIs(t, err, errPriceNotSet) + err = checkAlignment(itemWithFunding, true, true, false, false, dsc, "Binance") - if !errors.Is(err, errPriceNotSet) { - t.Fatalf("received: %v but expected: %v", err, errPriceNotSet) - } + require.ErrorIs(t, err, errPriceNotSet) itemWithFunding[0].Price = 1337 err = checkAlignment(itemWithFunding, true, true, false, true, dsc, "Binance") - if !errors.Is(err, errChecksumStringNotSet) { - t.Fatalf("received: %v but expected: %v", err, errChecksumStringNotSet) - } + require.ErrorIs(t, err, errChecksumStringNotSet) itemWithFunding[0].StrAmount = "1337.0000000" itemWithFunding[0].StrPrice = "1337.0000000" diff --git a/exchanges/orderbook/tranches_test.go b/exchanges/orderbook/tranches_test.go index 72a70b31..b89a6d83 100644 --- a/exchanges/orderbook/tranches_test.go +++ b/exchanges/orderbook/tranches_test.go @@ -1,7 +1,6 @@ package orderbook import ( - "errors" "fmt" "testing" "time" @@ -320,9 +319,7 @@ func TestUpdateByID(t *testing.T) { err = a.updateByID(Tranches{ {Price: 11, Amount: 1, ID: 1337}, }) - if !errors.Is(err, errIDCannotBeMatched) { - t.Fatalf("expecting %s but received %v", errIDCannotBeMatched, err) - } + require.ErrorIs(t, err, errIDCannotBeMatched) err = a.updateByID(Tranches{ // Simulate Bitmex updating {Price: 0, Amount: 1337, ID: 3}, @@ -402,9 +399,7 @@ func TestDeleteByID(t *testing.T) { // Intentional error err = a.deleteByID(Tranches{{Price: 11, Amount: 1, ID: 1337}}, false) - if !errors.Is(err, errIDCannotBeMatched) { - t.Fatalf("expecting %s but received %v", errIDCannotBeMatched, err) - } + require.ErrorIs(t, err, errIDCannotBeMatched) // Error bypass err = a.deleteByID(Tranches{{Price: 11, Amount: 1, ID: 1337}}, true) @@ -1001,9 +996,7 @@ func TestInsertUpdatesBid(t *testing.T) { {Price: 3, Amount: 1, ID: 3}, {Price: 1, Amount: 1, ID: 1}, }) - if !errors.Is(err, errCollisionDetected) { - t.Fatalf("expected error %s but received %v", errCollisionDetected, err) - } + require.ErrorIs(t, err, errCollisionDetected) Check(t, b, 6, 36, 6) @@ -1063,9 +1056,7 @@ func TestInsertUpdatesAsk(t *testing.T) { {Price: 3, Amount: 1, ID: 3}, {Price: 1, Amount: 1, ID: 1}, }) - if !errors.Is(err, errCollisionDetected) { - t.Fatalf("expected error %s but received %v", errCollisionDetected, err) - } + require.ErrorIs(t, err, errCollisionDetected) Check(t, a, 6, 36, 6) @@ -1260,9 +1251,7 @@ func TestGetMovementByBaseAmount(t *testing.T) { t.Fatal(err) } movement, err := depth.bidTranches.getMovementByBase(tt.BaseAmount, tt.ReferencePrice, false) - if !errors.Is(err, tt.ExpectedError) { - t.Fatalf("received: '%v' but expected: '%v'", err, tt.ExpectedError) - } + require.ErrorIs(t, err, tt.ExpectedError) if movement == nil { return @@ -1501,9 +1490,8 @@ func TestGetBaseAmountFromImpact(t *testing.T) { t.Fatal(err) } base, err := depth.bidTranches.hitBidsByImpactSlippage(tt.ImpactSlippage, tt.ReferencePrice) - if !errors.Is(err, tt.ExpectedError) { - t.Fatalf("%s received: '%v' but expected: '%v'", tt.Name, err, tt.ExpectedError) - } + require.ErrorIs(t, err, tt.ExpectedError) + if !base.IsEqual(tt.ExpectedShift) { t.Fatalf("%s quote received: '%+v' but expected: '%+v'", tt.Name, base, tt.ExpectedShift) @@ -1586,9 +1574,7 @@ func TestGetMovementByQuoteAmount(t *testing.T) { t.Fatal(err) } movement, err := depth.askTranches.getMovementByQuotation(tt.QuoteAmount, tt.ReferencePrice, false) - if !errors.Is(err, tt.ExpectedError) { - t.Fatalf("received: '%v' but expected: '%v'", err, tt.ExpectedError) - } + require.ErrorIs(t, err, tt.ExpectedError) if movement == nil { return @@ -1817,16 +1803,13 @@ func TestGetQuoteAmountFromImpact(t *testing.T) { func TestGetHeadPrice(t *testing.T) { t.Parallel() depth := NewDepth(id) - if _, err := depth.bidTranches.getHeadPriceNoLock(); !errors.Is(err, errNoLiquidity) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoLiquidity) - } - if _, err := depth.askTranches.getHeadPriceNoLock(); !errors.Is(err, errNoLiquidity) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoLiquidity) - } - err := depth.LoadSnapshot(bid, ask, 0, time.Now(), time.Now(), true) - if err != nil { - t.Fatalf("failed to load snapshot: %s", err) - } + _, err := depth.bidTranches.getHeadPriceNoLock() + require.ErrorIs(t, err, errNoLiquidity) + _, err = depth.askTranches.getHeadPriceNoLock() + require.ErrorIs(t, err, errNoLiquidity) + + err = depth.LoadSnapshot(bid, ask, 0, time.Now(), time.Now(), true) + require.NoError(t, err, "LoadSnapshot must not error") val, err := depth.bidTranches.getHeadPriceNoLock() require.NoError(t, err) diff --git a/exchanges/poloniex/currency_details_test.go b/exchanges/poloniex/currency_details_test.go index 65215d1d..e13642d6 100644 --- a/exchanges/poloniex/currency_details_test.go +++ b/exchanges/poloniex/currency_details_test.go @@ -1,7 +1,6 @@ package poloniex import ( - "errors" "testing" "github.com/stretchr/testify/require" @@ -16,54 +15,34 @@ func TestWsCurrencyMap(t *testing.T) { } err := m.loadPairs(nil) - if !errors.Is(err, errCannotLoadNoData) { - t.Fatalf("expected: %v but received: %v", errCannotLoadNoData, err) - } + require.ErrorIs(t, err, errCannotLoadNoData) err = m.loadCodes(nil) - if !errors.Is(err, errCannotLoadNoData) { - t.Fatalf("expected: %v but received: %v", errCannotLoadNoData, err) - } + require.ErrorIs(t, err, errCannotLoadNoData) _, err = m.GetPair(1337) - if !errors.Is(err, errPairMapIsNil) { - t.Fatalf("expected: %v but received: %v", errPairMapIsNil, err) - } + require.ErrorIs(t, err, errPairMapIsNil) _, err = m.GetCode(1337) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.GetWithdrawalTXFee(currency.EMPTYCODE) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.GetDepositAddress(currency.EMPTYCODE) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.IsWithdrawAndDepositsEnabled(currency.EMPTYCODE) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.IsTradingEnabledForCurrency(currency.EMPTYCODE) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.IsTradingEnabledForPair(currency.EMPTYPAIR) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) _, err = m.IsPostOnlyForPair(currency.EMPTYPAIR) - if !errors.Is(err, errCodeMapIsNil) { - t.Fatalf("expected: %v but received: %v", errCodeMapIsNil, err) - } + require.ErrorIs(t, err, errCodeMapIsNil) c, err := p.GetCurrencies(t.Context()) if err != nil { @@ -86,18 +65,14 @@ func TestWsCurrencyMap(t *testing.T) { } pTest, err := m.GetPair(1337) - if !errors.Is(err, errIDNotFoundInPairMap) { - t.Fatalf("expected: %v but received: %v", errIDNotFoundInPairMap, err) - } + require.ErrorIs(t, err, errIDNotFoundInPairMap) if pTest.String() != "1337" { t.Fatal("unexpected value") } _, err = m.GetCode(1337) - if !errors.Is(err, errIDNotFoundInCodeMap) { - t.Fatalf("expected: %v but received: %v", errIDNotFoundInCodeMap, err) - } + require.ErrorIs(t, err, errIDNotFoundInCodeMap) btcusdt, err := m.GetPair(121) require.NoError(t, err) @@ -123,9 +98,7 @@ func TestWsCurrencyMap(t *testing.T) { } _, err = m.GetDepositAddress(eth) - if !errors.Is(err, errNoDepositAddress) { - t.Fatalf("expected: %v but received: %v", errNoDepositAddress, err) - } + require.ErrorIs(t, err, errNoDepositAddress) dAddr, err := m.GetDepositAddress(currency.NewCode("BCN")) require.NoError(t, err) @@ -165,32 +138,20 @@ func TestWsCurrencyMap(t *testing.T) { } _, err = m.GetWithdrawalTXFee(currency.EMPTYCODE) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) _, err = m.GetDepositAddress(currency.EMPTYCODE) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) _, err = m.IsWithdrawAndDepositsEnabled(currency.EMPTYCODE) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) _, err = m.IsTradingEnabledForCurrency(currency.EMPTYCODE) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) _, err = m.IsTradingEnabledForPair(currency.EMPTYPAIR) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) _, err = m.IsPostOnlyForPair(currency.EMPTYPAIR) - if !errors.Is(err, errCurrencyNotFoundInMap) { - t.Fatalf("expected: %v but received: %v", errCurrencyNotFoundInMap, err) - } + require.ErrorIs(t, err, errCurrencyNotFoundInMap) } diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index 8c37e83c..d19f73a0 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -1,7 +1,6 @@ package poloniex import ( - "errors" "net/http" "strings" "testing" @@ -699,27 +698,19 @@ func TestProcessAccountMarginPosition(t *testing.T) { margin := []byte(`[1000,"",[["m", 23432933, 28, "-0.06000000"]]]`) err = p.wsHandleData(margin) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) margin = []byte(`[1000,"",[["m", "23432933", 28, "-0.06000000", null]]]`) err = p.wsHandleData(margin) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) margin = []byte(`[1000,"",[["m", 23432933, "28", "-0.06000000", null]]]`) err = p.wsHandleData(margin) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) margin = []byte(`[1000,"",[["m", 23432933, 28, -0.06000000, null]]]`) err = p.wsHandleData(margin) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) margin = []byte(`[1000,"",[["m", 23432933, 28, "-0.06000000", null]]]`) err = p.wsHandleData(margin) @@ -736,39 +727,27 @@ func TestProcessAccountPendingOrder(t *testing.T) { pending := []byte(`[1000,"",[["p",431682155857,127,"1000.00000000","1.00000000","0"]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) pending = []byte(`[1000,"",[["p","431682155857",127,"1000.00000000","1.00000000","0",null]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) pending = []byte(`[1000,"",[["p",431682155857,"127","1000.00000000","1.00000000","0",null]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) pending = []byte(`[1000,"",[["p",431682155857,127,1000.00000000,"1.00000000","0",null]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) pending = []byte(`[1000,"",[["p",431682155857,127,"1000.00000000",1.00000000,"0",null]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) pending = []byte(`[1000,"",[["p",431682155857,127,"1000.00000000","1.00000000",0,null]]]`) err = p.wsHandleData(pending) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) pending = []byte(`[1000,"",[["p",431682155857,127,"1000.00000000","1.00000000","0",null]]]`) err = p.wsHandleData(pending) @@ -787,33 +766,23 @@ func TestProcessAccountPendingOrder(t *testing.T) { func TestProcessAccountOrderUpdate(t *testing.T) { orderUpdate := []byte(`[1000,"",[["o",431682155857,"0.00000000","f"]]]`) err := p.wsHandleData(orderUpdate) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) orderUpdate = []byte(`[1000,"",[["o","431682155857","0.00000000","f",null]]]`) err = p.wsHandleData(orderUpdate) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) orderUpdate = []byte(`[1000,"",[["o",431682155857,0.00000000,"f",null]]]`) err = p.wsHandleData(orderUpdate) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) orderUpdate = []byte(`[1000,"",[["o",431682155857,"0.00000000",123,null]]]`) err = p.wsHandleData(orderUpdate) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) orderUpdate = []byte(`[1000,"",[["o",431682155857,"0.00000000","c",null]]]`) err = p.wsHandleData(orderUpdate) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) orderUpdate = []byte(`[1000,"",[["o",431682155857,"0.50000000","c",null,"0.50000000"]]]`) err = p.wsHandleData(orderUpdate) @@ -848,51 +817,35 @@ func TestProcessAccountOrderLimit(t *testing.T) { accountTrade := []byte(`[1000,"",[["n",127,431682155857,"0","1000.00000000","1.00000000","2021-04-13 07:19:56","1.00000000"]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) accountTrade = []byte(`[1000,"",[["n","127",431682155857,"0","1000.00000000","1.00000000","2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,"431682155857","0","1000.00000000","1.00000000","2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,0,"1000.00000000","1.00000000","2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,"0",1000.00000000,"1.00000000","2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,"0","1000.00000000",1.00000000,"2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,"0","1000.00000000","1.00000000",1234,"1.00000000",null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,"0","1000.00000000","1.00000000","2021-04-13 07:19:56",1.00000000,null]]]`) err = p.wsHandleData(accountTrade) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrade = []byte(`[1000,"",[["n",127,431682155857,"0","1000.00000000","1.00000000","2021-04-13 07:19:56","1.00000000",null]]]`) err = p.wsHandleData(accountTrade) @@ -909,27 +862,19 @@ func TestProcessAccountBalanceUpdate(t *testing.T) { balance := []byte(`[1000,"",[["b",243,"e"]]]`) err = p.wsHandleData(balance) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) balance = []byte(`[1000,"",[["b","243","e","-1.00000000"]]]`) err = p.wsHandleData(balance) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) balance = []byte(`[1000,"",[["b",243,1234,"-1.00000000"]]]`) err = p.wsHandleData(balance) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) balance = []byte(`[1000,"",[["b",243,"e",-1.00000000]]]`) err = p.wsHandleData(balance) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) balance = []byte(`[1000,"",[["b",243,"e","-1.00000000"]]]`) err = p.wsHandleData(balance) @@ -941,45 +886,31 @@ func TestProcessAccountBalanceUpdate(t *testing.T) { func TestProcessAccountTrades(t *testing.T) { accountTrades := []byte(`[1000,"",[["t", 12345, "0.03000000", "0.50000000", "0.00250000", 0, 6083059, "0.00000375", "2018-09-08 05:54:09", "12345"]]]`) err := p.wsHandleData(accountTrades) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) accountTrades = []byte(`[1000,"",[["t", "12345", "0.03000000", "0.50000000", "0.00250000", 0, 6083059, "0.00000375", "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, 0.03000000, "0.50000000", "0.00250000", 0, 6083059, "0.00000375", "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, "0.03000000", 0.50000000, "0.00250000", 0, 6083059, "0.00000375", "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, "0.03000000", "0.50000000", "0.00250000", 0, 6083059, 0.00000375, "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, "0.03000000", "0.50000000", "0.00250000", 0, 6083059, 0.0000037, "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, "0.03000000", "0.50000000", "0.00250000", 0, 6083059, "0.00000375", 12345, "12345", 0.015]]]`) err = p.wsHandleData(accountTrades) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) accountTrades = []byte(`[1000,"",[["t", 12345, "0.03000000", "0.50000000", "0.00250000", 0, 6083059, "0.00000375", "2018-09-08 05:54:09", "12345", "0.015"]]]`) err = p.wsHandleData(accountTrades) @@ -991,15 +922,11 @@ func TestProcessAccountTrades(t *testing.T) { func TestProcessAccountKilledOrder(t *testing.T) { kill := []byte(`[1000,"",[["k", 1337]]]`) err := p.wsHandleData(kill) - if !errors.Is(err, errNotEnoughData) { - t.Fatalf("expected: %v but received: %v", errNotEnoughData, err) - } + require.ErrorIs(t, err, errNotEnoughData) kill = []byte(`[1000,"",[["k", "1337", null]]]`) err = p.wsHandleData(kill) - if !errors.Is(err, errTypeAssertionFailure) { - t.Fatalf("expected: %v but received: %v", errTypeAssertionFailure, err) - } + require.ErrorIs(t, err, errTypeAssertionFailure) kill = []byte(`[1000,"",[["k", 1337, null]]]`) err = p.wsHandleData(kill) diff --git a/exchanges/request/client_test.go b/exchanges/request/client_test.go index 6850eaaf..04ce532e 100644 --- a/exchanges/request/client_test.go +++ b/exchanges/request/client_test.go @@ -1,7 +1,6 @@ package request import ( - "errors" "net/http" "net/url" "slices" @@ -22,9 +21,7 @@ func (c *clientTracker) contains(check *http.Client) bool { func TestCheckAndRegister(t *testing.T) { t.Parallel() err := tracker.checkAndRegister(nil) - if !errors.Is(err, errHTTPClientIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) - } + require.ErrorIs(t, err, errHTTPClientIsNil) newLovelyClient := new(http.Client) err = tracker.checkAndRegister(newLovelyClient) @@ -35,23 +32,17 @@ func TestCheckAndRegister(t *testing.T) { } err = tracker.checkAndRegister(newLovelyClient) - if !errors.Is(err, errCannotReuseHTTPClient) { - t.Fatalf("received: '%v' but expected: '%v'", err, errCannotReuseHTTPClient) - } + require.ErrorIs(t, err, errCannotReuseHTTPClient) } func TestDeRegister(t *testing.T) { t.Parallel() err := tracker.deRegister(nil) - if !errors.Is(err, errHTTPClientIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) - } + require.ErrorIs(t, err, errHTTPClientIsNil) newLovelyClient := new(http.Client) err = tracker.deRegister(newLovelyClient) - if !errors.Is(err, errHTTPClientNotFound) { - t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientNotFound) - } + require.ErrorIs(t, err, errHTTPClientNotFound) err = tracker.checkAndRegister(newLovelyClient) require.NoError(t, err) @@ -70,9 +61,8 @@ func TestDeRegister(t *testing.T) { func TestNewProtectedClient(t *testing.T) { t.Parallel() - if _, err := newProtectedClient(nil); !errors.Is(err, errHTTPClientIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errHTTPClientIsNil) - } + _, err := newProtectedClient(nil) + require.ErrorIs(t, err, errHTTPClientIsNil) newLovelyClient := new(http.Client) protec, err := newProtectedClient(newLovelyClient) @@ -86,17 +76,15 @@ func TestNewProtectedClient(t *testing.T) { func TestClientSetProxy(t *testing.T) { t.Parallel() err := (&client{}).setProxy(nil) - if !errors.Is(err, errNoProxyURLSupplied) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoProxyURLSupplied) - } + require.ErrorIs(t, err, errNoProxyURLSupplied) + pp, err := url.Parse("lol.com") if err != nil { t.Fatal(err) } err = (&client{protected: new(http.Client)}).setProxy(pp) - if !errors.Is(err, errTransportNotSet) { - t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet) - } + require.ErrorIs(t, err, errTransportNotSet) + err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setProxy(pp) require.NoError(t, err) } @@ -104,9 +92,8 @@ func TestClientSetProxy(t *testing.T) { func TestClientSetHTTPClientTimeout(t *testing.T) { t.Parallel() err := (&client{protected: new(http.Client)}).setHTTPClientTimeout(time.Second) - if !errors.Is(err, errTransportNotSet) { - t.Fatalf("received: '%v' but expected: '%v'", err, errTransportNotSet) - } + require.ErrorIs(t, err, errTransportNotSet) + err = (&client{protected: common.NewHTTPClientWithTimeout(0)}).setHTTPClientTimeout(time.Second) require.NoError(t, err) } diff --git a/exchanges/request/request_test.go b/exchanges/request/request_test.go index 266aae99..7af7a1fd 100644 --- a/exchanges/request/request_test.go +++ b/exchanges/request/request_test.go @@ -390,9 +390,7 @@ func TestDoRequest_RetryNonRecoverable(t *testing.T) { Path: testURL + "/always-retry", }, nil }, UnauthenticatedRequest) - if !errors.Is(err, errFailedToRetryRequest) { - t.Fatalf("received: %v but expected: %v", err, errFailedToRetryRequest) - } + require.ErrorIs(t, err, errFailedToRetryRequest) } func TestDoRequest_NotRetryable(t *testing.T) { @@ -415,9 +413,7 @@ func TestDoRequest_NotRetryable(t *testing.T) { Path: testURL + "/always-retry", }, nil }, UnauthenticatedRequest) - if !errors.Is(err, notRetryErr) { - t.Fatalf("received: %v but expected: %v", err, notRetryErr) - } + require.ErrorIs(t, err, notRetryErr) } func TestGetNonce(t *testing.T) { @@ -457,9 +453,8 @@ func TestSetProxy(t *testing.T) { t.Parallel() var r *Requester err := r.SetProxy(nil) - if !errors.Is(err, ErrRequestSystemIsNil) { - t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, ErrRequestSystemIsNil) + r, err = New("test", &http.Client{Transport: new(http.Transport)}, WithLimiter(globalshell)) if err != nil { t.Fatal(err) @@ -506,9 +501,7 @@ func TestBasicLimiter(t *testing.T) { ctx, cancel := context.WithDeadline(ctx, tn.Add(time.Nanosecond)) defer cancel() err = r.SendPayload(ctx, Unset, func() (*Item, error) { return &i, nil }, UnauthenticatedRequest) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("received: %v but expected: %v", err, context.DeadlineExceeded) - } + require.ErrorIs(t, err, context.DeadlineExceeded) } func TestEnableDisableRateLimit(t *testing.T) { @@ -554,26 +547,22 @@ func TestEnableDisableRateLimit(t *testing.T) { func TestSetHTTPClient(t *testing.T) { var r *Requester err := r.SetHTTPClient(nil) - if !errors.Is(err, ErrRequestSystemIsNil) { - t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, ErrRequestSystemIsNil) + client := new(http.Client) r = new(Requester) err = r.SetHTTPClient(client) require.NoError(t, err) err = r.SetHTTPClient(client) - if !errors.Is(err, errCannotReuseHTTPClient) { - t.Fatalf("received: '%v', but expected: '%v'", err, errCannotReuseHTTPClient) - } + require.ErrorIs(t, err, errCannotReuseHTTPClient) } func TestSetHTTPClientTimeout(t *testing.T) { var r *Requester err := r.SetHTTPClientTimeout(0) - if !errors.Is(err, ErrRequestSystemIsNil) { - t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, ErrRequestSystemIsNil) + r = new(Requester) err = r.SetHTTPClient(common.NewHTTPClientWithTimeout(2)) if err != nil { @@ -586,9 +575,8 @@ func TestSetHTTPClientTimeout(t *testing.T) { func TestSetHTTPClientUserAgent(t *testing.T) { var r *Requester err := r.SetHTTPClientUserAgent("") - if !errors.Is(err, ErrRequestSystemIsNil) { - t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, ErrRequestSystemIsNil) + r = new(Requester) err = r.SetHTTPClientUserAgent("") require.NoError(t, err) @@ -597,9 +585,8 @@ func TestSetHTTPClientUserAgent(t *testing.T) { func TestGetHTTPClientUserAgent(t *testing.T) { var r *Requester _, err := r.GetHTTPClientUserAgent() - if !errors.Is(err, ErrRequestSystemIsNil) { - t.Fatalf("received: '%v', but expected: '%v'", err, ErrRequestSystemIsNil) - } + require.ErrorIs(t, err, ErrRequestSystemIsNil) + r = new(Requester) err = r.SetHTTPClientUserAgent("sillyness") require.NoError(t, err) diff --git a/exchanges/ticker/ticker_test.go b/exchanges/ticker/ticker_test.go index ce1ee819..08e384bc 100644 --- a/exchanges/ticker/ticker_test.go +++ b/exchanges/ticker/ticker_test.go @@ -1,7 +1,6 @@ package ticker import ( - "errors" "log" "math/rand" "os" @@ -293,9 +292,7 @@ func TestProcessTicker(t *testing.T) { // non-appending function to tickers Bid: 1338, Ask: 1336, }) - if !errors.Is(err, errBidGreaterThanAsk) { - t.Errorf("received: %v but expected: %v", err, errBidGreaterThanAsk) - } + assert.ErrorIs(t, err, errBidGreaterThanAsk) err = ProcessTicker(&Price{ ExchangeName: "Bitfinex", diff --git a/gctscript/modules/gct/errors_test.go b/gctscript/modules/gct/errors_test.go index 8c72dcaf..1f2a8e2e 100644 --- a/gctscript/modules/gct/errors_test.go +++ b/gctscript/modules/gct/errors_test.go @@ -1,23 +1,19 @@ package gct import ( - "errors" "testing" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common" ) func TestErrorResponse(t *testing.T) { t.Parallel() _, err := errorResponsef("") - if !errors.Is(err, errFormatStringIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errFormatStringIsEmpty) - } + require.ErrorIs(t, err, errFormatStringIsEmpty) _, err = errorResponsef("--") - if !errors.Is(err, errNoArguments) { - t.Fatalf("received: '%v' but expected: '%v'", err, errNoArguments) - } + require.ErrorIs(t, err, errNoArguments) errResp, err := errorResponsef("error %s", "hello") if err != nil { @@ -32,7 +28,5 @@ func TestErrorResponse(t *testing.T) { func TestConstructRuntimeError(t *testing.T) { t.Parallel() err := constructRuntimeError(0, "", "", nil) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) } diff --git a/gctscript/modules/gct/gct_test.go b/gctscript/modules/gct/gct_test.go index 219419aa..2c1121d6 100644 --- a/gctscript/modules/gct/gct_test.go +++ b/gctscript/modules/gct/gct_test.go @@ -1,7 +1,6 @@ package gct import ( - "errors" "os" "reflect" "testing" @@ -42,9 +41,8 @@ var ( Value: "", } - tv = objects.TrueValue - fv = objects.FalseValue - errTestFailed = errors.New("test failed") + tv = objects.TrueValue + fv = objects.FalseValue ) func TestMain(m *testing.M) { @@ -55,37 +53,19 @@ func TestMain(m *testing.M) { func TestExchangeOrderbook(t *testing.T) { t.Parallel() _, err := ExchangeOrderbook(ctx, exch, currencyPair, delimiter, assetType) - if err != nil { - t.Error(err) - } - - _, err = ExchangeOrderbook(exchError, currencyPair, delimiter, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeOrderbook() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangeTicker(t *testing.T) { t.Parallel() _, err := ExchangeTicker(ctx, exch, currencyPair, delimiter, assetType) - if err != nil { - t.Error(err) - } - - _, err = ExchangeTicker(exchError, currencyPair, delimiter, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeTicker() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangeExchanges(t *testing.T) { @@ -107,74 +87,52 @@ func TestExchangeExchanges(t *testing.T) { } _, err = ExchangeExchanges() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangePairs(t *testing.T) { t.Parallel() _, err := ExchangePairs(exch, tv, assetType) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangePairs(exchError, tv, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangePairs() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestAccountInfo(t *testing.T) { t.Parallel() _, err := ExchangeAccountInfo() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = ExchangeAccountInfo(ctx, exch, assetType) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeAccountInfo(ctx, exchError, assetType) - if err != nil && !errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeOrderQuery(t *testing.T) { t.Parallel() _, err := ExchangeOrderQuery() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = ExchangeOrderQuery(ctx, exch, orderID) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeOrderQuery(ctx, exchError, orderID) - if err != nil && !errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeOrderCancel(t *testing.T) { t.Parallel() _, err := ExchangeOrderCancel() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = ExchangeOrderCancel(blank, orderID, currencyPair, assetType) if err == nil { @@ -205,9 +163,7 @@ func TestExchangeOrderCancel(t *testing.T) { func TestExchangeOrderSubmit(t *testing.T) { t.Parallel() _, err := ExchangeOrderSubmit() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) orderSide := &objects.String{Value: "ASK"} orderType := &objects.String{Value: "LIMIT"} @@ -217,21 +173,15 @@ func TestExchangeOrderSubmit(t *testing.T) { _, err = ExchangeOrderSubmit(ctx, exch, currencyPair, delimiter, orderType, orderSide, orderPrice, orderAmount, orderID, orderAsset) - if err != nil && !errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeOrderSubmit(ctx, exch, currencyPair, delimiter, orderType, orderSide, orderPrice, orderAmount, orderID, orderAsset) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) _, err = ExchangeOrderSubmit(ctx, objects.TrueValue, currencyPair, delimiter, orderType, orderSide, orderPrice, orderAmount, orderID, orderAsset) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) } func TestAllModuleNames(t *testing.T) { @@ -246,9 +196,7 @@ func TestAllModuleNames(t *testing.T) { func TestExchangeDepositAddress(t *testing.T) { t.Parallel() _, err := ExchangeDepositAddress() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "BTC"} chain := &objects.String{Value: ""} @@ -258,17 +206,13 @@ func TestExchangeDepositAddress(t *testing.T) { } _, err = ExchangeDepositAddress(exchError, currCode, chain) - if err != nil && !errors.Is(err, errTestFailed) { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeWithdrawCrypto(t *testing.T) { t.Parallel() _, err := ExchangeWithdrawCrypto() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "BTC"} desc := &objects.String{Value: "HELLO"} @@ -284,9 +228,7 @@ func TestExchangeWithdrawCrypto(t *testing.T) { func TestExchangeWithdrawFiat(t *testing.T) { t.Parallel() _, err := ExchangeWithdrawFiat() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "AUD"} desc := &objects.String{Value: "Hello"} @@ -339,14 +281,10 @@ func TestParseInterval(t *testing.T) { func TestSetVerbose(t *testing.T) { t.Parallel() _, err := setVerbose() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatalf("received: '%v' but expected: '%v'", err, objects.ErrWrongNumArguments) - } + require.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = setVerbose(objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) resp, err := setVerbose(&Context{}) require.NoError(t, err) @@ -367,44 +305,28 @@ var dummyStr = &objects.String{Value: "xxxx"} func TestSetAccount(t *testing.T) { t.Parallel() _, err := setAccount() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatalf("received: '%v' but expected: '%v'", err, objects.ErrWrongNumArguments) - } + require.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = setAccount(objects.TrueValue, objects.TrueValue, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, objects.TrueValue, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, dummyStr, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, dummyStr, dummyStr, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, dummyStr, dummyStr, dummyStr, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, dummyStr, dummyStr, dummyStr, dummyStr, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setAccount(&Context{}, dummyStr, dummyStr, dummyStr, dummyStr, dummyStr, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) resp, err := setAccount(&Context{}, dummyStr, dummyStr, dummyStr, dummyStr, dummyStr, dummyStr) require.NoError(t, err) @@ -443,19 +365,13 @@ func TestSetAccount(t *testing.T) { func TestSetSubAccount(t *testing.T) { t.Parallel() _, err := setSubAccount() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatalf("received: '%v' but expected: '%v'", err, objects.ErrWrongNumArguments) - } + require.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = setSubAccount(objects.TrueValue, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) _, err = setSubAccount(&Context{}, objects.TrueValue) - if !errors.Is(err, common.ErrTypeAssertFailure) { - t.Fatalf("received: '%v' but expected: '%v'", err, common.ErrTypeAssertFailure) - } + require.ErrorIs(t, err, common.ErrTypeAssertFailure) subby, err := setSubAccount(&Context{}, dummyStr) require.NoError(t, err) diff --git a/gctscript/modules/ta/indicators/indicators_test.go b/gctscript/modules/ta/indicators/indicators_test.go index 9286d9b5..c40a8d94 100644 --- a/gctscript/modules/ta/indicators/indicators_test.go +++ b/gctscript/modules/ta/indicators/indicators_test.go @@ -1,7 +1,6 @@ package indicators import ( - "errors" "math/rand" "os" "reflect" @@ -55,9 +54,7 @@ func TestMain(m *testing.M) { func TestMfi(t *testing.T) { _, err := mfi() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } v := &objects.String{Value: testString} @@ -124,9 +121,7 @@ func TestRsi(t *testing.T) { func TestEMA(t *testing.T) { _, err := ema() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } v := &objects.String{Value: testString} @@ -166,9 +161,7 @@ func TestEMA(t *testing.T) { func TestSMA(t *testing.T) { _, err := sma() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } v := &objects.String{Value: testString} @@ -208,9 +201,7 @@ func TestSMA(t *testing.T) { func TestMACD(t *testing.T) { _, err := macd() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } v := &objects.String{Value: testString} @@ -256,9 +247,7 @@ func TestMACD(t *testing.T) { func TestAtr(t *testing.T) { _, err := atr() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } v := &objects.String{Value: testString} @@ -298,9 +287,7 @@ func TestAtr(t *testing.T) { func TestBbands(t *testing.T) { _, err := bbands() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } _, err = bbands(&objects.String{Value: testString}, ohlcvData, @@ -361,9 +348,7 @@ func TestBbands(t *testing.T) { &objects.Float{Value: 2.0}, &objects.String{Value: testString}) if err != nil { - if !errors.Is(err, errInvalidSelector) { - t.Error(err) - } + assert.ErrorIs(t, err, errInvalidSelector) } _, err = bbands(objects.UndefinedValue, ohlcvData, @@ -379,9 +364,7 @@ func TestBbands(t *testing.T) { func TestOBV(t *testing.T) { _, err := obv() if err != nil { - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } _, err = obv(ohlcvData) diff --git a/gctscript/vm/vm_test.go b/gctscript/vm/vm_test.go index 3467f5ba..69873cec 100644 --- a/gctscript/vm/vm_test.go +++ b/gctscript/vm/vm_test.go @@ -9,6 +9,8 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -49,25 +51,14 @@ func TestVMLoad(t *testing.T) { started: 1, } testVM := manager.New() - err := testVM.Load(testScript) - if err != nil { - t.Fatal(err) - } + require.NoError(t, testVM.Load(testScript)) testScript = testScript[0 : len(testScript)-4] testVM = manager.New() - err = testVM.Load(testScript) - if err != nil { - t.Fatal(err) - } + require.NoError(t, testVM.Load(testScript)) manager.config = configHelper(false, false, maxTestVirtualMachines) - err = testVM.Load(testScript) - if err != nil { - if !errors.Is(err, ErrScriptingDisabled) { - t.Fatal(err) - } - } + require.NoError(t, testVM.Load(testScript)) } func TestVMLoad1s(t *testing.T) { @@ -76,18 +67,10 @@ func TestVMLoad1s(t *testing.T) { started: 1, } testVM := manager.New() - err := testVM.Load(testScriptRunner1s) - if err != nil { - t.Fatal(err) - } + require.NoError(t, testVM.Load(testScriptRunner1s)) testVM.CompileAndRun() - err = testVM.Shutdown() - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } + require.NoError(t, testVM.Shutdown()) } func TestVMLoadNegativeTimer(t *testing.T) { @@ -96,17 +79,10 @@ func TestVMLoadNegativeTimer(t *testing.T) { started: 1, } testVM := manager.New() - err := testVM.Load(testScriptRunnerNegative) - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } + require.NoError(t, testVM.Load(testScriptRunnerNegative)) + testVM.CompileAndRun() - err = testVM.Shutdown() - if err == nil { - t.Fatal("expect error on shutdown due to invalid VM") - } + require.Error(t, testVM.Shutdown()) } func TestVMLoadNilVM(t *testing.T) { @@ -115,19 +91,10 @@ func TestVMLoadNilVM(t *testing.T) { started: 1, } testVM := manager.New() - err := testVM.Load(testScript) - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } + require.NoError(t, testVM.Load(testScript)) + testVM = nil - err = testVM.Load(testScript) - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } + require.ErrorIs(t, testVM.Load(testScript), ErrNoVMLoaded) } func TestCompileAndRunNilVM(t *testing.T) { @@ -137,28 +104,14 @@ func TestCompileAndRunNilVM(t *testing.T) { } vmcount := VMSCount.Len() testVM := manager.New() - err := testVM.Load(testScript) - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } - err = testVM.Load(testScript) - if err != nil { - if !errors.Is(err, ErrNoVMLoaded) { - t.Fatal(err) - } - } + require.NoError(t, testVM.Load(testScript)) + + require.NoError(t, testVM.Load(testScript)) testVM = nil testVM.CompileAndRun() - err = testVM.Shutdown() - if err == nil { - t.Fatal("VM should not be running with invalid timer") - } - if VMSCount.Len() == vmcount-1 { - t.Fatal("expected VM count to decrease") - } + require.ErrorIs(t, testVM.Shutdown(), ErrNoVMLoaded) + assert.NotEqual(t, vmcount-1, VMSCount.Len(), "Expected vmcount to decrease") } func TestVMLoadNoFile(t *testing.T) { @@ -167,12 +120,7 @@ func TestVMLoadNoFile(t *testing.T) { started: 1, } testVM := manager.New() - err := testVM.Load("missing file") - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - t.Fatal(err) - } - } + assert.ErrorIs(t, testVM.Load("missing file"), os.ErrNotExist) } func TestVMCompile(t *testing.T) { diff --git a/gctscript/wrappers/gct/gctwrapper_test.go b/gctscript/wrappers/gct/gctwrapper_test.go index 5cc38b63..3a4a23f8 100644 --- a/gctscript/wrappers/gct/gctwrapper_test.go +++ b/gctscript/wrappers/gct/gctwrapper_test.go @@ -2,7 +2,6 @@ package gct import ( "context" - "errors" "log" "os" "path/filepath" @@ -10,8 +9,8 @@ import ( "testing" objects "github.com/d5/tengo/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/engine" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" @@ -102,45 +101,29 @@ var ( ctx = &gct.Context{} - tv = objects.TrueValue - fv = objects.FalseValue - errTestFailed = errors.New("test failed") + tv = objects.TrueValue + fv = objects.FalseValue ) func TestExchangeOrderbook(t *testing.T) { t.Parallel() _, err := gct.ExchangeOrderbook(ctx, exch, currencyPair, delimiter, assetType) - if err != nil { - t.Fatal(err) - } - - _, err = gct.ExchangeOrderbook(ctx, exchError, currencyPair, delimiter, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Fatal(err) - } + assert.NoError(t, err) _, err = gct.ExchangeOrderbook() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangeTicker(t *testing.T) { t.Parallel() _, err := gct.ExchangeTicker(ctx, exch, currencyPair, delimiter, assetType) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = gct.ExchangeTicker(ctx, exchError, currencyPair, delimiter, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Fatal(err) - } + assert.NoError(t, err) _, err = gct.ExchangeTicker() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangeExchanges(t *testing.T) { @@ -161,27 +144,19 @@ func TestExchangeExchanges(t *testing.T) { } _, err = gct.ExchangeExchanges() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangePairs(t *testing.T) { t.Parallel() _, err := gct.ExchangePairs(exch, tv, assetType) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, err = gct.ExchangePairs(exchError, tv, assetType) - if err != nil && errors.Is(err, errTestFailed) { - t.Fatal(err) - } + assert.NoError(t, err) _, err = gct.ExchangePairs() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Error(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) } func TestExchangeAccountInfo(t *testing.T) { @@ -199,26 +174,18 @@ func TestExchangeOrderQuery(t *testing.T) { t.Parallel() _, err := gct.ExchangeOrderQuery() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = gct.ExchangeOrderQuery(ctx, exch, orderID) - if err != nil && err != common.ErrNotYetImplemented { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeOrderCancel(t *testing.T) { t.Parallel() _, err := gct.ExchangeOrderCancel() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) _, err = gct.ExchangeOrderCancel(ctx, exch, orderID, currencyPair, assetType) - if err != nil && err != common.ErrNotYetImplemented { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeOrderSubmit(t *testing.T) { @@ -251,66 +218,43 @@ func TestExchangeOrderSubmit(t *testing.T) { func TestAllModuleNames(t *testing.T) { t.Parallel() - x := gct.AllModuleNames() - xType := reflect.TypeOf(x).Kind() - if xType != reflect.Slice { - t.Errorf("AllModuleNames() should return slice instead received: %v", x) - } + assert.IsType(t, []string{}, gct.AllModuleNames(), "AllModuleNames should return a slice of strings") } func TestExchangeDepositAddress(t *testing.T) { t.Parallel() _, err := gct.ExchangeDepositAddress() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "BTC"} chain := &objects.String{Value: ""} _, err = gct.ExchangeDepositAddress(exch, currCode, chain) - if err != nil && err.Error() != "deposit address store is nil" { - t.Error(err) - } + assert.NoError(t, err) } func TestExchangeWithdrawCrypto(t *testing.T) { t.Parallel() _, err := gct.ExchangeWithdrawCrypto() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "BTC"} desc := &objects.String{Value: "HELLO"} address := &objects.String{Value: "0xTHISISALEGITBTCADDRESSS"} amount := &objects.Float{Value: 1.0} - _, err = gct.ExchangeWithdrawCrypto(ctx, - exch, - currCode, - address, - address, - amount, - amount, - desc) - if err != nil { - t.Error(err) - } + _, err = gct.ExchangeWithdrawCrypto(ctx, exch, currCode, address, address, amount, amount, desc) + assert.NoError(t, err) } func TestExchangeWithdrawFiat(t *testing.T) { t.Parallel() _, err := gct.ExchangeWithdrawFiat() - if !errors.Is(err, objects.ErrWrongNumArguments) { - t.Fatal(err) - } + assert.ErrorIs(t, err, objects.ErrWrongNumArguments) currCode := &objects.String{Value: "TEST"} amount := &objects.Float{Value: 1.0} desc := &objects.String{Value: "2"} bankID := &objects.String{Value: "3!"} _, err = gct.ExchangeWithdrawFiat(ctx, exch, currCode, desc, amount, bankID) - if err != nil && err.Error() != "exchange Bitstamp bank details not found for TEST" { - t.Error(err) - } + assert.NoError(t, err) } diff --git a/log/logger_test.go b/log/logger_test.go index bd1cf436..3d9eb77c 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/encoding/json" @@ -102,9 +103,8 @@ func SetupDisabled() error { func TestSetGlobalLogConfig(t *testing.T) { t.Parallel() err := SetGlobalLogConfig(nil) - if !errors.Is(err, errConfigNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigNil) - } + require.ErrorIs(t, err, errConfigNil) + err = SetGlobalLogConfig(testConfigEnabled) require.NoError(t, err) } @@ -112,9 +112,7 @@ func TestSetGlobalLogConfig(t *testing.T) { func TestSetLogPath(t *testing.T) { t.Parallel() err := SetLogPath("") - if !errors.Is(err, errLogPathIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errLogPathIsEmpty) - } + require.ErrorIs(t, err, errLogPathIsEmpty) err = SetLogPath(tempDir) require.NoError(t, err) @@ -147,9 +145,7 @@ func getFileLoggingState() bool { func TestAddWriter(t *testing.T) { t.Parallel() _, err := multiWriter(io.Discard, io.Discard) - if !errors.Is(err, errWriterAlreadyLoaded) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWriterAlreadyLoaded) - } + require.ErrorIs(t, err, errWriterAlreadyLoaded) mw, err := multiWriter() require.NoError(t, err) @@ -168,9 +164,7 @@ func TestAddWriter(t *testing.T) { } err = mw.add(nil) - if !errors.Is(err, errWriterIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWriterIsNil) - } + require.ErrorIs(t, err, errWriterIsNil) if total := len(mw.writers); total != 3 { t.Errorf("expected m.Writers to be 3 %v", total) @@ -230,18 +224,15 @@ func TestMultiWriterWrite(t *testing.T) { func TestGetWriters(t *testing.T) { t.Parallel() err := getWritersProtected(nil) - if !errors.Is(err, errSubloggerConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errSubloggerConfigIsNil) - } + require.ErrorIs(t, err, errSubloggerConfigIsNil) outputWriters := "stDout|stderr|filE" mu.Lock() fileLoggingConfiguredCorrectly = false _, err = getWriters(&SubLoggerConfig{Output: outputWriters}) - if !errors.Is(err, errFileLoggingNotConfiguredCorrectly) { - t.Fatalf("received: '%v' but expected: '%v'", err, errFileLoggingNotConfiguredCorrectly) - } + require.ErrorIs(t, err, errFileLoggingNotConfiguredCorrectly) + fileLoggingConfiguredCorrectly = true _, err = getWriters(&SubLoggerConfig{Output: outputWriters}) require.NoError(t, err) @@ -250,9 +241,7 @@ func TestGetWriters(t *testing.T) { outputWriters = "stdout|stderr|noobs" err = getWritersProtected(&SubLoggerConfig{Output: outputWriters}) - if !errors.Is(err, errUnhandledOutputWriter) { - t.Fatalf("received: '%v' but expected: '%v'", err, errUnhandledOutputWriter) - } + require.ErrorIs(t, err, errUnhandledOutputWriter) } func getWritersProtected(s *SubLoggerConfig) error { @@ -503,9 +492,7 @@ func TestError(t *testing.T) { } sl.setLevelsProtected(splitLevel("ERROR")) err = sl.setOutputProtected(nil) - if !errors.Is(err, errMultiWriterHolderIsNil) { - t.Errorf("received: '%v' but expected: '%v'", err, errMultiWriterHolderIsNil) - } + assert.ErrorIs(t, err, errMultiWriterHolderIsNil) err = sl.setOutputProtected(mw) if err != nil { @@ -579,9 +566,7 @@ func TestSubLoggerName(t *testing.T) { func TestNewSubLogger(t *testing.T) { t.Parallel() _, err := NewSubLogger("") - if !errors.Is(err, errEmptyLoggerName) { - t.Fatalf("received: %v but expected: %v", err, errEmptyLoggerName) - } + require.ErrorIs(t, err, errEmptyLoggerName) sl, err := NewSubLogger("TESTERINOS") require.NoError(t, err) @@ -589,9 +574,7 @@ func TestNewSubLogger(t *testing.T) { Debugln(sl, "testerinos") _, err = NewSubLogger("TESTERINOS") - if !errors.Is(err, ErrSubLoggerAlreadyRegistered) { - t.Fatalf("received: %v but expected: %v", err, ErrSubLoggerAlreadyRegistered) - } + require.ErrorIs(t, err, ErrSubLoggerAlreadyRegistered) } func TestRotateWrite(t *testing.T) { @@ -599,16 +582,12 @@ func TestRotateWrite(t *testing.T) { empty := Rotate{Rotate: convert.BoolPtr(true), FileName: "test.txt"} payload := make([]byte, defaultMaxSize*megabyte+1) _, err := empty.Write(payload) - if !errors.Is(err, errExceedsMaxFileSize) { - t.Fatalf("received: %v but expected: %v", err, errExceedsMaxFileSize) - } + require.ErrorIs(t, err, errExceedsMaxFileSize) empty.MaxSize = 1 payload = make([]byte, 1*megabyte+1) _, err = empty.Write(payload) - if !errors.Is(err, errExceedsMaxFileSize) { - t.Fatalf("received: %v but expected: %v", err, errExceedsMaxFileSize) - } + require.ErrorIs(t, err, errExceedsMaxFileSize) // test write payload = make([]byte, 1*megabyte-1) @@ -628,9 +607,7 @@ func TestOpenNew(t *testing.T) { t.Parallel() empty := Rotate{} err := empty.openNew() - if !errors.Is(err, errFileNameIsEmpty) { - t.Fatalf("received: %v but expected: %v", err, errFileNameIsEmpty) - } + require.ErrorIs(t, err, errFileNameIsEmpty) empty.FileName = "wow.txt" err = empty.openNew()