From 78382afb14e429981b77487b99a2654f694fba1b Mon Sep 17 00:00:00 2001 From: suranmiao Date: Wed, 10 Dec 2025 07:54:54 +0800 Subject: [PATCH] refactor: use reflect.TypeFor instead of reflect.TypeOf and improve related tests (#2101) * refactor: using reflect.TypeFor Signed-off-by: suranmiao * refactor: remove unused reflect.TypeFor calls and improve test assertions * refactor: simplify TestSetup by removing reflect.TypeFor * test: enhance test assertions and improve parallel execution in TestSetup --------- Signed-off-by: suranmiao Co-authored-by: Adrian Gallagher --- cmd/exchange_wrapper_coverage/main.go | 7 +-- .../exchange_wrapper_standards_test.go | 54 ++++++++--------- engine/rpcserver_test.go | 21 +++---- exchanges/order/orders.go | 3 +- gctscript/modules/gct/gct_test.go | 7 +-- gctscript/modules/loader/loader_test.go | 14 ++--- .../modules/ta/indicators/indicators_test.go | 60 +++++++------------ gctscript/modules/ta/ta_test.go | 12 +--- gctscript/vm/vm_test.go | 12 +--- gctscript/wrappers/gct/gctwrapper_test.go | 8 +-- 10 files changed, 75 insertions(+), 123 deletions(-) diff --git a/cmd/exchange_wrapper_coverage/main.go b/cmd/exchange_wrapper_coverage/main.go index 2471b8c6..efd48f21 100644 --- a/cmd/exchange_wrapper_coverage/main.go +++ b/cmd/exchange_wrapper_coverage/main.go @@ -64,8 +64,7 @@ func main() { wg.Wait() log.Println("Done.") - var dummyInterface exchange.IBotExchange - totalWrappers := reflect.TypeOf(&dummyInterface).Elem().NumMethod() + totalWrappers := reflect.TypeFor[exchange.IBotExchange]().NumMethod() log.Println() for name, funcs := range results { @@ -89,11 +88,11 @@ func main() { // error common.ErrNotYetImplemented to verify whether the wrapper function has // been implemented yet. func testWrappers(e exchange.IBotExchange) ([]string, error) { - iExchange := reflect.TypeOf(&e).Elem() + iExchange := reflect.TypeFor[exchange.IBotExchange]() actualExchange := reflect.ValueOf(e) errType := reflect.TypeOf(common.ErrNotYetImplemented) - contextParam := reflect.TypeOf((*context.Context)(nil)).Elem() + contextParam := reflect.TypeFor[context.Context]() var funcs []string for x := range iExchange.NumMethod() { diff --git a/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go b/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go index 4bb78963..77c78abf 100644 --- a/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go +++ b/cmd/exchange_wrapper_standards/exchange_wrapper_standards_test.go @@ -181,7 +181,7 @@ type testCtxKey string func executeExchangeWrapperTests(ctx context.Context, t *testing.T, exch exchange.IBotExchange, assetParams []assetPair) { t.Helper() - iExchange := reflect.TypeOf(&exch).Elem() + iExchange := reflect.TypeFor[exchange.IBotExchange]() actualExchange := reflect.ValueOf(exch) for x := range iExchange.NumMethod() { methodName := iExchange.Method(x).Name @@ -277,33 +277,33 @@ type MethodArgumentGenerator struct { } var ( - currencyPairParam = reflect.TypeOf((*currency.Pair)(nil)).Elem() - klineParam = reflect.TypeOf((*kline.Interval)(nil)).Elem() - contextParam = reflect.TypeOf((*context.Context)(nil)).Elem() - timeParam = reflect.TypeOf((*time.Time)(nil)).Elem() - codeParam = reflect.TypeOf((*currency.Code)(nil)).Elem() - currencyPairsParam = reflect.TypeOf((*currency.Pairs)(nil)).Elem() - withdrawRequestParam = reflect.TypeOf((**withdraw.Request)(nil)).Elem() - stringParam = reflect.TypeOf((*string)(nil)).Elem() - feeBuilderParam = reflect.TypeOf((**exchange.FeeBuilder)(nil)).Elem() - credentialsParam = reflect.TypeOf((**accounts.Credentials)(nil)).Elem() - orderSideParam = reflect.TypeOf((*order.Side)(nil)).Elem() - collateralModeParam = reflect.TypeOf((*collateral.Mode)(nil)).Elem() - marginTypeParam = reflect.TypeOf((*margin.Type)(nil)).Elem() - int64Param = reflect.TypeOf((*int64)(nil)).Elem() - float64Param = reflect.TypeOf((*float64)(nil)).Elem() + currencyPairParam = reflect.TypeFor[currency.Pair]() + klineParam = reflect.TypeFor[kline.Interval]() + contextParam = reflect.TypeFor[context.Context]() + timeParam = reflect.TypeFor[time.Time]() + codeParam = reflect.TypeFor[currency.Code]() + currencyPairsParam = reflect.TypeFor[currency.Pairs]() + withdrawRequestParam = reflect.TypeFor[*withdraw.Request]() + stringParam = reflect.TypeFor[string]() + feeBuilderParam = reflect.TypeFor[*exchange.FeeBuilder]() + credentialsParam = reflect.TypeFor[*accounts.Credentials]() + orderSideParam = reflect.TypeFor[order.Side]() + collateralModeParam = reflect.TypeFor[collateral.Mode]() + marginTypeParam = reflect.TypeFor[margin.Type]() + int64Param = reflect.TypeFor[int64]() + float64Param = reflect.TypeFor[float64]() // types with asset in params - assetParam = reflect.TypeOf((*asset.Item)(nil)).Elem() - orderSubmitParam = reflect.TypeOf((**order.Submit)(nil)).Elem() - orderModifyParam = reflect.TypeOf((**order.Modify)(nil)).Elem() - orderCancelParam = reflect.TypeOf((**order.Cancel)(nil)).Elem() - orderCancelsParam = reflect.TypeOf((*[]order.Cancel)(nil)).Elem() - getOrdersRequestParam = reflect.TypeOf((**order.MultiOrderRequest)(nil)).Elem() - positionChangeRequestParam = reflect.TypeOf((**margin.PositionChangeRequest)(nil)).Elem() - positionSummaryRequestParam = reflect.TypeOf((**futures.PositionSummaryRequest)(nil)).Elem() - positionsRequestParam = reflect.TypeOf((**futures.PositionsRequest)(nil)).Elem() - latestRateRequest = reflect.TypeOf((**fundingrate.LatestRateRequest)(nil)).Elem() - pairKeySliceParam = reflect.TypeOf((*[]key.PairAsset)(nil)).Elem() + assetParam = reflect.TypeFor[asset.Item]() + orderSubmitParam = reflect.TypeFor[*order.Submit]() + orderModifyParam = reflect.TypeFor[*order.Modify]() + orderCancelParam = reflect.TypeFor[*order.Cancel]() + orderCancelsParam = reflect.TypeFor[[]order.Cancel]() + getOrdersRequestParam = reflect.TypeFor[*order.MultiOrderRequest]() + positionChangeRequestParam = reflect.TypeFor[*margin.PositionChangeRequest]() + positionSummaryRequestParam = reflect.TypeFor[*futures.PositionSummaryRequest]() + positionsRequestParam = reflect.TypeFor[*futures.PositionsRequest]() + latestRateRequest = reflect.TypeFor[*fundingrate.LatestRateRequest]() + pairKeySliceParam = reflect.TypeFor[[]key.PairAsset]() ) // generateMethodArg determines the argument type and returns a pre-made diff --git a/engine/rpcserver_test.go b/engine/rpcserver_test.go index 001d5f73..c20dbafc 100644 --- a/engine/rpcserver_test.go +++ b/engine/rpcserver_test.go @@ -13,7 +13,6 @@ import ( "net/http/httptest" "os" "path/filepath" - "reflect" "strconv" "strings" "sync" @@ -1530,22 +1529,18 @@ func TestParseEvents(t *testing.T) { testData[x] = resp } v := parseMultipleEvents(testData) - if reflect.TypeOf(v).String() != "*gctrpc.WithdrawalEventsByExchangeResponse" { - t.Fatal("expected type to be *gctrpc.WithdrawalEventsByExchangeResponse") - } - if len(testData) < 2 { - t.Fatal("expected at least 2") - } + require.NotNil(t, v, "parseMultipleEvents must not return nil") + require.Len(t, v.Event, 5, "parseMultipleEvents must return 5 events") v = parseSingleEvents(testData[0]) - if reflect.TypeOf(v).String() != "*gctrpc.WithdrawalEventsByExchangeResponse" { - t.Fatal("expected type to be *gctrpc.WithdrawalEventsByExchangeResponse") - } + require.NotNil(t, v, "parseSingleEvents must not return nil") + require.NotEmpty(t, v.Event, "parseSingleEvents must return an event") + assert.Equal(t, int64(1), v.Event[0].Request.Type, "parseSingleEvents should return an event with the correct request type") v = parseSingleEvents(testData[1]) - if v.Event[0].Request.Type != 0 { - t.Fatal("Expected second entry in slice to return a Request.Type of Crypto") - } + require.NotNil(t, v, "parseSingleEvents must not return nil") + require.NotEmpty(t, v.Event, "parseSingleEvents must return an event") + assert.Zero(t, v.Event[0].Request.Type, "parseSingleEvents should return an event with the correct request type") } func TestRPCServerUpsertDataHistoryJob(t *testing.T) { diff --git a/exchanges/order/orders.go b/exchanges/order/orders.go index 678fa138..0fda8030 100644 --- a/exchanges/order/orders.go +++ b/exchanges/order/orders.go @@ -1084,8 +1084,7 @@ func StringToOrderSide(side string) (Side, error) { func (s *Side) UnmarshalJSON(data []byte) (err error) { if !bytes.HasPrefix(data, []byte(`"`)) { // Note that we don't need to worry about invalid JSON here, it wouldn't have made it past the deserialiser far - // TODO: Can use reflect.TypeFor[s]() when it's released, probably 1.21 - return &json.UnmarshalTypeError{Value: string(data), Type: reflect.TypeOf(s), Offset: 1} + return &json.UnmarshalTypeError{Value: string(data), Type: reflect.TypeFor[*Side]()} } *s, err = StringToOrderSide(string(data[1 : len(data)-1])) // Remove quotes return diff --git a/gctscript/modules/gct/gct_test.go b/gctscript/modules/gct/gct_test.go index 21a96704..86e9d954 100644 --- a/gctscript/modules/gct/gct_test.go +++ b/gctscript/modules/gct/gct_test.go @@ -2,7 +2,6 @@ package gct import ( "os" - "reflect" "testing" "time" @@ -186,11 +185,7 @@ func TestExchangeOrderSubmit(t *testing.T) { func TestAllModuleNames(t *testing.T) { t.Parallel() - x := AllModuleNames() - xType := reflect.TypeOf(x).Kind() - if xType != reflect.Slice { - t.Errorf("AllModuleNames() should return slice instead received: %v", x) - } + require.NotEmpty(t, AllModuleNames(), "AllModuleNames must not return an empty slice") } func TestExchangeDepositAddress(t *testing.T) { diff --git a/gctscript/modules/loader/loader_test.go b/gctscript/modules/loader/loader_test.go index 5941acf2..1ecfb7df 100644 --- a/gctscript/modules/loader/loader_test.go +++ b/gctscript/modules/loader/loader_test.go @@ -1,18 +1,14 @@ package loader import ( - "reflect" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetModuleMap(t *testing.T) { x := GetModuleMap() - xType := reflect.TypeOf(x).String() - if xType != "*tengo.ModuleMap" { - t.Fatalf("GetModuleMap() should return pointer to ModuleMap instead received: %v", x) - } - - if x.Len() == 0 { - t.Fatal("expected GetModuleMap() to contain module results instead received 0 value") - } + require.NotNil(t, x, "GetModuleMap must not return nil") + assert.NotZero(t, x.Len(), "GetModuleMap should return a map with entries") } diff --git a/gctscript/modules/ta/indicators/indicators_test.go b/gctscript/modules/ta/indicators/indicators_test.go index c40a8d94..e0be8579 100644 --- a/gctscript/modules/ta/indicators/indicators_test.go +++ b/gctscript/modules/ta/indicators/indicators_test.go @@ -3,7 +3,6 @@ package indicators import ( "math/rand" "os" - "reflect" "testing" "time" @@ -394,42 +393,29 @@ func TestOBV(t *testing.T) { } func TestToFloat64(t *testing.T) { - value := 54.0 - v, err := toFloat64(value) - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(v).Kind() != reflect.Float64 { - t.Fatalf("expected toFloat to return kind float64 received: %v", reflect.TypeOf(v).Kind()) - } - - v, err = toFloat64(int(value)) - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(v).Kind() != reflect.Float64 { - t.Fatalf("expected toFloat to return kind float64 received: %v", reflect.TypeOf(v).Kind()) - } - - v, err = toFloat64(int32(value)) - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(v).Kind() != reflect.Float64 { - t.Fatalf("expected toFloat to return kind float64 received: %v", reflect.TypeOf(v).Kind()) - } - - v, err = toFloat64(int64(value)) - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(v).Kind() != reflect.Float64 { - t.Fatalf("expected toFloat to return kind float64 received: %v", reflect.TypeOf(v).Kind()) - } - - _, err = toFloat64("54") - if err == nil { - t.Fatalf("attempting to convert a string should fail but test passed") + t.Parallel() + for _, tc := range []struct { + name string + input any + expected float64 + expectErr bool + }{ + {"float64", 45.67, 45.67, false}, + {"int", int(45), 45.0, false}, + {"int32", int32(45), 45.0, false}, + {"int64", int64(45), 45.0, false}, + {"string", "45.67", 0, true}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result, err := toFloat64(tc.input) + if tc.expectErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) } } diff --git a/gctscript/modules/ta/ta_test.go b/gctscript/modules/ta/ta_test.go index 04813a72..bb2ed63b 100644 --- a/gctscript/modules/ta/ta_test.go +++ b/gctscript/modules/ta/ta_test.go @@ -1,17 +1,11 @@ package ta import ( - "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestGetModuleMap(t *testing.T) { - x := AllModuleNames() - xType := reflect.TypeOf(x).Kind() - if xType != reflect.Slice { - t.Fatalf("AllModuleNames() should return slice instead received: %v", x) - } - if len(x) != 9 { - t.Fatalf("unexpected results received expected 9 received: %v", len(x)) - } + require.Len(t, AllModuleNames(), 9, "AllModuleNames must return 9 modules") } diff --git a/gctscript/vm/vm_test.go b/gctscript/vm/vm_test.go index 69873cec..c982bb14 100644 --- a/gctscript/vm/vm_test.go +++ b/gctscript/vm/vm_test.go @@ -4,7 +4,6 @@ import ( "errors" "os" "path/filepath" - "reflect" "testing" "time" @@ -33,16 +32,9 @@ func TestNewVM(t *testing.T) { manager := GctScriptManager{ config: configHelper(true, true, maxTestVirtualMachines), } - x := manager.New() - if x != nil { - t.Error("Should not create a VM when manager not started") - } + require.Nil(t, manager.New(), "New must not create a VM when manager not started") manager.started = 1 - x = manager.New() - xType := reflect.TypeOf(x).String() - if xType != "*vm.VM" { - t.Fatalf("vm.New should return pointer to VM instead received: %v", x) - } + require.NotNil(t, manager.New(), "New must create a VM when manager is started") } func TestVMLoad(t *testing.T) { diff --git a/gctscript/wrappers/gct/gctwrapper_test.go b/gctscript/wrappers/gct/gctwrapper_test.go index c05d6fc5..665e38a9 100644 --- a/gctscript/wrappers/gct/gctwrapper_test.go +++ b/gctscript/wrappers/gct/gctwrapper_test.go @@ -5,7 +5,6 @@ import ( "log" "os" "path/filepath" - "reflect" "testing" objects "github.com/d5/tengo/v2" @@ -72,11 +71,8 @@ func TestMain(m *testing.M) { } func TestSetup(t *testing.T) { - x := Setup() - xType := reflect.TypeOf(x).String() - if xType != "*gct.Wrapper" { - t.Fatalf("SetupCommunicationManager() should return pointer to Wrapper instead received: %v", x) - } + t.Parallel() + require.NotNil(t, Setup(), "Setup must not return nil") } var (