From 4ccb495baf9e3c1c28486d303620e8d98ac17b37 Mon Sep 17 00:00:00 2001 From: Ryan O'Hara-Reid Date: Mon, 19 Oct 2020 13:59:50 +1100 Subject: [PATCH] Asset package update (#581) * Rewrite new function and deploy where we can minimise the chance of setting an asset type that is different to supported list - sets validation to exact supported list * change wording --- cmd/exchange_wrapper_issues/main.go | 7 +- cmd/gctcli/validation.go | 3 +- config/config.go | 2 +- engine/rpcserver.go | 118 ++++++++++++++++++---------- engine/websocket.go | 15 ++-- exchanges/asset/asset.go | 40 ++++------ exchanges/asset/asset_test.go | 41 +++++----- gctscript/modules/gct/exchange.go | 28 +++++-- 8 files changed, 153 insertions(+), 101 deletions(-) diff --git a/cmd/exchange_wrapper_issues/main.go b/cmd/exchange_wrapper_issues/main.go index 41aa8c13..2fa80c95 100644 --- a/cmd/exchange_wrapper_issues/main.go +++ b/cmd/exchange_wrapper_issues/main.go @@ -280,10 +280,11 @@ func testWrappers(e exchange.IBotExchange, base *exchange.Base, config *Config) testOrderType := parseOrderType(config.OrderSubmission.OrderType) assetTypes := base.GetAssetTypes() if assetTypeOverride != "" { - if asset.IsValid(asset.Item(assetTypeOverride)) { - assetTypes = asset.Items{asset.Item(assetTypeOverride)} - } else { + a, err := asset.New(assetTypeOverride) + if err != nil { log.Printf("%v Asset Type '%v' not recognised, defaulting to exchange defaults", base.GetName(), assetTypeOverride) + } else { + assetTypes = asset.Items{a} } } for i := range assetTypes { diff --git a/cmd/gctcli/validation.go b/cmd/gctcli/validation.go index fd308b06..d8aa4eb4 100644 --- a/cmd/gctcli/validation.go +++ b/cmd/gctcli/validation.go @@ -23,5 +23,6 @@ func validExchange(exch string) bool { } func validAsset(i string) bool { - return asset.IsValid(asset.Item(i)) + _, err := asset.New(i) + return err == nil } diff --git a/config/config.go b/config/config.go index ebe5a971..14d8490e 100644 --- a/config/config.go +++ b/config/config.go @@ -358,7 +358,7 @@ func (c *Config) SupportsExchangeAssetType(exchName string, assetType asset.Item return fmt.Errorf("exchange %s currency pairs is nil", exchName) } - if !asset.IsValid(assetType) { + if !assetType.IsValid() { return fmt.Errorf("exchange %s invalid asset type %s", exchName, assetType) diff --git a/engine/rpcserver.go b/engine/rpcserver.go index d485ad4f..5d188a44 100644 --- a/engine/rpcserver.go +++ b/engine/rpcserver.go @@ -327,13 +327,18 @@ func (s *RPCServer) GetExchangeInfo(_ context.Context, r *gctrpc.GenericExchange // GetTicker returns the ticker for a specified exchange, currency pair and // asset type func (s *RPCServer) GetTicker(_ context.Context, r *gctrpc.GetTickerRequest) (*gctrpc.TickerResponse, error) { + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err + } + t, err := s.GetSpecificTicker(currency.Pair{ Delimiter: r.Pair.Delimiter, Base: currency.NewCode(r.Pair.Base), Quote: currency.NewCode(r.Pair.Quote), }, r.Exchange, - asset.Item(r.AssetType), + a, ) if err != nil { return nil, err @@ -390,13 +395,18 @@ func (s *RPCServer) GetTickers(_ context.Context, r *gctrpc.GetTickersRequest) ( // GetOrderbook returns an orderbook for a specific exchange, currency pair // and asset type func (s *RPCServer) GetOrderbook(_ context.Context, r *gctrpc.GetOrderbookRequest) (*gctrpc.OrderbookResponse, error) { + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err + } + ob, err := s.GetSpecificOrderbook(currency.Pair{ Delimiter: r.Pair.Delimiter, Base: currency.NewCode(r.Pair.Base), Quote: currency.NewCode(r.Pair.Quote), }, r.Exchange, - asset.Item(r.AssetType), + a, ) if err != nil { return nil, err @@ -842,9 +852,9 @@ func (s *RPCServer) SubmitOrder(_ context.Context, r *gctrpc.SubmitOrderRequest) return nil, err } - a := asset.Item(r.AssetType) - if !asset.IsValid(a) { - return nil, fmt.Errorf("asset type: %s is invalid", a) + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err } submission := &order.Submit{ @@ -964,9 +974,9 @@ func (s *RPCServer) CancelOrder(_ context.Context, r *gctrpc.CancelOrderRequest) return nil, err } - a := asset.Item(r.AssetType) - if !asset.IsValid(a) { - return nil, fmt.Errorf("asset type: %s is invalid", a) + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err } err = exch.CancelOrder(&order.Cancel{ @@ -1007,7 +1017,12 @@ func (s *RPCServer) AddEvent(_ context.Context, r *gctrpc.AddEventRequest) (*gct p := currency.NewPairWithDelimiter(r.Pair.Base, r.Pair.Quote, r.Pair.Delimiter) - id, err := Add(r.Exchange, r.Item, evtCondition, p, asset.Item(r.AssetType), r.Action) + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err + } + + id, err := Add(r.Exchange, r.Item, evtCondition, p, a, r.Action) if err != nil { return nil, err } @@ -1261,14 +1276,21 @@ func (s *RPCServer) GetExchangePairs(_ context.Context, r *gctrpc.GetExchangePai return nil, err } - if r.Asset != "" && - !exchCfg.CurrencyPairs.GetAssetTypes().Contains(asset.Item(r.Asset)) { - return nil, errors.New("specified asset type does not exist") + assetTypes := exchCfg.CurrencyPairs.GetAssetTypes() + + var a asset.Item + if r.Asset != "" { + a, err = asset.New(r.Asset) + if err != nil { + return nil, err + } + if !assetTypes.Contains(a) { + return nil, fmt.Errorf("specified asset %s is not supported by exchange", a) + } } var resp gctrpc.GetExchangePairsResponse resp.SupportedAssets = make(map[string]*gctrpc.PairsSupported) - assetTypes := exchCfg.CurrencyPairs.GetAssetTypes() for x := range assetTypes { if r.Asset != "" && !strings.EqualFold(assetTypes[x].String(), r.Asset) { continue @@ -1298,11 +1320,14 @@ func (s *RPCServer) SetExchangePair(_ context.Context, r *gctrpc.SetExchangePair return nil, errors.New("asset type must be specified") } - if !exchCfg.CurrencyPairs.GetAssetTypes().Contains(asset.Item(r.AssetType)) { - return nil, errors.New("specified asset type does not exist") + a, err := asset.New(r.AssetType) + if err != nil { + return nil, err } - a := asset.Item(r.AssetType) + if !exchCfg.CurrencyPairs.GetAssetTypes().Contains(a) { + return nil, fmt.Errorf("specified asset %s is not supported by exchange", a) + } exch := s.GetExchangeByName(r.Exchange) if exch == nil { @@ -1334,7 +1359,7 @@ func (s *RPCServer) SetExchangePair(_ context.Context, r *gctrpc.SetExchangePair newErrors = append(newErrors, err) continue } - err = base.CurrencyPairs.EnablePair(asset.Item(r.AssetType), p) + err = base.CurrencyPairs.EnablePair(a, p) if err != nil { newErrors = append(newErrors, err) continue @@ -1343,13 +1368,13 @@ func (s *RPCServer) SetExchangePair(_ context.Context, r *gctrpc.SetExchangePair continue } - err = exchCfg.CurrencyPairs.DisablePair(asset.Item(r.AssetType), + err = exchCfg.CurrencyPairs.DisablePair(a, p.Format(pairFmt.Delimiter, pairFmt.Uppercase)) if err != nil { newErrors = append(newErrors, err) continue } - err = base.CurrencyPairs.DisablePair(asset.Item(r.AssetType), p) + err = base.CurrencyPairs.DisablePair(a, p) if err != nil { newErrors = append(newErrors, err) continue @@ -1390,7 +1415,12 @@ func (s *RPCServer) GetOrderbookStream(r *gctrpc.GetOrderbookStreamRequest, stre return err } - pipe, err := orderbook.SubscribeOrderbook(r.Exchange, p, asset.Item(r.AssetType)) + a, err := asset.New(r.AssetType) + if err != nil { + return err + } + + pipe, err := orderbook.SubscribeOrderbook(r.Exchange, p, a) if err != nil { return err } @@ -1499,7 +1529,12 @@ func (s *RPCServer) GetTickerStream(r *gctrpc.GetTickerStreamRequest, stream gct return err } - pipe, err := ticker.SubscribeTicker(r.Exchange, p, asset.Item(r.AssetType)) + a, err := asset.New(r.AssetType) + if err != nil { + return err + } + + pipe, err := ticker.SubscribeTicker(r.Exchange, p, a) if err != nil { return err } @@ -1640,14 +1675,21 @@ func (s *RPCServer) GetHistoricCandles(_ context.Context, req *gctrpc.GetHistori End: req.End, } + a, err := asset.New(req.AssetType) + if err != nil { + return nil, err + } + + pair := currency.Pair{ + Delimiter: req.Pair.Delimiter, + Base: currency.NewCode(req.Pair.Base), + Quote: currency.NewCode(req.Pair.Quote), + } + if req.UseDb { candles, err = kline.LoadFromDatabase(req.Exchange, - currency.Pair{ - Delimiter: req.Pair.Delimiter, - Base: currency.NewCode(req.Pair.Base), - Quote: currency.NewCode(req.Pair.Quote), - }, - asset.Item(strings.ToLower(req.AssetType)), + pair, + a, kline.Interval(req.TimeInterval), time.Unix(req.Start, 0), time.Unix(req.End, 0), @@ -1658,22 +1700,14 @@ func (s *RPCServer) GetHistoricCandles(_ context.Context, req *gctrpc.GetHistori return nil, errors.New("Exchange " + req.Exchange + " not found") } if req.ExRequest { - candles, err = exchangeEngine.GetHistoricCandlesExtended(currency.Pair{ - Delimiter: req.Pair.Delimiter, - Base: currency.NewCode(req.Pair.Base), - Quote: currency.NewCode(req.Pair.Quote), - }, - asset.Item(strings.ToLower(req.AssetType)), + candles, err = exchangeEngine.GetHistoricCandlesExtended(pair, + a, time.Unix(req.Start, 0), time.Unix(req.End, 0), kline.Interval(req.TimeInterval)) } else { - candles, err = exchangeEngine.GetHistoricCandles(currency.Pair{ - Delimiter: req.Pair.Delimiter, - Base: currency.NewCode(req.Pair.Base), - Quote: currency.NewCode(req.Pair.Quote), - }, - asset.Item(strings.ToLower(req.AssetType)), + candles, err = exchangeEngine.GetHistoricCandles(pair, + a, time.Unix(req.Start, 0), time.Unix(req.End, 0), kline.Interval(req.TimeInterval)) @@ -2020,7 +2054,11 @@ func (s *RPCServer) SetExchangeAsset(_ context.Context, r *gctrpc.SetExchangeAss return nil, errors.New("asset type must be specified") } - a := asset.Item(r.Asset) + a, err := asset.New(r.Asset) + if err != nil { + return nil, err + } + err = base.CurrencyPairs.SetAssetEnabled(a, r.Enable) if err != nil { return nil, err diff --git a/engine/websocket.go b/engine/websocket.go index 78ba3d81..7b57bca1 100644 --- a/engine/websocket.go +++ b/engine/websocket.go @@ -356,10 +356,12 @@ func wsGetTicker(client *WebsocketClient, data interface{}) error { return err } - result, err := Bot.GetSpecificTicker(p, - tickerReq.Exchange, - asset.Item(tickerReq.AssetType)) + a, err := asset.New(tickerReq.AssetType) + if err != nil { + return err + } + result, err := Bot.GetSpecificTicker(p, tickerReq.Exchange, a) if err != nil { wsResp.Error = err.Error() client.SendWebsocketMessage(wsResp) @@ -394,9 +396,12 @@ func wsGetOrderbook(client *WebsocketClient, data interface{}) error { return err } - result, err := Bot.GetSpecificOrderbook(p, - orderbookReq.Exchange, asset.Item(orderbookReq.AssetType)) + a, err := asset.New(orderbookReq.AssetType) + if err != nil { + return err + } + result, err := Bot.GetSpecificOrderbook(p, orderbookReq.Exchange, a) if err != nil { wsResp.Error = err.Error() client.SendWebsocketMessage(wsResp) diff --git a/exchanges/asset/asset.go b/exchanges/asset/asset.go index 5e80d0f7..015b27a4 100644 --- a/exchanges/asset/asset.go +++ b/exchanges/asset/asset.go @@ -1,6 +1,7 @@ package asset import ( + "fmt" "strings" ) @@ -59,12 +60,12 @@ func (a Items) Strings() []string { // Contains returns whether or not the supplied asset exists // in the list of Items func (a Items) Contains(i Item) bool { - if !IsValid(i) { + if !i.IsValid() { return false } for x := range a { - if strings.EqualFold(a[x].String(), i.String()) { + if a[x].String() == i.String() { return true } } @@ -79,35 +80,24 @@ func (a Items) JoinToString(separator string) string { // IsValid returns whether or not the supplied asset type is valid or // not -func IsValid(input Item) bool { - a := Supported() - for x := range a { - if strings.EqualFold(a[x].String(), input.String()) { +func (a Item) IsValid() bool { + for x := range supported { + if supported[x].String() == a.String() { return true } } return false } -// New takes an input of asset types as string and returns an Items -// array -func New(input string) Items { - if !strings.Contains(input, ",") { - if IsValid(Item(input)) { - return Items{ - Item(input), - } +// New takes an input matches to relevant package assets +func New(input string) (Item, error) { + input = strings.ToLower(input) + for i := range supported { + if string(supported[i]) == input { + return supported[i], nil } - return nil } - - assets := strings.Split(input, ",") - var result Items - for x := range assets { - if !IsValid(Item(assets[x])) { - return nil - } - result = append(result, Item(assets[x])) - } - return result + return "", fmt.Errorf("cannot create new asset: input %s mismatch to supported asset list %s", + input, + supported) } diff --git a/exchanges/asset/asset_test.go b/exchanges/asset/asset_test.go index e1db02ba..4553b0f1 100644 --- a/exchanges/asset/asset_test.go +++ b/exchanges/asset/asset_test.go @@ -37,7 +37,9 @@ func TestContains(t *testing.T) { t.Fatal("TestContains returned an unexpected result") } - if !a.Contains("SpOt") { + // Every asset should be created and matched with func New so this should + // not be matched against list + if a.Contains("SpOt") { t.Error("TestContains returned an unexpected result") } } @@ -50,36 +52,39 @@ func TestJoinToString(t *testing.T) { } func TestIsValid(t *testing.T) { - if IsValid("rawr") { + if Item("rawr").IsValid() { t.Fatal("TestIsValid returned an unexpected result") } - if !IsValid(Spot) { + if !Spot.IsValid() { t.Fatal("TestIsValid returned an unexpected result") } } func TestNew(t *testing.T) { - a := New("Spota") - if a != nil { + _, err := New("Spota") + if err == nil { t.Fatal("TestNew returned an unexpected result") } - a = New("SpOt") - if a == nil { - t.Fatal("TestNew returned an unexpected result") + a, err := New("SpOt") + if err != nil { + t.Fatal("TestNew returned an unexpected result", err) } - a = New("spot,futures") - if a.JoinToString(",") != "spot,futures" { - t.Fatal("TestNew returned an unexpected result") - } - - if a := New("Spot_rawr"); a != nil { - t.Fatal("TestNew returned an unexpected result") - } - - if a := New("Spot,Rawr"); a != nil { + if a != Spot { t.Fatal("TestNew returned an unexpected result") } } + +func TestSupported(t *testing.T) { + s := Supported() + if len(supported) != len(s) { + t.Fatal("TestSupported mismatched lengths") + } + for i := 0; i < len(supported); i++ { + if s[i] != supported[i] { + t.Fatal("TestSupported returned an unexpected result") + } + } +} diff --git a/gctscript/modules/gct/exchange.go b/gctscript/modules/gct/exchange.go index d2542f3d..125e2a5d 100644 --- a/gctscript/modules/gct/exchange.go +++ b/gctscript/modules/gct/exchange.go @@ -2,7 +2,6 @@ package gct import ( "fmt" - "strings" "time" objects "github.com/d5/tengo/v2" @@ -58,7 +57,11 @@ func ExchangeOrderbook(args ...objects.Object) (objects.Object, error) { if err != nil { return nil, err } - assetType := asset.Item(assetTypeParam) + + assetType, err := asset.New(assetTypeParam) + if err != nil { + return nil, err + } ob, err := wrappers.GetWrapper().Orderbook(exchangeName, pair, assetType) if err != nil { @@ -121,7 +124,10 @@ func ExchangeTicker(args ...objects.Object) (objects.Object, error) { return nil, err } - assetType := asset.Item(assetTypeParam) + assetType, err := asset.New(assetTypeParam) + if err != nil { + return nil, err + } tx, err := wrappers.GetWrapper().Ticker(exchangeName, pair, assetType) if err != nil { @@ -187,7 +193,10 @@ func ExchangePairs(args ...objects.Object) (objects.Object, error) { if !ok { return nil, fmt.Errorf(ErrParameterConvertFailed, assetTypeParam) } - assetType := asset.Item(strings.ToLower(assetTypeParam)) + assetType, err := asset.New(assetTypeParam) + if err != nil { + return nil, err + } rtnValue, err := wrappers.GetWrapper().Pairs(exchangeName, enabledOnly, assetType) if err != nil { @@ -362,9 +371,9 @@ func ExchangeOrderSubmit(args ...objects.Object) (objects.Object, error) { if !ok { return nil, fmt.Errorf(ErrParameterConvertFailed, orderClientID) } - a := asset.Item(assetType) - if !asset.IsValid(a) { - return nil, fmt.Errorf("asset type: %s is invalid", a) + a, err := asset.New(assetType) + if err != nil { + return nil, err } tempSubmit := &order.Submit{ @@ -571,7 +580,10 @@ func exchangeOHLCV(args ...objects.Object) (objects.Object, error) { if err != nil { return nil, err } - assetType := asset.Item(assetTypeParam) + assetType, err := asset.New(assetTypeParam) + if err != nil { + return nil, err + } ret, err := wrappers.GetWrapper().OHLCV(exchangeName, pair, assetType, startTime, endTime, kline.Interval(interval)) if err != nil {