diff --git a/config/config.go b/config/config.go index 656aafd5..d6f916e1 100644 --- a/config/config.go +++ b/config/config.go @@ -1303,7 +1303,7 @@ func (c *Config) CheckLoggerConfig() error { log.GlobalLogConfig = &c.Logging log.RWM.Unlock() - logPath := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "logs") + logPath := c.GetDataPath("logs") err := common.CreateDir(logPath) if err != nil { return err @@ -1325,7 +1325,7 @@ func (c *Config) checkGCTScriptConfig() error { c.GCTScript.MaxVirtualMachines = gctscript.DefaultMaxVirtualMachines } - scriptPath := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "scripts") + scriptPath := c.GetDataPath("scripts") err := common.CreateDir(scriptPath) if err != nil { return err @@ -1362,7 +1362,7 @@ func (c *Config) checkDatabaseConfig() error { } if c.Database.Driver == database.DBSQLite || c.Database.Driver == database.DBSQLite3 { - databaseDir := filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "database") + databaseDir := c.GetDataPath("database") err := common.CreateDir(databaseDir) if err != nil { return err @@ -1845,3 +1845,14 @@ func (c *Config) AssetTypeEnabled(a asset.Item, exch string) (bool, error) { } return true, nil } + +// GetDataPath gets the data path for the given subpath +func (c *Config) GetDataPath(elem ...string) string { + var baseDir string + if c.DataDirectory != "" { + baseDir = c.DataDirectory + } else { + baseDir = common.GetDefaultDataDir(runtime.GOOS) + } + return filepath.Join(append([]string{baseDir}, elem...)...) +} diff --git a/config/config_test.go b/config/config_test.go index 39985340..68c3d7bd 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,6 +1,8 @@ package config import ( + "path/filepath" + "runtime" "strings" "testing" @@ -2034,3 +2036,43 @@ func TestRemoveExchange(t *testing.T) { t.Fatal("exchange shouldn't exist") } } + +func TestGetDataPath(t *testing.T) { + tests := []struct { + name string + dir string + elem []string + want string + }{ + { + name: "empty", + dir: "", + elem: []string{}, + want: common.GetDefaultDataDir(runtime.GOOS), + }, + { + name: "empty a b", + dir: "", + elem: []string{"a", "b"}, + want: filepath.Join(common.GetDefaultDataDir(runtime.GOOS), "a", "b"), + }, + { + name: "target", + dir: "target", + elem: []string{"a", "b"}, + want: filepath.Join("target", "a", "b"), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c := &Config{ + DataDirectory: tt.dir, + } + if got := c.GetDataPath(tt.elem...); got != tt.want { + t.Errorf("Config.GetDataPath() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/config/config_types.go b/config/config_types.go index 57e08dcf..0af2cb2a 100644 --- a/config/config_types.go +++ b/config/config_types.go @@ -78,6 +78,7 @@ var ( // Exchanges type Config struct { Name string `json:"name"` + DataDirectory string `json:"dataDirectory"` EncryptConfig int `json:"encryptConfig"` GlobalHTTPTimeout time.Duration `json:"globalHTTPTimeout"` Database database.Config `json:"database"` diff --git a/config_example.json b/config_example.json index 1bd16444..a43e7974 100644 --- a/config_example.json +++ b/config_example.json @@ -1,5 +1,6 @@ { "name": "Skynet", + "dataDirectory": "", "encryptConfig": 0, "globalHTTPTimeout": 15000000000, "database": { diff --git a/engine/engine.go b/engine/engine.go index 303c3dbc..2b3627c4 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -70,16 +70,13 @@ func NewFromSettings(settings *Settings) (*Engine, error) { if settings == nil { return nil, errors.New("engine: settings is nil") } + // collect flags + flag.Visit(func(f *flag.Flag) { flagSet[f.Name] = true }) var b Engine - b.Config = &config.Cfg - filePath, err := config.GetFilePath(settings.ConfigFile) - if err != nil { - return nil, err - } + var err error - log.Printf("Loading config file %s..\n", filePath) - err = b.Config.LoadConfig(filePath, settings.EnableDryRun) + b.Config, err = loadConfigWithSettings(settings) if err != nil { return nil, fmt.Errorf("failed to load config. Err: %s", err) } @@ -95,8 +92,8 @@ func NewFromSettings(settings *Settings) (*Engine, error) { gctlog.Infoln(gctlog.Global, "Logger initialised.") } - b.Settings.ConfigFile = filePath - b.Settings.DataDir = settings.DataDir + b.Settings.ConfigFile = settings.ConfigFile + b.Settings.DataDir = b.Config.GetDataPath() b.Settings.CheckParamInteraction = settings.CheckParamInteraction err = utils.AdjustGoMaxProcs(settings.GoMaxProcs) @@ -104,14 +101,38 @@ func NewFromSettings(settings *Settings) (*Engine, error) { return nil, fmt.Errorf("unable to adjust runtime GOMAXPROCS value. Err: %s", err) } - ValidateSettings(&b, settings) + validateSettings(&b, settings) return &b, nil } -// ValidateSettings validates and sets all bot settings -func ValidateSettings(b *Engine, s *Settings) { - flag.Visit(func(f *flag.Flag) { flagSet[f.Name] = true }) +// loadConfigWithSettings creates configuration based on the provided settings +func loadConfigWithSettings(settings *Settings) (*config.Config, error) { + filePath, err := config.GetFilePath(settings.ConfigFile) + if err != nil { + return nil, err + } + log.Printf("Loading config file %s..\n", filePath) + conf := &config.Cfg + err = conf.ReadConfig(filePath, settings.EnableDryRun) + if err != nil { + return nil, fmt.Errorf(config.ErrFailureOpeningConfig, filePath, err) + } + // Apply overrides from settings + if flagSet["datadir"] { + // warn if dryrun isn't enabled + if !settings.EnableDryRun { + log.Println("Command line argument '-datadir' induces dry run mode.") + } + settings.EnableDryRun = true + conf.DataDirectory = settings.DataDir + } + + return conf, conf.CheckConfig() +} + +// validateSettings validates and sets all bot settings +func validateSettings(b *Engine, s *Settings) { b.Settings.Verbose = s.Verbose b.Settings.EnableDryRun = s.EnableDryRun b.Settings.EnableAllExchanges = s.EnableAllExchanges diff --git a/engine/engine_test.go b/engine/engine_test.go new file mode 100644 index 00000000..e53387c2 --- /dev/null +++ b/engine/engine_test.go @@ -0,0 +1,73 @@ +package engine + +import ( + "os" + "testing" + + "github.com/thrasher-corp/gocryptotrader/config" +) + +func TestLoadConfigWithSettings(t *testing.T) { + empty := "" + somePath := "somePath" + // Clean up after the tests + defer os.RemoveAll(somePath) + tests := []struct { + name string + flags []string + settings *Settings + want *string + wantErr bool + }{ + { + name: "invalid file", + settings: &Settings{ + ConfigFile: "nonExistent.json", + }, + wantErr: true, + }, + { + name: "test file", + settings: &Settings{ + ConfigFile: config.TestFile, + EnableDryRun: true, + }, + want: &empty, + wantErr: false, + }, + { + name: "data dir in settings overrides config data dir", + flags: []string{"datadir"}, + settings: &Settings{ + ConfigFile: config.TestFile, + DataDir: somePath, + EnableDryRun: true, + }, + want: &somePath, + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + // prepare the 'flags' + flagSet = make(map[string]bool) + for _, v := range tt.flags { + flagSet[v] = true + } + // Run the test + got, err := loadConfigWithSettings(tt.settings) + if (err != nil) != tt.wantErr { + t.Errorf("loadConfigWithSettings() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil || tt.want != nil { + if (got == nil && tt.want != nil) || (got != nil && tt.want == nil) { + t.Errorf("loadConfigWithSettings() = is nil %v, want nil %v", got == nil, tt.want == nil) + } else if got.DataDirectory != *tt.want { + t.Errorf("loadConfigWithSettings() = %v, want %v", got.DataDirectory, *tt.want) + } + } + }) + } +}