Config overwrite bugfix (#363)

* Fix bug where on parsing an alternate new config it will overwrite main config.json in gct dir

* Stop movement of config.json file from root dir when a new config is parsed in

* Stop overiding config.json at gct dir with new config.json from root directory

* RM LN :D

* Fix bug where promptforconfig in config_encryption.go overwrites default config
Ensure periphery command packages do not interact or save over configuration
Ensure tests to not save over or change current testdata/config
This commit is contained in:
Ryan O'Hara-Reid
2019-09-27 16:03:41 +10:00
committed by Adrian Gallagher
parent 6bdbe236c0
commit e2d57540a6
55 changed files with 408 additions and 173 deletions

View File

@@ -1423,24 +1423,38 @@ func GetFilePath(file string) (string, error) {
filepath.Join(newDir, EncryptedConfigFile),
}
// First upgrade the old dir config file if it exists to the corresponding new one
// First upgrade the old dir config file if it exists to the corresponding
// new one
for x := range oldDirs {
_, err := os.Stat(oldDirs[x])
if os.IsNotExist(err) {
continue
}
_, err = os.Stat(newDirs[x])
if !os.IsNotExist(err) {
log.Warnf(log.ConfigMgr,
"config.json file found in root dir and gct dir; cannot overwrite, defaulting to gct dir config.json at %s",
newDirs[x])
return newDirs[x], nil
}
if filepath.Ext(oldDirs[x]) == ".json" {
err = os.Rename(oldDirs[x], newDirs[0])
if err != nil {
return "", err
}
log.Debugf(log.ConfigMgr, "Renamed old config file %s to %s\n", oldDirs[x], newDirs[0])
log.Debugf(log.ConfigMgr,
"Renamed old config file %s to %s\n",
oldDirs[x],
newDirs[0])
} else {
err = os.Rename(oldDirs[x], newDirs[1])
if err != nil {
return "", err
}
log.Debugf(log.ConfigMgr, "Renamed old config file %s to %s\n", oldDirs[x], newDirs[1])
log.Debugf(log.ConfigMgr,
"Renamed old config file %s to %s\n",
oldDirs[x],
newDirs[1])
}
}
@@ -1485,7 +1499,7 @@ func GetFilePath(file string) (string, error) {
// ReadConfig verifies and checks for encryption and verifies the unencrypted
// file contains JSON.
func (c *Config) ReadConfig(configPath string) error {
func (c *Config) ReadConfig(configPath string, dryrun bool) error {
defaultPath, err := GetFilePath(configPath)
if err != nil {
return err
@@ -1510,9 +1524,9 @@ func (c *Config) ReadConfig(configPath string) error {
m.Lock()
IsInitialSetup = true
m.Unlock()
if c.PromptForConfigEncryption() {
if c.PromptForConfigEncryption(configPath, dryrun) {
c.EncryptConfig = configFileEncryptionEnabled
return c.SaveConfig(defaultPath)
return c.SaveConfig(defaultPath, dryrun)
}
}
} else {
@@ -1552,7 +1566,11 @@ func (c *Config) ReadConfig(configPath string) error {
}
// SaveConfig saves your configuration to your desired path
func (c *Config) SaveConfig(configPath string) error {
func (c *Config) SaveConfig(configPath string, dryrun bool) error {
if dryrun {
return nil
}
defaultPath, err := GetFilePath(configPath)
if err != nil {
return err
@@ -1666,8 +1684,8 @@ func (c *Config) CheckConfig() error {
}
// LoadConfig loads your configuration file into your configuration object
func (c *Config) LoadConfig(configPath string) error {
err := c.ReadConfig(configPath)
func (c *Config) LoadConfig(configPath string, dryrun bool) error {
err := c.ReadConfig(configPath, dryrun)
if err != nil {
return fmt.Errorf(ErrFailureOpeningConfig, configPath, err)
}
@@ -1676,7 +1694,7 @@ func (c *Config) LoadConfig(configPath string) error {
}
// UpdateConfig updates the config with a supplied config file
func (c *Config) UpdateConfig(configPath string, newCfg *Config) error {
func (c *Config) UpdateConfig(configPath string, newCfg *Config, dryrun bool) error {
err := newCfg.CheckConfig()
if err != nil {
return err
@@ -1691,12 +1709,12 @@ func (c *Config) UpdateConfig(configPath string, newCfg *Config) error {
c.Webserver = newCfg.Webserver
c.Exchanges = newCfg.Exchanges
err = c.SaveConfig(configPath)
err = c.SaveConfig(configPath, dryrun)
if err != nil {
return err
}
return c.LoadConfig(configPath)
return c.LoadConfig(configPath, dryrun)
}
// GetConfig returns a pointer to a configuration object

View File

@@ -11,6 +11,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
log "github.com/thrasher-corp/gocryptotrader/logger"
"golang.org/x/crypto/scrypt"
)
@@ -32,7 +33,7 @@ var (
)
// PromptForConfigEncryption asks for encryption key
func (c *Config) PromptForConfigEncryption() bool {
func (c *Config) PromptForConfigEncryption(configPath string, dryrun bool) bool {
fmt.Println("Would you like to encrypt your config file (y/n)?")
input := ""
@@ -43,7 +44,10 @@ func (c *Config) PromptForConfigEncryption() bool {
if !common.YesOrNo(input) {
c.EncryptConfig = configFileEncryptionDisabled
c.SaveConfig("")
err := c.SaveConfig(configPath, dryrun)
if err != nil {
log.Errorf(log.ConfigMgr, "cannot save config %s", err)
}
return false
}
return true

View File

@@ -8,7 +8,7 @@ import (
func TestPromptForConfigEncryption(t *testing.T) {
t.Parallel()
if Cfg.PromptForConfigEncryption() {
if Cfg.PromptForConfigEncryption("", true) {
t.Error("Test failed. PromptForConfigEncryption return incorrect bool")
}
}

View File

@@ -22,7 +22,7 @@ const (
func TestGetCurrencyConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetCurrencyConfig LoadConfig error", err)
}
@@ -31,7 +31,7 @@ func TestGetCurrencyConfig(t *testing.T) {
func TestGetExchangeBankAccounts(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetExchangeBankAccounts LoadConfig error", err)
}
@@ -47,7 +47,7 @@ func TestGetExchangeBankAccounts(t *testing.T) {
func TestUpdateExchangeBankAccounts(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. UpdateExchangeBankAccounts LoadConfig error", err)
}
@@ -77,7 +77,7 @@ func TestUpdateExchangeBankAccounts(t *testing.T) {
func TestGetClientBankAccounts(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetClientBankAccounts LoadConfig error", err)
}
@@ -97,7 +97,7 @@ func TestGetClientBankAccounts(t *testing.T) {
func TestUpdateClientBankAccounts(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. UpdateClientBankAccounts LoadConfig error", err)
}
@@ -127,7 +127,7 @@ func TestUpdateClientBankAccounts(t *testing.T) {
func TestCheckClientBankAccounts(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. CheckClientBankAccounts LoadConfig error", err)
}
@@ -262,7 +262,7 @@ func TestPurgeExchangeCredentials(t *testing.T) {
func TestGetCommunicationsConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetCommunicationsConfig LoadConfig error", err)
}
@@ -271,7 +271,7 @@ func TestGetCommunicationsConfig(t *testing.T) {
func TestUpdateCommunicationsConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. UpdateCommunicationsConfig LoadConfig error", err)
}
@@ -283,7 +283,7 @@ func TestUpdateCommunicationsConfig(t *testing.T) {
func TestGetCryptocurrencyProviderConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetCryptocurrencyProviderConfig LoadConfig error", err)
}
@@ -292,7 +292,7 @@ func TestGetCryptocurrencyProviderConfig(t *testing.T) {
func TestUpdateCryptocurrencyProviderConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. UpdateCryptocurrencyProviderConfig LoadConfig error", err)
}
@@ -308,7 +308,7 @@ func TestUpdateCryptocurrencyProviderConfig(t *testing.T) {
func TestCheckCommunicationsConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. CheckCommunicationsConfig LoadConfig error", err)
}
@@ -782,7 +782,7 @@ func TestCheckPairConsistency(t *testing.T) {
func TestSupportsPair(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestSupportsPair. LoadConfig Error: %s", err.Error(),
@@ -979,7 +979,7 @@ func TestGetEnabledPairs(t *testing.T) {
func TestGetEnabledExchanges(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestGetEnabledExchanges. LoadConfig Error: %s", err.Error(),
@@ -1002,7 +1002,7 @@ func TestGetEnabledExchanges(t *testing.T) {
func TestGetDisabledExchanges(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestGetDisabledExchanges. LoadConfig Error: %s", err.Error(),
@@ -1040,7 +1040,7 @@ func TestGetDisabledExchanges(t *testing.T) {
func TestCountEnabledExchanges(t *testing.T) {
GetConfigEnabledExchanges := GetConfig()
err := GetConfigEnabledExchanges.LoadConfig(ConfigTestFile)
err := GetConfigEnabledExchanges.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error(
"Test failed. GetConfigEnabledExchanges load config error: " + err.Error(),
@@ -1054,7 +1054,7 @@ func TestCountEnabledExchanges(t *testing.T) {
func TestGetConfigCurrencyPairFormat(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestGetConfigCurrencyPairFormat. LoadConfig Error: %s", err.Error(),
@@ -1080,7 +1080,7 @@ func TestGetConfigCurrencyPairFormat(t *testing.T) {
func TestGetRequestCurrencyPairFormat(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestGetRequestCurrencyPairFormat. LoadConfig Error: %s", err.Error(),
@@ -1107,7 +1107,7 @@ func TestGetRequestCurrencyPairFormat(t *testing.T) {
func TestGetCurrencyPairDisplayConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. GetCurrencyPairDisplayConfig. LoadConfig Error: %s", err.Error(),
@@ -1123,7 +1123,7 @@ func TestGetCurrencyPairDisplayConfig(t *testing.T) {
func TestGetAllExchangeConfigs(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetAllExchangeConfigs. LoadConfig error", err)
}
@@ -1134,7 +1134,7 @@ func TestGetAllExchangeConfigs(t *testing.T) {
func TestGetExchangeConfig(t *testing.T) {
GetExchangeConfig := GetConfig()
err := GetExchangeConfig.LoadConfig(ConfigTestFile)
err := GetExchangeConfig.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. GetExchangeConfig.LoadConfig Error: %s", err.Error(),
@@ -1153,7 +1153,7 @@ func TestGetExchangeConfig(t *testing.T) {
func TestGetForexProviderConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetForexProviderConfig. LoadConfig error", err)
}
@@ -1170,7 +1170,7 @@ func TestGetForexProviderConfig(t *testing.T) {
func TestGetForexProvidersConfig(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error(err)
}
@@ -1182,7 +1182,7 @@ func TestGetForexProvidersConfig(t *testing.T) {
func TestGetPrimaryForexProvider(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. GetPrimaryForexProvider. LoadConfig error", err)
}
@@ -1202,7 +1202,7 @@ func TestGetPrimaryForexProvider(t *testing.T) {
func TestUpdateExchangeConfig(t *testing.T) {
c := GetConfig()
err := c.LoadConfig(ConfigTestFile)
err := c.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error(err)
}
@@ -1232,7 +1232,7 @@ func TestCheckExchangeConfigValues(t *testing.T) {
t.Error("nil exchanges should throw an err")
}
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Fatal(err)
}
@@ -1548,7 +1548,7 @@ func TestCheckExchangeConfigValues(t *testing.T) {
func TestRetrieveConfigCurrencyPairs(t *testing.T) {
cfg := GetConfig()
err := cfg.LoadConfig(ConfigTestFile)
err := cfg.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf(
"Test failed. TestRetrieveConfigCurrencyPairs.LoadConfig: %s", err.Error(),
@@ -1573,17 +1573,17 @@ func TestRetrieveConfigCurrencyPairs(t *testing.T) {
func TestReadConfig(t *testing.T) {
readConfig := GetConfig()
err := readConfig.ReadConfig(ConfigTestFile)
err := readConfig.ReadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf("Test failed. TestReadConfig %s", err.Error())
}
err = readConfig.ReadConfig("bla")
err = readConfig.ReadConfig("bla", true)
if err == nil {
t.Error("Test failed. TestReadConfig error cannot be nil")
}
err = readConfig.ReadConfig("")
err = readConfig.ReadConfig("", true)
if err != nil {
t.Error("Test failed. TestReadConfig error")
}
@@ -1591,12 +1591,12 @@ func TestReadConfig(t *testing.T) {
func TestLoadConfig(t *testing.T) {
loadConfig := GetConfig()
err := loadConfig.LoadConfig(ConfigTestFile)
err := loadConfig.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Error("Test failed. TestLoadConfig " + err.Error())
}
err = loadConfig.LoadConfig("testy")
err = loadConfig.LoadConfig("testy", true)
if err == nil {
t.Error("Test failed. TestLoadConfig ")
}
@@ -1604,11 +1604,11 @@ func TestLoadConfig(t *testing.T) {
func TestSaveConfig(t *testing.T) {
saveConfig := GetConfig()
err := saveConfig.LoadConfig(ConfigTestFile)
err := saveConfig.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf("Test failed. TestSaveConfig.LoadConfig: %s", err.Error())
}
err2 := saveConfig.SaveConfig(ConfigTestFile)
err2 := saveConfig.SaveConfig(ConfigTestFile, true)
if err2 != nil {
t.Errorf("Test failed. TestSaveConfig.SaveConfig, %s", err2.Error())
}
@@ -1687,7 +1687,7 @@ func TestCheckRemoteControlConfig(t *testing.T) {
func TestCheckConfig(t *testing.T) {
var c Config
err := c.LoadConfig(ConfigTestFile)
err := c.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf("Test failed. %s", err)
}
@@ -1700,24 +1700,24 @@ func TestCheckConfig(t *testing.T) {
func TestUpdateConfig(t *testing.T) {
var c Config
err := c.LoadConfig(ConfigTestFile)
err := c.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Errorf("Test failed. %s", err)
}
newCfg := c
err = c.UpdateConfig(ConfigTestFile, &newCfg)
err = c.UpdateConfig(ConfigTestFile, &newCfg, true)
if err != nil {
t.Fatalf("Test failed. %s", err)
}
err = c.UpdateConfig("//non-existantpath\\", &newCfg)
err = c.UpdateConfig("//non-existantpath\\", &newCfg, true)
if err == nil {
t.Fatalf("Test failed. Error should of been thrown for invalid path")
}
newCfg.Currency.Cryptocurrencies = currency.NewCurrenciesFromStringArray([]string{""})
err = c.UpdateConfig(ConfigTestFile, &newCfg)
err = c.UpdateConfig(ConfigTestFile, &newCfg, true)
if err != nil {
t.Errorf("Test failed. %s", err)
}
@@ -1728,14 +1728,14 @@ func TestUpdateConfig(t *testing.T) {
func BenchmarkUpdateConfig(b *testing.B) {
var c Config
err := c.LoadConfig(ConfigTestFile)
err := c.LoadConfig(ConfigTestFile, true)
if err != nil {
b.Errorf("Unable to benchmark UpdateConfig(): %s", err)
}
newCfg := c
for i := 0; i < b.N; i++ {
_ = c.UpdateConfig(ConfigTestFile, &newCfg)
_ = c.UpdateConfig(ConfigTestFile, &newCfg, true)
}
}
@@ -1768,7 +1768,7 @@ func TestCheckLoggerConfig(t *testing.T) {
t.Error("unexpected result")
}
c.LoadConfig(ConfigTestFile)
c.LoadConfig(ConfigTestFile, true)
err = c.CheckLoggerConfig()
if err != nil {
t.Errorf("Failed to create logger with user settings: reason: %v", err)
@@ -1779,7 +1779,7 @@ func TestDisableNTPCheck(t *testing.T) {
t.Parallel()
c := GetConfig()
err := c.LoadConfig(ConfigTestFile)
err := c.LoadConfig(ConfigTestFile, true)
if err != nil {
t.Fatal(err)
}