From 7a90aecf6fd6f489574239616aaee0c680d5ce47 Mon Sep 17 00:00:00 2001 From: Adrian Gallagher Date: Thu, 16 Jan 2020 08:10:25 +1100 Subject: [PATCH] Bugfix: Introduces a new config.DefaultFilePath func (#415) * Introduces a new config.DefaultFilePath func * FiX GrAmMeRiNo --- cmd/config/config.go | 18 ++++++------------ cmd/dbmigrate/main.go | 15 ++++----------- cmd/gen_otp/otp_gen.go | 7 +------ cmd/gen_sqlboiler_config/main.go | 15 ++++----------- cmd/portfolio/portfolio.go | 11 ++--------- config/config.go | 28 +++++++++++++++++++++++++++- config/config_test.go | 14 ++++++++++++++ main.go | 3 ++- 8 files changed, 60 insertions(+), 51 deletions(-) diff --git a/cmd/config/config.go b/cmd/config/config.go index 3274b9ed..b1e28e48 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -20,15 +20,9 @@ func EncryptOrDecrypt(encrypt bool) string { func main() { var inFile, outFile, key string var encrypt bool - var err error - - configFile, err := config.GetFilePath("") - if err != nil { - log.Fatal(err) - } - - flag.StringVar(&inFile, "infile", configFile, "The config input file to process.") - flag.StringVar(&outFile, "outfile", configFile+".out", "The config output file.") + defaultCfgFile := config.DefaultFilePath() + flag.StringVar(&inFile, "infile", defaultCfgFile, "The config input file to process.") + flag.StringVar(&outFile, "outfile", defaultCfgFile+".out", "The config output file.") flag.BoolVar(&encrypt, "encrypt", true, "Whether to encrypt or decrypt.") flag.StringVar(&key, "key", "", "The key to use for AES encryption.") flag.Parse() @@ -36,9 +30,9 @@ func main() { log.Println("GoCryptoTrader: config-helper tool.") if key == "" { - result, errf := config.PromptForConfigKey(false) - if errf != nil { - log.Fatal("Unable to obtain encryption/decryption key.") + result, err := config.PromptForConfigKey(false) + if err != nil { + log.Fatalf("Unable to obtain encryption/decryption key: %s", err) } key = string(result) } diff --git a/cmd/dbmigrate/main.go b/cmd/dbmigrate/main.go index 25b240bd..9fbc3b8b 100644 --- a/cmd/dbmigrate/main.go +++ b/cmd/dbmigrate/main.go @@ -48,24 +48,16 @@ func main() { fmt.Println(core.Copyright) fmt.Println() - defaultPath, err := config.GetFilePath("") - if err != nil { - fmt.Println(err) - os.Exit(1) - } - flag.StringVar(&command, "command", "", "command to run status|up|up-by-one|up-to|down|create") flag.StringVar(&args, "args", "", "arguments to pass to goose") - - flag.StringVar(&configFile, "config", defaultPath, "config file to load") + flag.StringVar(&configFile, "config", config.DefaultFilePath(), "config file to load") flag.StringVar(&defaultDataDir, "datadir", common.GetDefaultDataDir(runtime.GOOS), "default data directory for GoCryptoTrader files") flag.StringVar(&migrationDir, "migrationdir", database.MigrationDir, "override migration folder") flag.Parse() - conf := config.GetConfig() - - err = conf.LoadConfig(configFile, true) + var conf config.Config + err := conf.LoadConfig(configFile, true) if err != nil { fmt.Println(err) os.Exit(1) @@ -75,6 +67,7 @@ func main() { fmt.Println("Database support is disabled") os.Exit(1) } + err = openDbConnection(conf.Database.Driver) if err != nil { fmt.Println(err) diff --git a/cmd/gen_otp/otp_gen.go b/cmd/gen_otp/otp_gen.go index f48c4fe2..66a6380d 100644 --- a/cmd/gen_otp/otp_gen.go +++ b/cmd/gen_otp/otp_gen.go @@ -27,12 +27,7 @@ func main() { var single bool var err error - defaultCfg, err := config.GetFilePath("") - if err != nil { - log.Fatal(err) - } - - flag.StringVar(&cfgFile, "config", defaultCfg, "The config input file to process.") + flag.StringVar(&cfgFile, "config", config.DefaultFilePath(), "The config input file to process.") flag.BoolVar(&single, "single", false, "prompt for single use OTP code gen") flag.Parse() diff --git a/cmd/gen_sqlboiler_config/main.go b/cmd/gen_sqlboiler_config/main.go index 9d739a2c..37567780 100644 --- a/cmd/gen_sqlboiler_config/main.go +++ b/cmd/gen_sqlboiler_config/main.go @@ -40,26 +40,19 @@ func main() { fmt.Println(core.Copyright) fmt.Println() - defaultPath, err := config.GetFilePath("") - if err != nil { - fmt.Println(err) - os.Exit(1) - } - - flag.StringVar(&configFile, "config", defaultPath, "config file to load") + flag.StringVar(&configFile, "config", config.DefaultFilePath(), "config file to load") flag.StringVar(&defaultDataDir, "datadir", common.GetDefaultDataDir(runtime.GOOS), "default data directory for GoCryptoTrader files") flag.StringVar(&outputFolder, "outdir", "", "overwrite default output folder") flag.Parse() - conf := config.GetConfig() - - err = conf.LoadConfig(configFile, true) + var cfg config.Config + err := cfg.LoadConfig(configFile, true) if err != nil { fmt.Println(err) os.Exit(1) } - convertGCTtoSQLBoilerConfig(&conf.Database) + convertGCTtoSQLBoilerConfig(&cfg.Database) jsonOutput, err := json.MarshalIndent(sqlboilerConfig, "", " ") if err != nil { diff --git a/cmd/portfolio/portfolio.go b/cmd/portfolio/portfolio.go index d9594674..8c5d30a2 100644 --- a/cmd/portfolio/portfolio.go +++ b/cmd/portfolio/portfolio.go @@ -63,21 +63,14 @@ func getOnlineOfflinePortfolio(coins []portfolio.Coin, online bool) { func main() { var inFile, key string - - defaultCfg, err := config.GetFilePath("") - if err != nil { - log.Println(err) - os.Exit(1) - } - - flag.StringVar(&inFile, "infile", defaultCfg, "The config input file to process.") + flag.StringVar(&inFile, "infile", config.DefaultFilePath(), "The config input file to process.") flag.StringVar(&key, "key", "", "The key to use for AES encryption.") flag.Parse() log.Println("GoCryptoTrader: portfolio tool.") var cfg config.Config - err = cfg.LoadConfig(inFile, true) + err := cfg.LoadConfig(inFile, true) if err != nil { log.Println(err) os.Exit(1) diff --git a/config/config.go b/config/config.go index 43c7fb27..742ce211 100644 --- a/config/config.go +++ b/config/config.go @@ -1328,8 +1328,34 @@ func (c *Config) CheckConnectionMonitorConfig() { } } +// DefaultFilePath returns the default config file path +// MacOS/Linux: $HOME/.gocryptotrader/config.json or config.dat +// Windows: %APPDATA%\GoCryptoTrader\config.json or config.dat +// Helpful for printing application usage +func DefaultFilePath() string { + f := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), File) + _, err := os.Stat(f) + if os.IsNotExist(err) { + encFile := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), EncryptedFile) + _, err = os.Stat(encFile) + if !os.IsNotExist(err) { + return encFile + } + } + return f +} + // GetFilePath returns the desired config file or the default config file name -// based on if the application is being run under test or normal mode. +// based on if the application is being run under test or normal mode. It will +// also move/rename the config file under the following conditions: +// 1) If a config file is found in the executable path directory and no explicit +// config path is set, plus no config is found in the GCT data dir, it will +// move it to the GCT data dir. If a config already exists in the GCT data +// dir, it will warn the user and load the config found in the GCT data dir +// 2) If a config file in the GCT data dir has the file extension .dat but +// contains json data, it will rename to the file to config.json +// 3) If a config file in the GCT data dir has the file extension .json but +// contains encrypted data, it will rename the file to config.dat func GetFilePath(configfile string) (string, error) { if configfile != "" { return configfile, nil diff --git a/config/config_test.go b/config/config_test.go index 1543dd44..20fd889c 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1608,6 +1608,20 @@ func TestCheckConnectionMonitorConfig(t *testing.T) { } } +func TestDefaultFilePath(t *testing.T) { + // This is tricky to test because we're dealing with a config file stored + // in a persons default directory and to properly test it, it would + // require causing os.Stat to return !os.IsNotExist and os.IsNotExist (which + // means moving a users config file around), a way of getting around this is + // to pass the datadir as a param line but adds a burden to everyone who + // uses it + result := DefaultFilePath() + if !strings.Contains(result, File) && + !strings.Contains(result, EncryptedFile) { + t.Error("result should have contained config.json or config.dat") + } +} + func TestGetFilePath(t *testing.T) { expected := "blah.json" result, _ := GetFilePath("blah.json") diff --git a/main.go b/main.go index fdb02f25..18b0f825 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "time" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/core" "github.com/thrasher-corp/gocryptotrader/dispatch" "github.com/thrasher-corp/gocryptotrader/engine" @@ -22,7 +23,7 @@ func main() { versionFlag := flag.Bool("version", false, "retrieves current GoCryptoTrader version") // Core settings - flag.StringVar(&settings.ConfigFile, "config", "", "config file to load") + flag.StringVar(&settings.ConfigFile, "config", config.DefaultFilePath(), "config file to load") flag.StringVar(&settings.DataDir, "datadir", common.GetDefaultDataDir(runtime.GOOS), "default data directory for GoCryptoTrader files") flag.IntVar(&settings.GoMaxProcs, "gomaxprocs", runtime.GOMAXPROCS(-1), "sets the runtime GOMAXPROCS value") flag.BoolVar(&settings.EnableDryRun, "dryrun", false, "dry runs bot, doesn't save config file")