Config: Fix config version downgrade (#1770)

* Config: Rename DecryptConfigFile to DecryptConfigData

Because this isn't really a file, it's a byte slice

* Config: Rename EncryptConfigFile to EncryptConfigData

Because it's not actually a file

* Config: Fix config version downgrade

Fixes #1769
This commit is contained in:
Gareth Kirwan
2025-02-19 22:27:52 +00:00
committed by GitHub
parent dc2d7770fb
commit 3748c97b12
6 changed files with 157 additions and 122 deletions

View File

@@ -1,7 +1,7 @@
package main
import (
"errors"
"context"
"flag"
"fmt"
"os"
@@ -11,9 +11,10 @@ import (
"github.com/buger/jsonparser"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/config/versions"
)
var commands = []string{"upgrade", "encrypt", "decrypt"}
var commands = []string{"upgrade", "downgrade", "encrypt", "decrypt"}
func main() {
fmt.Println("GoCryptoTrader: config-helper tool")
@@ -22,6 +23,7 @@ func main() {
var in, out, keyStr string
var inplace bool
var version int
fs := flag.NewFlagSet("config", flag.ExitOnError)
fs.Usage = func() { usage(fs) }
@@ -29,6 +31,7 @@ func main() {
fs.StringVar(&out, "out", "[in].out", "The config output file")
fs.BoolVar(&inplace, "edit", false, "Edit; Save result to the original file")
fs.StringVar(&keyStr, "key", "", "The key to use for AES encryption")
fs.IntVar(&version, "version", 0, "The version to downgrade to")
cmd, args := parseCommand(os.Args[1:])
if cmd == "" {
@@ -46,85 +49,68 @@ func main() {
out = in + ".out"
}
key := []byte(keyStr)
var err error
switch cmd {
case "upgrade":
err = upgradeFile(in, out, key)
case "decrypt":
err = encryptWrapper(in, out, key, false, decryptFile)
case "encrypt":
err = encryptWrapper(in, out, key, true, encryptFile)
key := []byte(keyStr)
data := readFile(in)
isEncrypted := config.IsEncrypted(data)
if cmd == "encrypt" && isEncrypted {
fatal("Error: File is already encrypted")
}
if cmd == "decrypt" && !isEncrypted {
fatal("Error: File is already decrypted")
}
if err != nil {
fatal(err.Error())
if len(key) == 0 && (isEncrypted || cmd == "encrypt") {
if key, err = config.PromptForConfigKey(cmd == "encrypt"); err != nil {
fatal(err.Error())
}
}
if isEncrypted {
if data, err = config.DecryptConfigData(data, key); err != nil {
fatal(err.Error())
}
}
switch cmd {
case "decrypt":
if data, err = jsonparser.Set(data, []byte("-1"), "encryptConfig"); err != nil {
fatal("Unable to decrypt config data; Error: " + err.Error())
}
case "downgrade", "upgrade":
if version == 0 {
if cmd == "downgrade" {
fmt.Fprintln(os.Stderr, "Error: downgrade requires a version")
usage(fs)
os.Exit(3)
}
version = versions.LatestVersion
} else if version < 0 {
fmt.Fprintln(os.Stderr, "Error: version must be positive")
usage(fs)
os.Exit(3)
}
if data, err = versions.Manager.Deploy(context.Background(), data, version); err != nil {
fatal("Unable to " + cmd + " config; Error: " + err.Error())
}
if !isEncrypted {
break
}
fallthrough
case "encrypt":
if data, err = config.EncryptConfigData(data, key); err != nil {
fatal("Unable to encrypt config data; Error: " + err.Error())
}
}
if err := file.Write(out, data); err != nil {
fatal("Unable to write output file `" + out + "`; Error: " + err.Error())
}
fmt.Println("Success! File written to " + out)
}
func upgradeFile(in, out string, key []byte) error {
c := &config.Config{
EncryptionKeyProvider: func(_ bool) ([]byte, error) {
if len(key) != 0 {
return key, nil
}
return config.PromptForConfigKey(false)
},
}
if err := c.ReadConfigFromFile(in, true); err != nil {
return err
}
return c.SaveConfigToFile(out)
}
type encryptFunc func(string, []byte) ([]byte, error)
func encryptWrapper(in, out string, key []byte, confirmKey bool, fn encryptFunc) error {
if len(key) == 0 {
var err error
if key, err = config.PromptForConfigKey(confirmKey); err != nil {
return err
}
}
outData, err := fn(in, key)
if err != nil {
return err
}
if err := file.Write(out, outData); err != nil {
return fmt.Errorf("unable to write output file %s; Error: %w", out, err)
}
return nil
}
func encryptFile(in string, key []byte) ([]byte, error) {
if config.IsFileEncrypted(in) {
return nil, errors.New("file is already encrypted")
}
outData, err := config.EncryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to encrypt config data. Error: %w", err)
}
return outData, nil
}
func decryptFile(in string, key []byte) ([]byte, error) {
if !config.IsFileEncrypted(in) {
return nil, errors.New("file is already decrypted")
}
outData, err := config.DecryptConfigFile(readFile(in), key)
if err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
if outData, err = jsonparser.Set(outData, []byte("-1"), "encryptConfig"); err != nil {
return nil, fmt.Errorf("unable to decrypt config data. Error: %w", err)
}
return outData, nil
}
func readFile(in string) []byte {
fileData, err := os.ReadFile(in)
if err != nil {
@@ -152,7 +138,7 @@ func parseCommand(a []string) (cmd string, args []string) {
switch len(cmds) {
case 0:
fmt.Fprintln(os.Stderr, "No command provided")
case 1: //
case 1:
return cmds[0], rem
default:
fmt.Fprintln(os.Stderr, "Too many commands provided: "+strings.Join(cmds, ", "))
@@ -171,6 +157,7 @@ The commands are:
encrypt encrypt infile and write to outfile
decrypt decrypt infile and write to outfile
upgrade upgrade the version of a decrypted config file
downgrade downgrade the version of a decrypted config file to a specific version
The arguments are:`)
fs.PrintDefaults()

View File

@@ -1504,7 +1504,7 @@ func (c *Config) readConfig(d io.Reader) error {
}
}
if j, err = versions.Manager.Deploy(context.Background(), j); err != nil {
if j, err = versions.Manager.Deploy(context.Background(), j, versions.LatestVersion); err != nil {
return err
}
@@ -1595,7 +1595,7 @@ func (c *Config) Save(writerProvider func() (io.Writer, error)) error {
}
c.sessionDK, c.storedSalt = sessionDK, storedSalt
}
payload, err = c.encryptConfigFile(payload)
payload, err = c.encryptConfigData(payload)
if err != nil {
return err
}

View File

@@ -95,8 +95,8 @@ func getSensitiveInput(prompt string) (resp []byte, err error) {
return bytes.TrimRight(resp, "\r\n"), err
}
// EncryptConfigFile encrypts json config data with a key
func EncryptConfigFile(configData, key []byte) ([]byte, error) {
// EncryptConfigData encrypts json config data with a key
func EncryptConfigData(configData, key []byte) ([]byte, error) {
sessionDK, salt, err := makeNewSessionDK(key)
if err != nil {
return nil, err
@@ -105,12 +105,12 @@ func EncryptConfigFile(configData, key []byte) ([]byte, error) {
sessionDK: sessionDK,
storedSalt: salt,
}
return c.encryptConfigFile(configData)
return c.encryptConfigData(configData)
}
// encryptConfigFile encrypts json config data with a key
// encryptConfigData encrypts json config data with a key
// The EncryptConfig field is set to config enabled (1)
func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
func (c *Config) encryptConfigData(configData []byte) ([]byte, error) {
configData, err := jsonparser.Set(configData, []byte("1"), "encryptConfig")
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrSettingEncryptConfig, err)
@@ -135,8 +135,8 @@ func (c *Config) encryptConfigFile(configData []byte) ([]byte, error) {
return appendedFile, nil
}
// DecryptConfigFile decrypts config data with a key
func DecryptConfigFile(d, key []byte) ([]byte, error) {
// DecryptConfigData decrypts config data with a key
func DecryptConfigData(d, key []byte) ([]byte, error) {
return (&Config{}).decryptConfigData(d, key)
}

View File

@@ -59,16 +59,16 @@ func TestPromptForConfigKey(t *testing.T) {
func TestEncryptConfigFile(t *testing.T) {
t.Parallel()
_, err := EncryptConfigFile([]byte("test"), nil)
_, err := EncryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errKeyIsEmpty)
c := &Config{
sessionDK: []byte("a"),
}
_, err = c.encryptConfigFile([]byte(`test`))
_, err = c.encryptConfigData([]byte(`test`))
require.ErrorIs(t, err, ErrSettingEncryptConfig)
_, err = c.encryptConfigFile([]byte(`{"test":1}`))
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.Error(t, err)
require.IsType(t, aes.KeySizeError(1), err)
@@ -79,26 +79,26 @@ func TestEncryptConfigFile(t *testing.T) {
sessionDK: sessDk,
storedSalt: salt,
}
_, err = c.encryptConfigFile([]byte(`{"test":1}`))
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.NoError(t, err)
}
func TestDecryptConfigFile(t *testing.T) {
t.Parallel()
e, err := EncryptConfigFile([]byte(`{"test":1}`), []byte("key"))
e, err := EncryptConfigData([]byte(`{"test":1}`), []byte("key"))
require.NoError(t, err)
d, err := DecryptConfigFile(e, []byte("key"))
d, err := DecryptConfigData(e, []byte("key"))
require.NoError(t, err)
assert.Equal(t, `{"test":1,"encryptConfig":1}`, string(d), "encryptConfig should be set to 1 after first encryption")
_, err = DecryptConfigFile(e, nil)
_, err = DecryptConfigData(e, nil)
require.ErrorIs(t, err, errKeyIsEmpty)
_, err = DecryptConfigFile([]byte("test"), nil)
_, err = DecryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errNoPrefix)
_, err = DecryptConfigFile(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
_, err = DecryptConfigData(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
require.ErrorIs(t, err, errAESBlockSize)
}

View File

@@ -13,9 +13,11 @@ package versions
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"slices"
"strconv"
"sync"
@@ -24,12 +26,17 @@ import (
"github.com/thrasher-corp/gocryptotrader/common"
)
// LatestVersion used as version param to Deploy to automatically use the latest version
const LatestVersion = -1
var (
errMissingVersion = errors.New("missing version")
errVersionIncompatible = errors.New("version does not implement ConfigVersion or ExchangeVersion")
errModifyingExchange = errors.New("error modifying exchange config")
errNoVersions = errors.New("error retrieving latest config version: No config versions are registered")
errApplyingVersion = errors.New("error applying version")
errConfigVersion = errors.New("version in config file is higher than the latest available version")
errTargetVersion = errors.New("target downgrade version is higher than the latest available version")
)
// ConfigVersion is a version that affects the general configuration
@@ -55,16 +62,23 @@ type manager struct {
var Manager = &manager{}
// Deploy upgrades or downgrades the config between versions
func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {
// Pass LatestVersion for version to use the latest version automatically
// Prints an error an exits if the config file version or version param is not registered
func (m *manager) Deploy(ctx context.Context, j []byte, version int) ([]byte, error) {
if err := m.checkVersions(); err != nil {
return j, err
}
target, err := m.latest()
latest, err := m.latest()
if err != nil {
return j, err
}
target := latest
if version != LatestVersion {
target = version
}
m.m.RLock()
defer m.m.RUnlock()
@@ -75,49 +89,66 @@ func (m *manager) Deploy(ctx context.Context, j []byte) ([]byte, error) {
current = -1
case err != nil:
return j, fmt.Errorf("%w `version`: %w", common.ErrGettingField, err)
}
switch {
case target == current:
return j, nil
case latest < current:
warnVersionNotRegistered(current, latest, errConfigVersion)
return j, errConfigVersion
case target > latest:
warnVersionNotRegistered(target, latest, errTargetVersion)
return j, errTargetVersion
}
for current != target {
next := current + 1
action := "upgrade"
patchVersion := current + 1
action := "upgrade to"
configMethod := ConfigVersion.UpgradeConfig
exchMethod := ExchangeVersion.UpgradeExchange
if target < current {
next = current - 1
action = "downgrade"
patchVersion = current
action = "downgrade from"
configMethod = ConfigVersion.DowngradeConfig
exchMethod = ExchangeVersion.DowngradeExchange
}
log.Printf("Running %s to config version %v\n", action, next)
log.Printf("Running %s config version %v\n", action, patchVersion)
patch := m.versions[next]
patch := m.versions[patchVersion]
if cPatch, ok := patch.(ConfigVersion); ok {
if j, err = configMethod(cPatch, ctx, j); err != nil {
return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err)
return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err)
}
}
if ePatch, ok := patch.(ExchangeVersion); ok {
if j, err = exchangeDeploy(ctx, ePatch, exchMethod, j); err != nil {
return j, fmt.Errorf("%w %s to %v: %w", errApplyingVersion, action, next, err)
return j, fmt.Errorf("%w %s %v: %w", errApplyingVersion, action, patchVersion, err)
}
}
current = next
current = patchVersion
if target < current {
current = patchVersion - 1
}
if j, err = jsonparser.Set(j, []byte(strconv.Itoa(current)), "version"); err != nil {
return j, fmt.Errorf("%w `version` during %s to %v: %w", common.ErrSettingField, action, next, err)
return j, fmt.Errorf("%w `version` during %s %v: %w", common.ErrSettingField, action, patchVersion, err)
}
}
var out bytes.Buffer
if err = json.Indent(&out, j, "", " "); err != nil {
return j, fmt.Errorf("error formatting json: %w", err)
}
log.Println("Version management finished")
return j, nil
return out.Bytes(), nil
}
func exchangeDeploy(ctx context.Context, patch ExchangeVersion, method func(ExchangeVersion, context.Context, []byte) ([]byte, error), j []byte) ([]byte, error) {
@@ -196,3 +227,11 @@ func (m *manager) checkVersions() error {
}
return nil
}
func warnVersionNotRegistered(current, latest int, msg error) {
fmt.Fprintf(os.Stderr, `
%s ('%d' > '%d')
Switch back to the version of GoCryptoTrader containing config version '%d' and run:
$ ./cmd/config downgrade %d
`, msg, current, latest, current, latest)
}

View File

@@ -13,47 +13,56 @@ import (
func TestDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Deploy(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), LatestVersion)
assert.ErrorIs(t, err, errNoVersions)
m.registerVersion(1, &TestVersion1{})
_, err = m.Deploy(context.Background(), []byte(``))
_, err = m.Deploy(context.Background(), []byte(``), LatestVersion)
require.ErrorIs(t, err, errVersionIncompatible)
m = manager{}
m.registerVersion(0, &Version0{})
_, err = m.Deploy(context.Background(), []byte(`not an object`))
_, err = m.Deploy(context.Background(), []byte(`not an object`), LatestVersion)
require.ErrorIs(t, err, jsonparser.KeyPathNotFoundError, "Must throw the correct error trying to add version to bad json")
require.ErrorIs(t, err, common.ErrSettingField, "Must throw the correct error trying to add version to bad json")
require.ErrorContains(t, err, "version", "Must throw the correct error trying to add version to bad json")
_, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`))
_, err = m.Deploy(context.Background(), []byte(`{"version":"not an int"}`), LatestVersion)
require.ErrorIs(t, err, common.ErrGettingField, "Must throw the correct error trying to get version from bad json")
in := []byte(`{"version":0,"exchanges":[{"name":"Juan"}]}`)
j, err := m.Deploy(context.Background(), in)
j, err := m.Deploy(context.Background(), in, LatestVersion)
require.NoError(t, err)
require.Equal(t, string(in), string(j))
assert.Equal(t, string(in), string(j))
m.registerVersion(1, &Version1{})
j, err = m.Deploy(context.Background(), in)
j, err = m.Deploy(context.Background(), in, LatestVersion)
require.NoError(t, err)
require.Contains(t, string(j), `"version":1`)
assert.Contains(t, string(j), `"version": 1`)
m.versions = m.versions[:1]
j, err = m.Deploy(context.Background(), j)
require.NoError(t, err)
require.Contains(t, string(j), `"version":0`)
_, err = m.Deploy(context.Background(), j, 2)
assert.ErrorIs(t, err, errTargetVersion, "Downgrade to a unregistered version should not be allowed")
m.versions = append(m.versions, &TestVersion2{ConfigErr: true, ExchErr: false}) // Bit hacky, but this will actually work
_, err = m.Deploy(context.Background(), j)
m.versions = append(m.versions, &TestVersion2{ConfigErr: true, ExchErr: false})
_, err = m.Deploy(context.Background(), j, LatestVersion)
require.ErrorIs(t, err, errUpgrade)
m.versions[1] = &TestVersion2{ConfigErr: false, ExchErr: true}
_, err = m.Deploy(context.Background(), in)
m.versions[len(m.versions)-1] = &TestVersion2{ConfigErr: false, ExchErr: true}
_, err = m.Deploy(context.Background(), in, LatestVersion)
require.Implements(t, (*ExchangeVersion)(nil), m.versions[1])
require.ErrorIs(t, err, errUpgrade)
j2, err := m.Deploy(context.Background(), j, 0)
require.NoError(t, err)
assert.Contains(t, string(j2), `"version": 0`, "Explicit downgrade should work correctly")
m.versions = m.versions[:1]
_, err = m.Deploy(context.Background(), j, LatestVersion)
assert.ErrorIs(t, err, errConfigVersion, "Config version ahead of latest version should error")
_, err = m.Deploy(context.Background(), j, 0)
assert.ErrorIs(t, err, errConfigVersion, "Config version ahead of latest version should error")
}
// TestExchangeDeploy exercises exchangeDeploy
@@ -61,7 +70,7 @@ func TestDeploy(t *testing.T) {
func TestExchangeDeploy(t *testing.T) {
t.Parallel()
m := manager{}
_, err := m.Deploy(context.Background(), []byte(``))
_, err := m.Deploy(context.Background(), []byte(``), LatestVersion)
assert.ErrorIs(t, err, errNoVersions)
v := &TestVersion2{}