mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 23:16:45 +00:00
Config: refactor config file loaders (#577)
* Config: fix don't create empty dir when resolving path * Config: refactor config file loaders * add a layer of abstraction so that config can be loaded from non-files * use io.Reader / io.Writer abstraction to separate data operations from file operations * remove dryrun option from SaveConfig - now it always saves * rename read and save methods to mention file operations * log error when encryption prompt fails * as the user didn't make a choice, we'd prompt again next time the file is loaded * add file.Writer tests * skip permissions test for windows * defer creating the writer on save to the last moment * this avoids truncating file when there is error with password prompt * add a test * tests with StdIn cannot run in parallel
This commit is contained in:
@@ -227,7 +227,7 @@ func makeExchange(exch *exchange) error {
|
||||
}
|
||||
|
||||
configTestFile.Exchanges = append(configTestFile.Exchanges, newExchConfig)
|
||||
err = configTestFile.SaveConfig(exchangeConfigPath, false)
|
||||
err = configTestFile.SaveConfigToFile(exchangeConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestNewExchange(t *testing.T) {
|
||||
t.Fatalf("unable to remove exchange config for %s, manual removal required\n",
|
||||
testExchangeName)
|
||||
}
|
||||
if err := cfg.SaveConfig(exchangeConfigPath, false); err != nil {
|
||||
if err := cfg.SaveConfigToFile(exchangeConfigPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,19 @@ func Write(file string, data []byte) error {
|
||||
return ioutil.WriteFile(file, data, 0770)
|
||||
}
|
||||
|
||||
// Writer creates a writer to a file or returns an error if it fails. This
|
||||
// func also ensures that all files are set to this permission (only rw access
|
||||
// for the running user and the group the user is a member of)
|
||||
func Writer(file string) (*os.File, error) {
|
||||
basePath := filepath.Dir(file)
|
||||
if !Exists(basePath) {
|
||||
if err := os.MkdirAll(basePath, 0770); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0770)
|
||||
}
|
||||
|
||||
// Move moves a file from a source path to a destination path
|
||||
// This must be used across the codebase for compatibility with Docker volumes
|
||||
// and Golang (fixes Invalid cross-device link when using os.Rename)
|
||||
|
||||
@@ -194,3 +194,92 @@ func TestWriteAsCSV(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter(t *testing.T) {
|
||||
type args struct {
|
||||
file string
|
||||
}
|
||||
tmp, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
testData := `data`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *os.File
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
args: args{"//invalid-nofile\\"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{""},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "relative newfile",
|
||||
args: args{"newfile"},
|
||||
},
|
||||
{
|
||||
name: "deep file",
|
||||
args: args{filepath.Join(tmp, "new", "file", "multiple", "sub", "paths")},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Writer(tt.args.file)
|
||||
if err != nil {
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Writer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer os.Remove(got.Name())
|
||||
fileInfo, err := os.Stat(got.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !fileInfo.Mode().IsRegular() {
|
||||
t.Fatalf("Writer() error = expected to get a file %s", got.Name())
|
||||
}
|
||||
_, err = got.WriteString(testData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = got.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if data, err := ioutil.ReadFile(got.Name()); err != nil || string(data) != testData {
|
||||
t.Errorf("Could not write the file, or contents were wrong: expected = %s, got =%s", testData, string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterNoPermissionFails(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skip file permissions")
|
||||
}
|
||||
temp, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(temp)
|
||||
err = os.Chmod(temp, 0555)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = Writer(filepath.Join(temp, "path", "to", "somefile"))
|
||||
if err == nil {
|
||||
t.Error("Expected to fail when no permissions, but writer succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
182
config/config.go
182
config/config.go
@@ -2,11 +2,13 @@ package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -1545,94 +1547,133 @@ func migrateConfig(configFile, targetDir string) (string, error) {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// ReadConfig verifies and checks for encryption and verifies the unencrypted
|
||||
// file contains JSON.
|
||||
// Prompts for decryption key, if target file is encrypted
|
||||
func (c *Config) ReadConfig(configPath string, dryrun bool) error {
|
||||
// ReadConfigFromFile reads the configuration from the given file
|
||||
// if target file is encrypted, prompts for encryption key
|
||||
// Also - if not in dryrun mode - it checks if the configuration needs to be encrypted
|
||||
// and stores the file as encrypted, if necessary (prompting for enryption key)
|
||||
func (c *Config) ReadConfigFromFile(configPath string, dryrun bool) error {
|
||||
defaultPath, _, err := GetFilePath(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileData, err := ioutil.ReadFile(defaultPath)
|
||||
confFile, err := os.Open(defaultPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer confFile.Close()
|
||||
result, wasEncrypted, err := ReadConfig(confFile, func() ([]byte, error) { return PromptForConfigKey(false) })
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading config %w", err)
|
||||
}
|
||||
// Override values in the current config
|
||||
*c = *result
|
||||
|
||||
if !ConfirmECS(fileData) {
|
||||
err = json.Unmarshal(fileData, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.EncryptConfig == fileEncryptionDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.EncryptConfig == fileEncryptionPrompt {
|
||||
confirm, err := promptForConfigEncryption()
|
||||
if err == nil {
|
||||
if confirm {
|
||||
c.EncryptConfig = fileEncryptionEnabled
|
||||
return c.SaveConfig(defaultPath, dryrun)
|
||||
}
|
||||
|
||||
c.EncryptConfig = fileEncryptionDisabled
|
||||
err := c.SaveConfig(configPath, dryrun)
|
||||
if err != nil {
|
||||
log.Errorf(log.ConfigMgr, "Cannot save config. Error: %s\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if dryrun || wasEncrypted || c.EncryptConfig == fileEncryptionDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
errCounter := 0
|
||||
for {
|
||||
if errCounter >= maxAuthFailures {
|
||||
return errors.New("failed to decrypt config after 3 attempts")
|
||||
}
|
||||
key, err := PromptForConfigKey(false)
|
||||
if c.EncryptConfig == fileEncryptionPrompt {
|
||||
confirm, err := promptForConfigEncryption()
|
||||
if err != nil {
|
||||
log.Errorf(log.ConfigMgr, "PromptForConfigKey err: %s", err)
|
||||
errCounter++
|
||||
continue
|
||||
log.Errorf(log.ConfigMgr, "The encryption prompt failed, ignoring for now, next time we will prompt again. Error: %s\n", err)
|
||||
return nil
|
||||
}
|
||||
if confirm {
|
||||
c.EncryptConfig = fileEncryptionEnabled
|
||||
return c.SaveConfigToFile(defaultPath)
|
||||
}
|
||||
|
||||
var f []byte
|
||||
f = append(f, fileData...)
|
||||
data, err := c.decryptConfigData(f, key)
|
||||
c.EncryptConfig = fileEncryptionDisabled
|
||||
err = c.SaveConfigToFile(defaultPath)
|
||||
if err != nil {
|
||||
log.Errorf(log.ConfigMgr, "decryptConfigData err: %s", err)
|
||||
errCounter++
|
||||
continue
|
||||
log.Errorf(log.ConfigMgr, "Cannot save config. Error: %s\n", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, c)
|
||||
if err != nil {
|
||||
if errCounter < maxAuthFailures {
|
||||
log.Error(log.ConfigMgr, "Invalid password.")
|
||||
}
|
||||
errCounter++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveConfig saves your configuration to your desired path
|
||||
// prompts for encryption key, if necessary
|
||||
func (c *Config) SaveConfig(configPath string, dryrun bool) error {
|
||||
if dryrun {
|
||||
return nil
|
||||
// ReadConfig verifies and checks for encryption and loads the config from a JSON object.
|
||||
// Prompts for decryption key, if target data is encrypted.
|
||||
// Returns the loaded configuration and whether it was encrypted.
|
||||
func ReadConfig(configReader io.Reader, keyProvider func() ([]byte, error)) (*Config, bool, error) {
|
||||
reader := bufio.NewReader(configReader)
|
||||
|
||||
pref, err := reader.Peek(len(EncryptConfirmString))
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if !ConfirmECS(pref) {
|
||||
// Read unencrypted configuration
|
||||
decoder := json.NewDecoder(reader)
|
||||
c := &Config{}
|
||||
err = decoder.Decode(c)
|
||||
return c, false, err
|
||||
}
|
||||
|
||||
conf, err := readEncryptedConfWithKey(reader, keyProvider)
|
||||
return conf, true, err
|
||||
}
|
||||
|
||||
// readEncryptedConf reads encrypted configuration and requests key from provider
|
||||
func readEncryptedConfWithKey(reader *bufio.Reader, keyProvider func() ([]byte, error)) (*Config, error) {
|
||||
fileData, err := ioutil.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for errCounter := 0; errCounter < maxAuthFailures; errCounter++ {
|
||||
key, err := keyProvider()
|
||||
if err != nil {
|
||||
log.Errorf(log.ConfigMgr, "PromptForConfigKey err: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
var c *Config
|
||||
c, err = readEncryptedConf(bytes.NewReader(fileData), key)
|
||||
if err != nil {
|
||||
log.Error(log.ConfigMgr, "Could not decrypt and deserialise data with given key. Invalid password?", err)
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
return nil, errors.New("failed to decrypt config after 3 attempts")
|
||||
}
|
||||
|
||||
func readEncryptedConf(reader io.Reader, key []byte) (*Config, error) {
|
||||
c := &Config{}
|
||||
data, err := c.decryptConfigData(reader, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, c)
|
||||
return c, err
|
||||
}
|
||||
|
||||
// SaveConfigToFile saves your configuration to your desired path as a JSON object.
|
||||
// The function encrypts the data and prompts for encryption key, if necessary
|
||||
func (c *Config) SaveConfigToFile(configPath string) error {
|
||||
defaultPath, _, err := GetFilePath(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var writer *os.File
|
||||
provider := func() (io.Writer, error) {
|
||||
writer, err = file.Writer(defaultPath)
|
||||
return writer, err
|
||||
}
|
||||
defer func() {
|
||||
if writer != nil {
|
||||
writer.Close()
|
||||
}
|
||||
}()
|
||||
return c.Save(provider, func() ([]byte, error) { return PromptForConfigKey(true) })
|
||||
}
|
||||
|
||||
// Save saves your configuration to the writer as a JSON object
|
||||
// with encryption, if configured
|
||||
// If there is an error when preparing the data to store, the writer is never requested
|
||||
func (c *Config) Save(writerProvider func() (io.Writer, error), keyProvider func() ([]byte, error)) error {
|
||||
payload, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1642,7 +1683,7 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error {
|
||||
// Ensure we have the key from session or from user
|
||||
if len(c.sessionDK) == 0 {
|
||||
var key []byte
|
||||
key, err = PromptForConfigKey(true)
|
||||
key, err = keyProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1658,7 +1699,12 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return file.Write(defaultPath, payload)
|
||||
configWriter, err := writerProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.Copy(configWriter, bytes.NewReader(payload))
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckRemoteControlConfig checks to see if the old c.Webserver field is used
|
||||
@@ -1759,7 +1805,7 @@ func (c *Config) CheckConfig() error {
|
||||
|
||||
// LoadConfig loads your configuration file into your configuration object
|
||||
func (c *Config) LoadConfig(configPath string, dryrun bool) error {
|
||||
err := c.ReadConfig(configPath, dryrun)
|
||||
err := c.ReadConfigFromFile(configPath, dryrun)
|
||||
if err != nil {
|
||||
return fmt.Errorf(ErrFailureOpeningConfig, configPath, err)
|
||||
}
|
||||
@@ -1783,9 +1829,11 @@ func (c *Config) UpdateConfig(configPath string, newCfg *Config, dryrun bool) er
|
||||
c.Webserver = newCfg.Webserver
|
||||
c.Exchanges = newCfg.Exchanges
|
||||
|
||||
err = c.SaveConfig(configPath, dryrun)
|
||||
if err != nil {
|
||||
return err
|
||||
if !dryrun {
|
||||
err = c.SaveConfigToFile(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return c.LoadConfig(configPath, dryrun)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/common"
|
||||
@@ -41,6 +42,11 @@ func promptForConfigEncryption() (bool, error) {
|
||||
return common.YesOrNo(input), nil
|
||||
}
|
||||
|
||||
// Unencrypted provides the default key provider implementation for unencrypted files
|
||||
func Unencrypted() ([]byte, error) {
|
||||
return nil, errors.New("encryption key was requested, no key provided")
|
||||
}
|
||||
|
||||
// PromptForConfigKey asks for configuration key
|
||||
// if initialSetup is true, the password needs to be repeated
|
||||
func PromptForConfigKey(initialSetup bool) ([]byte, error) {
|
||||
@@ -94,7 +100,7 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) {
|
||||
return c.encryptConfigFile(configData)
|
||||
}
|
||||
|
||||
// EncryptConfigFile encrypts configuration data that is parsed in with a key
|
||||
// encryptConfigFile encrypts configuration data that is parsed in with a key
|
||||
// and returns it as a byte array with an error
|
||||
func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(c.sessionDK)
|
||||
@@ -120,26 +126,33 @@ func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
|
||||
// DecryptConfigFile decrypts configuration data with the supplied key and
|
||||
// returns the un-encrypted data as a byte array with an error
|
||||
func DecryptConfigFile(configData, key []byte) ([]byte, error) {
|
||||
return (&Config{}).decryptConfigData(configData, key)
|
||||
reader := bytes.NewReader(configData)
|
||||
return (&Config{}).decryptConfigData(reader, key)
|
||||
}
|
||||
|
||||
// decryptConfigData decrypts configuration data with the supplied key and
|
||||
// returns the un-encrypted data as a byte array with an error
|
||||
func (c *Config) decryptConfigData(configData, key []byte) ([]byte, error) {
|
||||
configData = removeECS(configData)
|
||||
func (c *Config) decryptConfigData(configReader io.Reader, key []byte) ([]byte, error) {
|
||||
err := skipECS(configReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
origKey := key
|
||||
configData, err := ioutil.ReadAll(configReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ConfirmSalt(configData) {
|
||||
salt := make([]byte, len(SaltPrefix)+SaltRandomLength)
|
||||
salt = configData[0:len(salt)]
|
||||
|
||||
dk, err := getScryptDK(key, salt)
|
||||
key, err = getScryptDK(key, salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configData = configData[len(salt):]
|
||||
key = dk
|
||||
}
|
||||
|
||||
blockDecrypt, err := aes.NewCipher(key)
|
||||
@@ -177,9 +190,18 @@ func ConfirmECS(file []byte) bool {
|
||||
return bytes.Contains(file, []byte(EncryptConfirmString))
|
||||
}
|
||||
|
||||
// removeECS removes encryption confirmation string
|
||||
func removeECS(file []byte) []byte {
|
||||
return bytes.Trim(file, EncryptConfirmString)
|
||||
// skipECS skips encryption confirmation string
|
||||
// or errors, if the prefix wasn't found
|
||||
func skipECS(file io.Reader) error {
|
||||
buf := make([]byte, len(EncryptConfirmString))
|
||||
_, err := io.ReadFull(file, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(buf) != EncryptConfirmString {
|
||||
return errors.New("data does not start with ECS")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getScryptDK(key, salt []byte) ([]byte, error) {
|
||||
|
||||
@@ -2,9 +2,12 @@ package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -108,9 +111,16 @@ func TestRemoveECS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ECStest := []byte(EncryptConfirmString)
|
||||
isremoved := removeECS(ECStest)
|
||||
reader := bytes.NewReader(ECStest)
|
||||
err := skipECS(reader)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if string(isremoved) != "" {
|
||||
// Attempt read
|
||||
var buf []byte
|
||||
_, err = reader.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("TestConfirmECS: Error ECS not deleted.")
|
||||
}
|
||||
}
|
||||
@@ -151,13 +161,13 @@ func TestEncryptTwiceReusesSaltButNewCipher(t *testing.T) {
|
||||
|
||||
// Save encrypted config
|
||||
enc1 := filepath.Join(tempDir, "encrypted.dat")
|
||||
err = c.SaveConfig(enc1, false)
|
||||
err = c.SaveConfigToFile(enc1)
|
||||
if err != nil {
|
||||
t.Fatalf("Problem storing config in file %s: %s\n", enc1, err)
|
||||
}
|
||||
// Save again
|
||||
enc2 := filepath.Join(tempDir, "encrypted2.dat")
|
||||
err = c.SaveConfig(enc2, false)
|
||||
err = c.SaveConfigToFile(enc2)
|
||||
if err != nil {
|
||||
t.Fatalf("Problem storing config in file %s: %s\n", enc2, err)
|
||||
}
|
||||
@@ -191,37 +201,20 @@ func TestSaveAndReopenEncryptedConfig(t *testing.T) {
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Prepare password
|
||||
passFile, err := ioutil.TempFile(tempDir, "*.pw")
|
||||
if err != nil {
|
||||
t.Fatalf("Problem creating temp file at %s: %s\n", tempDir, err)
|
||||
}
|
||||
passFile.WriteString("pass\npass\n")
|
||||
passFile.Close()
|
||||
|
||||
// Temporarily replace Stdin with a custom input
|
||||
cleanup := setAnswersFile(t, passFile.Name())
|
||||
defer cleanup()
|
||||
|
||||
// Save encrypted config
|
||||
enc := filepath.Join(tempDir, "encrypted.dat")
|
||||
err = c.SaveConfig(enc, false)
|
||||
err = withInteractiveResponse(t, "pass\npass\n", func() error {
|
||||
return c.SaveConfigToFile(enc)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Problem storing config in file %s: %s\n", enc, err)
|
||||
}
|
||||
|
||||
// Prepare password input for decryption
|
||||
passFile, err = os.Open(passFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Problem opening temp file at %s: %s\n", passFile.Name(), err)
|
||||
}
|
||||
defer passFile.Close()
|
||||
os.Stdin = passFile
|
||||
|
||||
// Clean session
|
||||
readConf := &Config{}
|
||||
// Load with no existing state, key is read from the prepared file
|
||||
err = readConf.ReadConfig(enc, true)
|
||||
err = withInteractiveResponse(t, "pass\n", func() error {
|
||||
// Load with no existing state, key is read from the prepared file
|
||||
return readConf.ReadConfigFromFile(enc, true)
|
||||
})
|
||||
|
||||
// Verify
|
||||
if err != nil {
|
||||
@@ -264,26 +257,19 @@ func TestReadConfigWithPrompt(t *testing.T) {
|
||||
|
||||
// Save config
|
||||
testConfigFile := filepath.Join(tempDir, "config.json")
|
||||
err = c.SaveConfig(testConfigFile, false)
|
||||
err = c.SaveConfigToFile(testConfigFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Problem saving config file in %s: %s\n", tempDir, err)
|
||||
}
|
||||
|
||||
// Answers to the prompt
|
||||
responseFile, err := ioutil.TempFile(tempDir, "*.in")
|
||||
if err != nil {
|
||||
t.Fatalf("Problem creating temp file at %s: %s\n", tempDir, err)
|
||||
}
|
||||
responseFile.WriteString("y\npass\npass\n")
|
||||
responseFile.Close()
|
||||
|
||||
// Temporarily replace Stdin with a custom input
|
||||
cleanup := setAnswersFile(t, responseFile.Name())
|
||||
defer cleanup()
|
||||
|
||||
// Run the test
|
||||
c = &Config{}
|
||||
c.ReadConfig(testConfigFile, false)
|
||||
err = withInteractiveResponse(t, "y\npass\npass\n", func() error {
|
||||
return c.ReadConfigFromFile(testConfigFile, false)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Problem reading config file at %s: %s\n", testConfigFile, err)
|
||||
}
|
||||
|
||||
// Verify results
|
||||
data, err := ioutil.ReadFile(testConfigFile)
|
||||
@@ -297,3 +283,89 @@ func TestReadConfigWithPrompt(t *testing.T) {
|
||||
t.Error("Config file should be encrypted after prompts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEncryptedConfigFromReader(t *testing.T) {
|
||||
keyProvider := func() ([]byte, error) { return []byte("pass"), nil }
|
||||
// Encrypted conf for: `{"name":"test"}` with key `pass`
|
||||
confBytes := []byte{84, 72, 79, 82, 83, 45, 72, 65, 77, 77, 69, 82, 126, 71, 67, 84, 126, 83, 79, 126, 83, 65, 76, 84, 89, 126, 246, 110, 128, 3, 30, 168, 172, 160, 198, 176, 136, 62, 152, 155, 253, 176, 16, 48, 52, 246, 44, 29, 151, 47, 217, 226, 178, 12, 218, 113, 248, 172, 195, 232, 136, 104, 9, 199, 20, 4, 71, 4, 253, 249}
|
||||
conf, encrypted, err := ReadConfig(bytes.NewReader(confBytes), keyProvider)
|
||||
if err != nil {
|
||||
t.Errorf("TestReadConfig %s", err)
|
||||
}
|
||||
if !encrypted {
|
||||
t.Errorf("Expected encrypted config %s", err)
|
||||
}
|
||||
if conf.Name != "test" {
|
||||
t.Errorf("Conf not properly loaded %s", err)
|
||||
}
|
||||
|
||||
// Change the salt
|
||||
confBytes[20] = 0
|
||||
conf, _, err = ReadConfig(bytes.NewReader(confBytes), keyProvider)
|
||||
if err == nil {
|
||||
t.Errorf("Expected unable to decrypt, but got %+v", conf)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfigToFileWithErrorInPasswordPrompt should preserve the original file
|
||||
func TestSaveConfigToFileWithErrorInPasswordPrompt(t *testing.T) {
|
||||
c := &Config{
|
||||
Name: "test",
|
||||
EncryptConfig: fileEncryptionEnabled,
|
||||
}
|
||||
testData := []byte("testdata")
|
||||
f, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
targetFile := f.Name()
|
||||
defer os.Remove(targetFile)
|
||||
|
||||
_, err = io.Copy(f, bytes.NewReader(testData))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = withInteractiveResponse(t, "\n\n", func() error {
|
||||
err = c.SaveConfigToFile(targetFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := ioutil.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(data, testData) {
|
||||
t.Errorf("Expected contents %s, but was %s", testData, data)
|
||||
}
|
||||
}
|
||||
|
||||
func withInteractiveResponse(t *testing.T, response string, body func() error) error {
|
||||
// Answers to the prompt
|
||||
responseFile, err := ioutil.TempFile("", "*.in")
|
||||
if err != nil {
|
||||
return fmt.Errorf("problem creating temp file: %w", err)
|
||||
}
|
||||
_, err = responseFile.WriteString(response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("problem writing to temp file at %s: %w", responseFile.Name(), err)
|
||||
}
|
||||
err = responseFile.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("problem closing temp file at %s: %w", responseFile.Name(), err)
|
||||
}
|
||||
defer os.Remove(responseFile.Name())
|
||||
|
||||
// Temporarily replace Stdin with a custom input
|
||||
cleanup := setAnswersFile(t, responseFile.Name())
|
||||
defer cleanup()
|
||||
return body()
|
||||
}
|
||||
|
||||
@@ -1646,14 +1646,33 @@ func TestRetrieveConfigCurrencyPairs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
func TestReadConfigFromFile(t *testing.T) {
|
||||
readConfig := GetConfig()
|
||||
err := readConfig.ReadConfig(TestFile, true)
|
||||
err := readConfig.ReadConfigFromFile(TestFile, true)
|
||||
if err != nil {
|
||||
t.Errorf("TestReadConfig %s", err.Error())
|
||||
}
|
||||
|
||||
err = readConfig.ReadConfig("bla", true)
|
||||
err = readConfig.ReadConfigFromFile("bla", true)
|
||||
if err == nil {
|
||||
t.Error("TestReadConfig error cannot be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadConfigFromReader(t *testing.T) {
|
||||
confString := `{"name":"test"}`
|
||||
conf, encrypted, err := ReadConfig(strings.NewReader(confString), Unencrypted)
|
||||
if err != nil {
|
||||
t.Errorf("TestReadConfig %s", err)
|
||||
}
|
||||
if encrypted {
|
||||
t.Errorf("Expected unencrypted config %s", err)
|
||||
}
|
||||
if conf.Name != "test" {
|
||||
t.Errorf("Conf not properly loaded %s", err)
|
||||
}
|
||||
|
||||
_, _, err = ReadConfig(strings.NewReader("{}"), Unencrypted)
|
||||
if err == nil {
|
||||
t.Error("TestReadConfig error cannot be nil")
|
||||
}
|
||||
@@ -1672,13 +1691,19 @@ func TestLoadConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig(t *testing.T) {
|
||||
func TestSaveConfigToFile(t *testing.T) {
|
||||
saveConfig := GetConfig()
|
||||
err := saveConfig.LoadConfig(TestFile, true)
|
||||
if err != nil {
|
||||
t.Errorf("TestSaveConfig.LoadConfig: %s", err.Error())
|
||||
}
|
||||
err2 := saveConfig.SaveConfig(TestFile, true)
|
||||
f, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
t.Errorf("TestSaveConfig create file: %s", err)
|
||||
}
|
||||
f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
err2 := saveConfig.SaveConfigToFile(f.Name())
|
||||
if err2 != nil {
|
||||
t.Errorf("TestSaveConfig.SaveConfig, %s", err2.Error())
|
||||
}
|
||||
@@ -1804,7 +1829,7 @@ func TestUpdateConfig(t *testing.T) {
|
||||
t.Fatalf("%s", err)
|
||||
}
|
||||
|
||||
err = c.UpdateConfig("//non-existantpath\\", &newCfg, true)
|
||||
err = c.UpdateConfig("//non-existantpath\\", &newCfg, false)
|
||||
if err == nil {
|
||||
t.Fatalf("Error should have been thrown for invalid path")
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func loadConfigWithSettings(settings *Settings, flagSet map[string]bool) (*confi
|
||||
log.Printf("Loading config file %s..\n", filePath)
|
||||
|
||||
conf := &config.Cfg
|
||||
err = conf.ReadConfig(filePath, settings.EnableDryRun)
|
||||
err = conf.ReadConfigFromFile(filePath, settings.EnableDryRun)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(config.ErrFailureOpeningConfig, filePath, err)
|
||||
}
|
||||
@@ -541,7 +541,7 @@ func (bot *Engine) Stop() {
|
||||
}
|
||||
|
||||
if !bot.Settings.EnableDryRun {
|
||||
err := bot.Config.SaveConfig(bot.Settings.ConfigFile, false)
|
||||
err := bot.Config.SaveConfigToFile(bot.Settings.ConfigFile)
|
||||
if err != nil {
|
||||
gctlog.Errorln(gctlog.Global, "Unable to save config.")
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user