diff --git a/config/config.go b/config/config.go index 211ebe5d..77372155 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "errors" + "flag" "fmt" "log" "os" @@ -80,10 +81,10 @@ type Config struct { Name string EncryptConfig int Cryptocurrencies string - Portfolio portfolio.Base `json:"PortfolioAddresses"` - SMS SMSGlobalConfig `json:"SMSGlobal"` - Webserver WebserverConfig `json:"Webserver"` - Exchanges []ExchangeConfig `json:"Exchanges"` + Portfolio portfolio.Base `json:"PortfolioAddresses"` + SMS SMSGlobalConfig `json:"SMSGlobal"` + Webserver WebserverConfig `json:"Webserver"` + Exchanges []ExchangeConfig `json:"Exchanges"` } // ExchangeConfig holds all the information needed for each enabled Exchange. @@ -284,6 +285,18 @@ func (c *Config) RetrieveConfigCurrencyPairs() error { return nil } +// GetFilePath returns the desired config file or the default config file name +// based on if the application is being run under test or normal mode. +func GetFilePath(file string) string { + if file != "" { + return file + } + if flag.Lookup("test.v") == nil { + return ConfigFile + } + return ConfigTestFile +} + // CheckConfig checks to see if there is an old configuration filename and path // if found it will change it to correct filename. func CheckConfig() error { @@ -301,14 +314,7 @@ func CheckConfig() error { // ReadConfig verifies and checks for encryption and verifies the unencrypted // file contains JSON. func (c *Config) ReadConfig(configPath string) error { - var defaultPath string - - if configPath == "" { - defaultPath = ConfigTestFile - } else { - defaultPath = configPath - } - + defaultPath := GetFilePath(configPath) err := CheckConfig() if err != nil { return err @@ -356,14 +362,7 @@ func (c *Config) ReadConfig(configPath string) error { // SaveConfig saves your configuration to your desired path func (c *Config) SaveConfig(configPath string) error { - var defaultPath string - - if configPath == "" { - defaultPath = ConfigFile - } else { - defaultPath = configPath - } - + defaultPath := GetFilePath(configPath) payload, err := json.MarshalIndent(c, "", " ") if c.EncryptConfig == configFileEncryptionEnabled { @@ -389,7 +388,7 @@ func (c *Config) SaveConfig(configPath string) error { func (c *Config) LoadConfig(configPath string) error { err := c.ReadConfig(configPath) if err != nil { - return fmt.Errorf(ErrFailureOpeningConfig, ConfigFile, err) + return fmt.Errorf(ErrFailureOpeningConfig, configPath, err) } err = c.CheckExchangeConfigValues() diff --git a/config/config_test.go b/config/config_test.go index 1da0b2b0..a47c10e1 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -305,3 +305,17 @@ func TestSaveConfig(t *testing.T) { t.Errorf("Test failed. TestSaveConfig.SaveConfig, %s", err2.Error()) } } + +func TestGetFilePath(t *testing.T) { + expected := "blah.json" + result := GetFilePath("blah.json") + if result != "blah.json" { + t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result) + } + + expected = ConfigTestFile + result = GetFilePath("") + if result != expected { + t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result) + } +} diff --git a/config_routes.go b/config_routes.go index 28413d7a..8de3474e 100644 --- a/config_routes.go +++ b/config_routes.go @@ -39,11 +39,11 @@ func SaveAllSettings(w http.ResponseWriter, r *http.Request) { } } //Reload the configuration - err := bot.config.SaveConfig("") + err := bot.config.SaveConfig(bot.configFile) if err != nil { panic(err) } - err = bot.config.LoadConfig("") + err = bot.config.LoadConfig(bot.configFile) if err != nil { panic(err) } diff --git a/main.go b/main.go index 5e30c2d7..df192f48 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "log" "net/http" "os" @@ -60,12 +61,13 @@ type ExchangeMain struct { // Bot contains configuration, portfolio, exchange & ticker data and is the // overarching type across this code base. type Bot struct { - config *config.Config - portfolio *portfolio.Base - exchange ExchangeMain - exchanges []exchange.IBotExchange - tickers []ticker.Ticker - shutdown chan bool + config *config.Config + portfolio *portfolio.Base + exchange ExchangeMain + exchanges []exchange.IBotExchange + tickers []ticker.Ticker + shutdown chan bool + configFile string } var bot Bot @@ -98,10 +100,15 @@ func setupBotExchanges() { func main() { HandleInterrupt() - bot.config = &config.Cfg - log.Printf("Loading config file %s..\n", config.ConfigFile) - err := bot.config.LoadConfig("") + //Handle flags + flag.StringVar(&bot.configFile, "config", config.GetFilePath(""), "config file to load") + flag.Parse() + + bot.config = &config.Cfg + log.Printf("Loading config file %s..\n", bot.configFile) + + err := bot.config.LoadConfig(bot.configFile) if err != nil { log.Fatal(err) } @@ -238,7 +245,7 @@ func HandleInterrupt() { func Shutdown() { log.Println("Bot shutting down..") bot.config.Portfolio = portfolio.Portfolio - err := bot.config.SaveConfig("") + err := bot.config.SaveConfig(bot.configFile) if err != nil { log.Println("Unable to save config.")