Fix default config loading behaviour and add config flag

This commit is contained in:
Adrian Gallagher
2017-08-02 15:43:47 +10:00
parent 6afaefa5bf
commit 9e8397225f
4 changed files with 53 additions and 33 deletions

View File

@@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"errors"
"flag"
"fmt"
"log"
"os"
@@ -80,10 +81,10 @@ type Config struct {
Name string
EncryptConfig int
Cryptocurrencies string
Portfolio portfolio.Base `json:"PortfolioAddresses"`
SMS SMSGlobalConfig `json:"SMSGlobal"`
Webserver WebserverConfig `json:"Webserver"`
Exchanges []ExchangeConfig `json:"Exchanges"`
Portfolio portfolio.Base `json:"PortfolioAddresses"`
SMS SMSGlobalConfig `json:"SMSGlobal"`
Webserver WebserverConfig `json:"Webserver"`
Exchanges []ExchangeConfig `json:"Exchanges"`
}
// ExchangeConfig holds all the information needed for each enabled Exchange.
@@ -284,6 +285,18 @@ func (c *Config) RetrieveConfigCurrencyPairs() error {
return nil
}
// GetFilePath returns the desired config file or the default config file name
// based on if the application is being run under test or normal mode.
func GetFilePath(file string) string {
if file != "" {
return file
}
if flag.Lookup("test.v") == nil {
return ConfigFile
}
return ConfigTestFile
}
// CheckConfig checks to see if there is an old configuration filename and path
// if found it will change it to correct filename.
func CheckConfig() error {
@@ -301,14 +314,7 @@ func CheckConfig() error {
// ReadConfig verifies and checks for encryption and verifies the unencrypted
// file contains JSON.
func (c *Config) ReadConfig(configPath string) error {
var defaultPath string
if configPath == "" {
defaultPath = ConfigTestFile
} else {
defaultPath = configPath
}
defaultPath := GetFilePath(configPath)
err := CheckConfig()
if err != nil {
return err
@@ -356,14 +362,7 @@ func (c *Config) ReadConfig(configPath string) error {
// SaveConfig saves your configuration to your desired path
func (c *Config) SaveConfig(configPath string) error {
var defaultPath string
if configPath == "" {
defaultPath = ConfigFile
} else {
defaultPath = configPath
}
defaultPath := GetFilePath(configPath)
payload, err := json.MarshalIndent(c, "", " ")
if c.EncryptConfig == configFileEncryptionEnabled {
@@ -389,7 +388,7 @@ func (c *Config) SaveConfig(configPath string) error {
func (c *Config) LoadConfig(configPath string) error {
err := c.ReadConfig(configPath)
if err != nil {
return fmt.Errorf(ErrFailureOpeningConfig, ConfigFile, err)
return fmt.Errorf(ErrFailureOpeningConfig, configPath, err)
}
err = c.CheckExchangeConfigValues()

View File

@@ -305,3 +305,17 @@ func TestSaveConfig(t *testing.T) {
t.Errorf("Test failed. TestSaveConfig.SaveConfig, %s", err2.Error())
}
}
func TestGetFilePath(t *testing.T) {
expected := "blah.json"
result := GetFilePath("blah.json")
if result != "blah.json" {
t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result)
}
expected = ConfigTestFile
result = GetFilePath("")
if result != expected {
t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result)
}
}

View File

@@ -39,11 +39,11 @@ func SaveAllSettings(w http.ResponseWriter, r *http.Request) {
}
}
//Reload the configuration
err := bot.config.SaveConfig("")
err := bot.config.SaveConfig(bot.configFile)
if err != nil {
panic(err)
}
err = bot.config.LoadConfig("")
err = bot.config.LoadConfig(bot.configFile)
if err != nil {
panic(err)
}

27
main.go
View File

@@ -1,6 +1,7 @@
package main
import (
"flag"
"log"
"net/http"
"os"
@@ -60,12 +61,13 @@ type ExchangeMain struct {
// Bot contains configuration, portfolio, exchange & ticker data and is the
// overarching type across this code base.
type Bot struct {
config *config.Config
portfolio *portfolio.Base
exchange ExchangeMain
exchanges []exchange.IBotExchange
tickers []ticker.Ticker
shutdown chan bool
config *config.Config
portfolio *portfolio.Base
exchange ExchangeMain
exchanges []exchange.IBotExchange
tickers []ticker.Ticker
shutdown chan bool
configFile string
}
var bot Bot
@@ -98,10 +100,15 @@ func setupBotExchanges() {
func main() {
HandleInterrupt()
bot.config = &config.Cfg
log.Printf("Loading config file %s..\n", config.ConfigFile)
err := bot.config.LoadConfig("")
//Handle flags
flag.StringVar(&bot.configFile, "config", config.GetFilePath(""), "config file to load")
flag.Parse()
bot.config = &config.Cfg
log.Printf("Loading config file %s..\n", bot.configFile)
err := bot.config.LoadConfig(bot.configFile)
if err != nil {
log.Fatal(err)
}
@@ -238,7 +245,7 @@ func HandleInterrupt() {
func Shutdown() {
log.Println("Bot shutting down..")
bot.config.Portfolio = portfolio.Portfolio
err := bot.config.SaveConfig("")
err := bot.config.SaveConfig(bot.configFile)
if err != nil {
log.Println("Unable to save config.")