Use key derivitive function for encryption/decryption of config data

Fixes https://github.com/thrasher-/gocryptotrader/issues/115
This commit is contained in:
Adrian Gallagher
2018-06-04 18:43:13 +10:00
parent e80aaf1448
commit 4903c788b1
12 changed files with 377 additions and 88 deletions

View File

@@ -16,6 +16,7 @@ install:
- go get github.com/thrasher-/socketio
- go get github.com/beatgammit/turnpike
- go get github.com/gorilla/mux
- go get golang.org/x/crypto/scrypt
after_success:
- bash <(curl -s https://codecov.io/bash)

11
Gopkg.lock generated
View File

@@ -49,6 +49,15 @@
revision = "9831f2c3ac1068a78f50999a30db84270f647af6"
version = "v1.1"
[[projects]]
branch = "master"
name = "golang.org/x/crypto"
packages = [
"pbkdf2",
"scrypt"
]
revision = "df8d4716b3472e4a531c33cedbe537dae921a1a9"
[[projects]]
branch = "master"
name = "golang.org/x/net"
@@ -58,6 +67,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "044625dbd4ca2222e3d52af7e25d4636b9c7e7c3e5211694b549e8ed42c2b6d7"
inputs-digest = "b77f3524104c74cc3e20314a7eff2b79bb9cbd19fd81671226ef998e4e816412"
solver-name = "gps-cdcl"
solver-version = 1

View File

@@ -40,3 +40,7 @@
[[constraint]]
branch = "master"
name = "github.com/toorop/go-pusher"
[[constraint]]
branch = "master"
name = "golang.org/x/crypto"

View File

@@ -3,6 +3,7 @@ package common
import (
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
@@ -59,6 +60,24 @@ func NewHTTPClientWithTimeout(t time.Duration) *http.Client {
return h
}
// GetRandomSalt returns a random salt
func GetRandomSalt(input []byte, saltLen int) ([]byte, error) {
if saltLen <= 0 {
return nil, errors.New("salt length is too small")
}
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return nil, err
}
var result []byte
if input != nil {
result = input
}
result = append(result, salt...)
return result, nil
}
// GetMD5 returns a MD5 hash of a byte array
func GetMD5(input []byte) []byte {
hash := md5.New()

View File

@@ -71,6 +71,33 @@ func TestIsValidCryptoAddress(t *testing.T) {
}
}
func TestGetRandomSalt(t *testing.T) {
t.Parallel()
_, err := GetRandomSalt(nil, -1)
if err == nil {
t.Fatal("Test failed. Expected err on negative salt length")
}
salt, err := GetRandomSalt(nil, 10)
if err != nil {
t.Fatal(err)
}
if len(salt) != 10 {
t.Fatal("Test failed. Expected salt of len=10")
}
salt, err = GetRandomSalt([]byte("RAWR"), 12)
if err != nil {
t.Fatal(err)
}
if len(salt) != 16 {
t.Fatal("Test failed. Expected salt of len=16")
}
}
func TestGetMD5(t *testing.T) {
t.Parallel()
var originalString = []byte("I am testing the MD5 function in common!")

View File

@@ -29,6 +29,7 @@ const (
configFileEncryptionDisabled = -1
configPairsLastUpdatedWarningThreshold = 30 // 30 days
configDefaultHTTPTimeout = time.Duration(time.Second * 15)
configMaxAuthFailres = 3
)
// Variables here are mainly alerts and a configuration object
@@ -53,6 +54,8 @@ var (
WarningCurrencyExchangeProvider = "WARNING -- Currency exchange provider invalid valid. Reset to Fixer."
WarningPairsLastUpdatedThresholdExceeded = "WARNING -- Exchange %s: Last manual update of available currency pairs has exceeded %d days. Manual update required!"
Cfg Config
IsInitialSetup bool
testBypass bool
)
// WebserverConfig struct holds the prestart variables for the webserver.
@@ -304,7 +307,7 @@ func (c *Config) CheckExchangeConfigValues() error {
if !exch.SupportsAutoPairUpdates {
lastUpdated := common.UnixTimestampToTime(exch.PairsLastUpdated)
lastUpdated.AddDate(0, 0, configPairsLastUpdatedWarningThreshold)
if lastUpdated.Unix() >= time.Now().Unix() {
if lastUpdated.Unix() <= time.Now().Unix() {
log.Printf(WarningPairsLastUpdatedThresholdExceeded, exch.Name, configPairsLastUpdatedWarningThreshold)
}
}
@@ -406,18 +409,19 @@ func (c *Config) RetrieveConfigCurrencyPairs(enabledOnly bool) error {
// GetFilePath returns the desired config file or the default config file name
// based on if the application is being run under test or normal mode.
func GetFilePath(file string) string {
func GetFilePath(file string) (string, error) {
if file != "" {
return file
return file, nil
}
if flag.Lookup("test.v") != nil {
return ConfigTestFile
if flag.Lookup("test.v") != nil && !testBypass {
return ConfigTestFile, nil
}
exePath, err := common.GetExecutablePath()
if err != nil {
log.Fatalf("Unable to get executable path: %s", err)
return "", err
}
tempPath := exePath + common.GetOSPathSlash()
@@ -427,32 +431,38 @@ func GetFilePath(file string) string {
data, err := common.ReadFile(encPath)
if err == nil {
if ConfirmECS(data) {
return encPath
return encPath, nil
}
err = os.Rename(encPath, cfgPath)
if err != nil {
log.Fatalf("Unable to rename config file: %s", err)
return "", err
}
log.Printf("Renaming non-encrypted config file from %s to %s",
encPath, cfgPath)
return cfgPath
return cfgPath, nil
}
if !ConfirmECS(data) {
return cfgPath
return cfgPath, nil
}
err = os.Rename(cfgPath, encPath)
if err != nil {
log.Fatalf("Unable to rename config file: %s", err)
return "", err
}
log.Printf("Renamed encrypted config file from %s to %s", cfgPath,
encPath)
return encPath
return encPath, nil
}
// ReadConfig verifies and checks for encryption and verifies the unencrypted
// file contains JSON.
func (c *Config) ReadConfig(configPath string) error {
defaultPath := GetFilePath(configPath)
defaultPath, err := GetFilePath(configPath)
if err != nil {
return err
}
file, err := common.ReadFile(defaultPath)
if err != nil {
return err
@@ -469,25 +479,43 @@ func (c *Config) ReadConfig(configPath string) error {
}
if c.EncryptConfig == configFileEncryptionPrompt {
IsInitialSetup = true
if c.PromptForConfigEncryption() {
c.EncryptConfig = configFileEncryptionEnabled
return c.SaveConfig("")
return c.SaveConfig(defaultPath)
}
}
} else {
key, err := PromptForConfigKey()
if err != nil {
return err
}
errCounter := 0
for {
if errCounter >= configMaxAuthFailres {
return errors.New("failed to decrypt config after 3 attempts")
}
key, err := PromptForConfigKey(IsInitialSetup)
if err != nil {
log.Printf("PromptForConfigKey err: %s", err)
errCounter++
continue
}
data, err := DecryptConfigFile(file, key)
if err != nil {
return err
}
var f []byte
f = append(f, file...)
data, err := DecryptConfigFile(f, key)
if err != nil {
log.Printf("DecryptConfigFile err: %s", err)
errCounter++
continue
}
err = ConfirmConfigJSON(data, &c)
if err != nil {
return err
err = ConfirmConfigJSON(data, &c)
if err != nil {
if errCounter < configMaxAuthFailres {
log.Printf("Invalid password.")
}
errCounter++
continue
}
break
}
}
return nil
@@ -495,13 +523,26 @@ func (c *Config) ReadConfig(configPath string) error {
// SaveConfig saves your configuration to your desired path
func (c *Config) SaveConfig(configPath string) error {
defaultPath := GetFilePath(configPath)
defaultPath, err := GetFilePath(configPath)
if err != nil {
return err
}
payload, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
}
if c.EncryptConfig == configFileEncryptionEnabled {
key, err2 := PromptForConfigKey()
if err2 != nil {
return err
var key []byte
var err error
if IsInitialSetup {
key, err = PromptForConfigKey(true)
if err != nil {
return err
}
IsInitialSetup = false
}
payload, err = EncryptConfigFile(payload, key)

View File

@@ -9,17 +9,26 @@ import (
"fmt"
"io"
"log"
"reflect"
"github.com/thrasher-/gocryptotrader/common"
"golang.org/x/crypto/scrypt"
)
const (
// EncryptConfirmString has a the general confirmation string to allow us to
// see if the file is correctly encrypted
EncryptConfirmString = "THORS-HAMMER"
errAESBlockSize = "The config file data is too small for the AES required block size"
errNotAPointer = "Error: parameter interface is not a pointer"
// SaltPrefix string
SaltPrefix = "~GCT~SO~SALTY~"
// SaltRandomLength is the number of random bytes to append after the prefix string
SaltRandomLength = 12
errAESBlockSize = "The config file data is too small for the AES required block size"
)
var (
storedSalt []byte
sessionDK []byte
)
// PromptForConfigEncryption asks for encryption key
@@ -41,33 +50,62 @@ func (c *Config) PromptForConfigEncryption() bool {
}
// PromptForConfigKey asks for configuration key
func PromptForConfigKey() ([]byte, error) {
func PromptForConfigKey(initialSetup bool) ([]byte, error) {
var cryptoKey []byte
for len(cryptoKey) != 32 {
log.Println("Enter password (32 characters):")
for {
log.Println("Please enter in your password: ")
pwPrompt := func(i *[]byte) error {
_, err := fmt.Scanln(i)
if err != nil {
return err
}
_, err := fmt.Scanln(&cryptoKey)
return nil
}
var p1 []byte
err := pwPrompt(&p1)
if err != nil {
return nil, err
}
if len(cryptoKey) > 32 || len(cryptoKey) < 32 {
log.Println("Please re-enter password (32 characters):")
if !initialSetup {
cryptoKey = p1
break
}
var p2 []byte
log.Println("Please re-enter your password: ")
err = pwPrompt(&p2)
if err != nil {
return nil, err
}
if bytes.Equal(p1, p2) {
cryptoKey = p1
break
} else {
log.Printf("Passwords did not match, please try again.")
continue
}
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return cryptoKey, nil
}
// EncryptConfigFile encrypts configuration data that is parsed in with a key
// and returns it as a byte array with an error
func EncryptConfigFile(configData, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
var err error
if len(sessionDK) == 0 {
sessionDK, err = makeNewSessionDK(key)
if err != nil {
return nil, err
}
}
block, err := aes.NewCipher(sessionDK)
if err != nil {
return nil, err
}
@@ -82,6 +120,7 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) {
stream.XORKeyStream(ciphertext[aes.BlockSize:], configData)
appendedFile := []byte(EncryptConfirmString)
appendedFile = append(appendedFile, storedSalt...)
appendedFile = append(appendedFile, ciphertext...)
return appendedFile, nil
}
@@ -90,6 +129,21 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) {
// returns the un-encrypted file as a byte array with an error
func DecryptConfigFile(configData, key []byte) ([]byte, error) {
configData = RemoveECS(configData)
origKey := key
if ConfirmSalt(configData) {
salt := make([]byte, len(SaltPrefix)+SaltRandomLength)
salt = configData[0:len(salt)]
dk, err := getScryptDK(key, salt)
if err != nil {
return nil, err
}
configData = configData[len(salt):]
key = dk
}
blockDecrypt, err := aes.NewCipher(key)
if err != nil {
return nil, err
@@ -105,24 +159,53 @@ func DecryptConfigFile(configData, key []byte) ([]byte, error) {
stream := cipher.NewCFBDecrypter(blockDecrypt, iv)
stream.XORKeyStream(configData, configData)
result := configData
sessionDK, err = makeNewSessionDK(origKey)
if err != nil {
return nil, err
}
return result, nil
}
// ConfirmConfigJSON confirms JSON in file
func ConfirmConfigJSON(file []byte, result interface{}) error {
if !common.StringContains(reflect.TypeOf(result).String(), "*") {
return errors.New(errNotAPointer)
}
return common.JSONDecode(file, &result)
}
// ConfirmSalt checks whether the encrypted data contains a salt
func ConfirmSalt(file []byte) bool {
return bytes.Contains(file, []byte(SaltPrefix))
}
// ConfirmECS confirms that the encryption confirmation string is found
func ConfirmECS(file []byte) bool {
subslice := []byte(EncryptConfirmString)
return bytes.Contains(file, subslice)
return bytes.Contains(file, []byte(EncryptConfirmString))
}
// RemoveECS removes encryption confirmation string
func RemoveECS(file []byte) []byte {
return bytes.Trim(file, EncryptConfirmString)
}
func getScryptDK(key, salt []byte) ([]byte, error) {
if len(key) == 0 {
return nil, errors.New("key is empty")
}
return scrypt.Key(key, salt, 32768, 8, 1, 32)
}
func makeNewSessionDK(key []byte) ([]byte, error) {
var err error
storedSalt, err = common.GetRandomSalt([]byte(SaltPrefix), SaltRandomLength)
if err != nil {
return nil, err
}
dk, err := getScryptDK(key, storedSalt)
if err != nil {
return nil, err
}
return dk, nil
}

View File

@@ -1,7 +1,6 @@
package config
import (
"reflect"
"testing"
"github.com/thrasher-/gocryptotrader/common"
@@ -18,39 +17,72 @@ func TestPromptForConfigEncryption(t *testing.T) {
func TestPromptForConfigKey(t *testing.T) {
t.Parallel()
byteyBite, err := PromptForConfigKey()
byteyBite, err := PromptForConfigKey(true)
if err == nil && len(byteyBite) > 1 {
t.Errorf("Test failed. PromptForConfigKey: %s", err)
}
_, err = PromptForConfigKey(false)
if err == nil {
t.Fatal(err)
}
}
func TestEncryptDecryptConfigFile(t *testing.T) { //Dual function Test
testKey := []byte("12345678901234567890123456789012")
func TestEncryptConfigFile(t *testing.T) {
_, err := EncryptConfigFile([]byte("test"), nil)
if err == nil {
t.Fatal("Test failed. Expected different result")
}
testConfigData, err := common.ReadFile(ConfigTestFile)
sessionDK = []byte("a")
_, err = EncryptConfigFile([]byte("test"), nil)
if err == nil {
t.Fatal("Test failed. Expected different result")
}
sessionDK, err = makeNewSessionDK([]byte("asdf"))
if err != nil {
t.Errorf("Test failed. EncryptConfigFile: %s", err)
}
encryptedFile, err2 := EncryptConfigFile(testConfigData, testKey)
if err2 != nil {
t.Errorf("Test failed. EncryptConfigFile: %s", err2)
}
if reflect.TypeOf(encryptedFile).String() != "[]uint8" {
t.Errorf("Test failed. EncryptConfigFile: Incorrect Type")
t.Fatal(err)
}
decryptedFile, err3 := DecryptConfigFile(encryptedFile, testKey)
if err3 != nil {
t.Errorf("Test failed. DecryptConfigFile: %s", err3)
_, err = EncryptConfigFile([]byte("test"), []byte("key"))
if err != nil {
t.Fatal(err)
}
if reflect.TypeOf(decryptedFile).String() != "[]uint8" {
t.Errorf("Test failed. DecryptConfigFile: Incorrect Type")
}
func TestDecryptConfigFile(t *testing.T) {
sessionDK = nil
result, err := EncryptConfigFile([]byte("test"), []byte("key"))
if err != nil {
t.Fatal(err)
}
result, err = DecryptConfigFile(result, nil)
if err == nil {
t.Fatal("Test failed. Expected different result")
}
result, err = DecryptConfigFile([]byte("test"), nil)
if err == nil {
t.Fatal("Test failed. Expected different result")
}
result, err = DecryptConfigFile([]byte("test"), []byte("AAAAAAAAAAAAAAAA"))
if err == nil {
t.Fatalf("Test failed. Expected %s", errAESBlockSize)
}
result, err = EncryptConfigFile([]byte("test"), []byte("key"))
if err != nil {
t.Fatal(err)
}
result, err = DecryptConfigFile(result, []byte("key"))
if err != nil {
t.Fatal(err)
}
// unmarshalled := Config{} // racecondition
// err4 := json.Unmarshal(decryptedFile, &unmarshalled)
// if err4 != nil {
// t.Errorf("Test failed. DecryptConfigFile: %s", err3)
// }
}
func TestConfirmConfigJSON(t *testing.T) {
@@ -60,16 +92,9 @@ func TestConfirmConfigJSON(t *testing.T) {
t.Errorf("Test failed. testConfirmJSON: %s", err)
}
err2 := ConfirmConfigJSON(testConfirmJSON, &result)
if err2 != nil {
t.Errorf("Test failed. testConfirmJSON: %s", err2)
}
if result == nil {
t.Errorf("Test failed. testConfirmJSON: Error Unmarshalling JSON")
}
err3 := ConfirmConfigJSON(testConfirmJSON, result)
if err3 == nil {
t.Errorf("Test failed. testConfirmJSON: %s", err3)
err = ConfirmConfigJSON(testConfirmJSON, &result)
if err != nil || result == nil {
t.Errorf("Test failed. testConfirmJSON: %s", err)
}
}
@@ -92,3 +117,12 @@ func TestRemoveECS(t *testing.T) {
t.Errorf("Test failed. TestConfirmECS: Error ECS not deleted.")
}
}
func TestMakeNewSessionDK(t *testing.T) {
t.Parallel()
_, err := makeNewSessionDK(nil)
if err == nil {
t.Fatal("Test failed. makeNewSessionDK passed with nil key")
}
}

View File

@@ -325,6 +325,12 @@ func TestCheckExchangeConfigValues(t *testing.T) {
)
}
checkExchangeConfigValues.Exchanges[0].HTTPTimeout = 0
checkExchangeConfigValues.CheckExchangeConfigValues()
if checkExchangeConfigValues.Exchanges[0].HTTPTimeout == 0 {
t.Fatalf("Test failed. Expected exchange %s to have updated HTTPTimeout value", checkExchangeConfigValues.Exchanges[0].Name)
}
checkExchangeConfigValues.Exchanges[0].APIKey = "Key"
checkExchangeConfigValues.Exchanges[0].APISecret = "Secret"
checkExchangeConfigValues.Exchanges[0].AuthenticatedAPISupport = true
@@ -428,6 +434,14 @@ func TestCheckWebserverConfigValues(t *testing.T) {
)
}
checkWebserverConfigValues.Webserver.WebsocketMaxAuthFailures = -1
checkWebserverConfigValues.CheckWebserverConfigValues()
if checkWebserverConfigValues.Webserver.WebsocketMaxAuthFailures != 3 {
t.Error(
"Test failed. checkWebserverConfigValues.CheckWebserverConfigValues error",
)
}
checkWebserverConfigValues.Webserver.ListenAddress = ":0"
err = checkWebserverConfigValues.CheckWebserverConfigValues()
if err == nil {
@@ -531,14 +545,55 @@ func TestSaveConfig(t *testing.T) {
func TestGetFilePath(t *testing.T) {
expected := "blah.json"
result := GetFilePath("blah.json")
result, _ := GetFilePath("blah.json")
if result != "blah.json" {
t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result)
}
expected = ConfigTestFile
result = GetFilePath("")
result, _ = GetFilePath("")
if result != expected {
t.Errorf("Test failed. TestGetFilePath: expected %s got %s", expected, result)
}
testBypass = true
result, _ = GetFilePath("")
}
func TestCheckConfig(t *testing.T) {
var c Config
err := c.LoadConfig(ConfigTestFile)
if err != nil {
t.Errorf("Test failed. %s", err)
}
err = c.CheckConfig()
if err != nil {
t.Fatal(err)
}
}
func TestUpdateConfig(t *testing.T) {
var c Config
err := c.LoadConfig(ConfigTestFile)
if err != nil {
t.Errorf("Test failed. %s", err)
}
newCfg := c
err = c.UpdateConfig("", newCfg)
if err != nil {
t.Fatalf("Test failed. %s", err)
}
err = c.UpdateConfig("//non-existantpath\\", newCfg)
if err == nil {
t.Fatalf("Test failed. Error should of been thrown for invalid path")
}
newCfg.Cryptocurrencies = ""
err = c.UpdateConfig("", newCfg)
if err == nil {
t.Fatalf("Test failed. Error should of been thrown for empty cryptocurrencies")
}
}

View File

@@ -46,8 +46,13 @@ func main() {
bot.shutdown = make(chan bool)
HandleInterrupt()
defaultPath, err := config.GetFilePath("")
if err != nil {
log.Fatal(err)
}
//Handle flags
flag.StringVar(&bot.configFile, "config", config.GetFilePath(""), "config file to load")
flag.StringVar(&bot.configFile, "config", defaultPath, "config file to load")
dryrun := flag.Bool("dryrun", false, "dry runs bot, doesn't save config file")
version := flag.Bool("version", false, "retrieves current GoCryptoTrader version")
flag.Parse()
@@ -66,7 +71,7 @@ func main() {
fmt.Println(BuildVersion(false))
log.Printf("Loading config file %s..\n", bot.configFile)
err := bot.config.LoadConfig(bot.configFile)
err = bot.config.LoadConfig(bot.configFile)
if err != nil {
log.Fatal(err)
}

View File

@@ -20,7 +20,12 @@ func main() {
var inFile, outFile, key string
var encrypt bool
var err error
configFile := config.GetFilePath("")
configFile, err := config.GetFilePath("")
if err != nil {
log.Fatal(err)
}
flag.StringVar(&inFile, "infile", configFile, "The config input file to process.")
flag.StringVar(&outFile, "outfile", configFile+".out", "The config output file.")
flag.BoolVar(&encrypt, "encrypt", true, "Whether to encrypt or decrypt.")
@@ -30,7 +35,7 @@ func main() {
log.Println("GoCryptoTrader: config-helper tool.")
if key == "" {
result, errf := config.PromptForConfigKey()
result, errf := config.PromptForConfigKey(false)
if errf != nil {
log.Fatal("Unable to obtain encryption/decryption key.")
}

View File

@@ -57,14 +57,20 @@ func getOnlineOfflinePortfolio(coins []portfolio.Coin, online bool) {
func main() {
var inFile, key string
flag.StringVar(&inFile, "infile", config.GetFilePath(""), "The config input file to process.")
defaultCfg, err := config.GetFilePath("")
if err != nil {
log.Fatal(err)
}
flag.StringVar(&inFile, "infile", defaultCfg, "The config input file to process.")
flag.StringVar(&key, "key", "", "The key to use for AES encryption.")
flag.Parse()
log.Println("GoCryptoTrader: portfolio tool.")
var cfg config.Config
var err = cfg.LoadConfig(inFile)
err = cfg.LoadConfig(inFile)
if err != nil {
log.Fatal(err)
}