Files
gocryptotrader/config/config_encryption_test.go
Adrian Gallagher 98a390b181 config: Fix TestPromptForConfigEncryption race (#1929)
* Here's how I resolved a race condition in the encryption prompt tests:

I've refactored `promptForConfigEncryption` to accept an `io.Reader`. This allows tests to use `strings.NewReader` instead of relying on the global `os.Stdin`. This change isolates the input source for `TestPromptForConfigEncryption`, preventing concurrent access conflicts with `TestPromptForConfigKey` which uses `withInteractiveResponse` to manipulate `os.Stdin`.

The original issue manifested as a data race detected by `go test -race` between these two test functions due to their parallel execution (`t.Parallel()`) and their interaction with the shared `os.Stdin` resource.

Here are the changes I made:
- `config/config_encryption.go`:
    - Modified `promptForConfigEncryption` to `promptForConfigEncryption(r io.Reader) (bool, error)`.
    - Introduced `PromptForConfigEncryption()` as a public wrapper that calls the refactored function with `os.Stdin` for application use.
- `config/config.go`:
    - Updated the call site for prompting config encryption to use the new `PromptForConfigEncryption()` wrapper.
- `config/config_encryption_test.go`:
    - Updated `TestPromptForConfigEncryption` to call the (now unexported) `promptForConfigEncryption` with `strings.NewReader` for various input scenarios.
    - `t.Parallel()` was maintained for `TestPromptForConfigEncryption`.

To verify, I confirmed that running `go test -race ./config/...` shows the previously reported race condition is no longer present. All tests in the `config` package now pass with the race detector enabled.

* Refactor(config): Remove redundant loop variable capture in test

I've removed the explicit `tc := tc` line in `TestPromptForConfigEncryption`
as it is no longer necessary for Go versions 1.22 and later.
The project's Go version (1.24.3 as per CI) handles loop variable
scoping correctly for parallel subtests, making this capture redundant.

This change is a follow-up to the fix for the race condition in issue #1928,
addressing your feedback on code style. No functional changes are introduced
by this commit.

* Style(config): Remove explicit empty string in test case

Refactors the "input_empty_eof" test case in
`TestPromptForConfigEncryption` to remove the explicit
assignment of `input: ""`. This relies on Go's default
behavior for struct field initialization (a string field
defaults to an empty string), making the code more concise.

This change is a follow-up to previous refactorings for issue #1928,
addressing your feedback on code style. No functional changes
are introduced by this commit.

* Update config/config_encryption_test.go

Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com>

* linter: Make indenting a happy bing

* config: Rid PromptForConfigEncryption

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com>
2025-06-07 12:46:25 +10:00

431 lines
12 KiB
Go

package config
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPromptForConfigEncryption(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expectedBool bool
expectedError error
}{
{
name: "input_y",
input: "y\n",
expectedBool: true,
},
{
name: "input_n",
input: "n\n",
},
{
name: "input_yes",
input: "yes\n",
expectedBool: true,
},
{
name: "input_no",
input: "no\n",
},
{
name: "input_invalid",
input: "invalid\n",
},
{
name: "input_empty_eof",
expectedError: io.EOF,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
reader := strings.NewReader(tc.input)
confirm, err := promptForConfigEncryption(reader)
require.ErrorIs(t, err, tc.expectedError)
require.Equal(t, tc.expectedBool, confirm)
})
}
}
func TestPromptForConfigKey(t *testing.T) {
t.Parallel()
withInteractiveResponse(t, "\n\n", func() {
_, err := PromptForConfigKey(false)
require.ErrorIs(t, err, io.EOF)
})
withInteractiveResponse(t, "pass\n", func() {
k, err := PromptForConfigKey(false)
require.NoError(t, err)
assert.Equal(t, "pass", string(k))
})
withInteractiveResponse(t, "what\nwhat\n", func() {
k, err := PromptForConfigKey(true)
require.NoError(t, err)
assert.Equal(t, "what", string(k))
})
withInteractiveResponse(t, "what\nno\n", func() {
_, err := PromptForConfigKey(true)
require.ErrorIs(t, err, io.EOF, "PromptForConfigKey must EOF when asking for another input but none is given")
})
withInteractiveResponse(t, "what\nno\nwhat\nno\nwhat\nno\n", func() {
_, err := PromptForConfigKey(true)
require.ErrorIs(t, err, io.EOF, "PromptForConfigKey must EOF when asking for another input but none is given")
})
withInteractiveResponse(t, "what\nno\nwhat\nno\nwhat\nwhat\n", func() {
k, err := PromptForConfigKey(true)
require.NoError(t, err, "PromptForConfigKey must not error if the user eventually answers consistently")
assert.Equal(t, "what", string(k))
})
}
func TestEncryptConfigData(t *testing.T) {
t.Parallel()
_, err := EncryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errKeyIsEmpty)
c := &Config{
sessionDK: []byte("a"),
}
_, err = c.encryptConfigData([]byte(`test`))
require.ErrorIs(t, err, ErrSettingEncryptConfig)
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.Error(t, err)
require.IsType(t, aes.KeySizeError(1), err)
sessDk, salt, err := makeNewSessionDK([]byte("asdf"))
require.NoError(t, err, "makeNewSessionDK must not error")
c = &Config{
sessionDK: sessDk,
storedSalt: salt,
}
_, err = c.encryptConfigData([]byte(`{"test":1}`))
require.NoError(t, err)
}
func TestDecryptConfigData(t *testing.T) {
t.Parallel()
e, err := EncryptConfigData([]byte(`{"test":1}`), []byte("key"))
require.NoError(t, err)
d, err := DecryptConfigData(e, []byte("key"))
require.NoError(t, err)
assert.Equal(t, `{"test":1,"encryptConfig":1}`, string(d), "encryptConfig should be set to 1 after first encryption")
_, err = DecryptConfigData(e, nil)
require.ErrorIs(t, err, errKeyIsEmpty)
_, err = DecryptConfigData([]byte("test"), nil)
require.ErrorIs(t, err, errNoPrefix)
_, err = DecryptConfigData(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
require.ErrorIs(t, err, errAESBlockSize)
sessionDK, salt, err := makeNewSessionDK([]byte("key"))
require.NoError(t, err, "makeNewSessionDK must not error")
encData, err := legacyEncrypt(t, salt, []byte(`{"test":123}`), sessionDK)
require.NoError(t, err)
data, err := DecryptConfigData(encData, []byte("key"))
require.NoError(t, err)
assert.Equal(t, `{"test":123}`, string(data))
badVersion := make([]byte, len(encryptionPrefix)+len(encryptionVersionPrefix)+versionSize)
copy(badVersion, encryptionPrefix)
copy(badVersion[len(encryptionPrefix):], encryptionVersionPrefix)
binary.BigEndian.PutUint16(badVersion[len(encryptionPrefix)+len(encryptionVersionPrefix):], 69)
_, err = DecryptConfigData(badVersion, []byte("key"))
require.ErrorIs(t, err, errUnsupportedEncryptionVersion)
}
func legacyEncrypt(t *testing.T, salt, data, key []byte) ([]byte, error) {
t.Helper()
ciphertext, err := aesCFBEncrypt(t, data, key)
if err != nil {
return nil, err
}
encData := append(bytes.Clone(encryptionPrefix), salt...)
encData = append(encData, ciphertext...)
return encData, nil
}
func aesCFBEncrypt(t *testing.T, data, key []byte) ([]byte, error) {
t.Helper()
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
ciphertext := make([]byte, aes.BlockSize+len(data))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
stream := cipher.NewCFBEncrypter(block, iv) //nolint:staticcheck // For testing purposes only
stream.XORKeyStream(ciphertext[aes.BlockSize:], data)
return ciphertext, nil
}
func TestEncryptAESGCMCiphertext(t *testing.T) {
t.Parallel()
_, err := decryptAESGCMCiphertext(nil, nil)
require.ErrorIs(t, err, aes.KeySizeError(0))
validKey := []byte(strings.Repeat("A", 16))
block, err := aes.NewCipher(validKey)
require.NoError(t, err)
aead, err := cipher.NewGCMWithRandomNonce(block)
require.NoError(t, err)
ciphertext := aead.Seal(nil, nil, []byte("MEOWMEOWMEOWMEOWMEOW"), nil)
data, err := decryptAESGCMCiphertext(ciphertext, validKey)
require.NoError(t, err)
assert.Equal(t, "MEOWMEOWMEOWMEOWMEOW", string(data))
}
func TestDecryptAESCFBCiphertext(t *testing.T) {
t.Parallel()
_, err := decryptAESCFBCiphertext(nil, nil)
require.ErrorIs(t, err, errAESBlockSize)
_, err = decryptAESCFBCiphertext([]byte("WOOFWOOFWOOFWOOFWOOF"), []byte("A"))
require.ErrorIs(t, err, aes.KeySizeError(1))
validKey := []byte(strings.Repeat("A", 16))
data, err := aesCFBEncrypt(t, []byte("WOOFWOOFWOOFWOOFWOOF"), validKey)
require.NoError(t, err)
data, err = decryptAESCFBCiphertext(data, validKey)
require.NoError(t, err)
assert.Equal(t, "WOOFWOOFWOOFWOOFWOOF", string(data))
}
func TestIsEncrypted(t *testing.T) {
t.Parallel()
assert.True(t, IsEncrypted(encryptionPrefix))
assert.False(t, IsEncrypted([]byte("mhmmm. Donuts.")))
}
func TestMakeNewSessionDK(t *testing.T) {
t.Parallel()
if _, _, err := makeNewSessionDK(nil); err == nil {
t.Fatal("makeNewSessionDK passed with nil key")
}
}
func TestEncryptTwiceReusesSaltButNewCipher(t *testing.T) {
c := &Config{
EncryptConfig: 1,
}
tempDir := t.TempDir()
// Prepare input
passFile, err := os.CreateTemp(tempDir, "*.pw")
if err != nil {
t.Fatalf("Problem creating temp file at %s: %s\n", tempDir, err)
}
_, err = passFile.WriteString("pass\npass\n")
if err != nil {
t.Error(err)
}
err = passFile.Close()
if err != nil {
t.Error(err)
}
// Temporarily replace Stdin with a custom input
oldIn := os.Stdin
t.Cleanup(func() { os.Stdin = oldIn })
os.Stdin, err = os.Open(passFile.Name())
if err != nil {
t.Fatalf("Problem opening temp file at %s: %s\n", passFile.Name(), err)
}
t.Cleanup(func() {
err = os.Stdin.Close()
if err != nil {
t.Error(err)
}
})
// Save encrypted config
enc1 := filepath.Join(tempDir, "encrypted.dat")
err = c.SaveConfigToFile(enc1)
if err != nil {
t.Fatalf("Problem storing config in file %s: %s\n", enc1, err)
}
// Save again
enc2 := filepath.Join(tempDir, "encrypted2.dat")
err = c.SaveConfigToFile(enc2)
if err != nil {
t.Fatalf("Problem storing config in file %s: %s\n", enc2, err)
}
data1, err := os.ReadFile(enc1)
if err != nil {
t.Fatalf("Problem reading file %s: %s\n", enc1, err)
}
data2, err := os.ReadFile(enc2)
if err != nil {
t.Fatalf("Problem reading file %s: %s\n", enc2, err)
}
// length of prefix + salt
l := len(encryptionPrefix) + len(saltPrefix) + saltRandomLength
// Even though prefix, including salt with the random bytes is the same
if !bytes.Equal(data1[:l], data2[:l]) {
t.Error("Salt is not reused.")
}
// the cipher text should not be
if bytes.Equal(data1, data2) {
t.Error("Encryption key must have been reused as cipher texts are the same")
}
}
func TestSaveAndReopenEncryptedConfig(t *testing.T) {
c := &Config{}
c.Name = "myCustomName"
c.EncryptConfig = 1
tempDir := t.TempDir()
// Save encrypted config
enc := filepath.Join(tempDir, "encrypted.dat")
withInteractiveResponse(t, "pass\npass\n", func() {
err := c.SaveConfigToFile(enc)
require.NoError(t, err, "SaveConfigToFile must not error")
})
readConf := &Config{}
withInteractiveResponse(t, "pass\n", func() {
// Load with no existing state, key is read from the prepared file
err := readConf.ReadConfigFromFile(enc, true)
require.NoError(t, err, "ReadConfigFromFile must not error")
})
assert.Equal(t, "myCustomName", readConf.Name, "Name should be correct")
assert.Equal(t, 1, readConf.EncryptConfig, "EncryptConfig should be set correctly")
}
func TestReadConfigWithPrompt(t *testing.T) {
// Prepare temp dir
tempDir := t.TempDir()
// Ensure we'll get the prompt when loading
c := &Config{
EncryptConfig: 0,
}
// Save config
testConfigFile := filepath.Join(tempDir, "config.json")
err := c.SaveConfigToFile(testConfigFile)
require.NoError(t, err, "SaveConfigToFile must not error")
// Run the test
c = &Config{}
withInteractiveResponse(t, "y\npass\npass\n", func() {
err = c.ReadConfigFromFile(testConfigFile, false)
require.NoError(t, err, "ReadConfigFromFile must not error")
})
if err != nil {
t.Fatalf("Problem reading config file at %s: %s\n", testConfigFile, err)
}
// Verify results
data, err := os.ReadFile(testConfigFile)
if err != nil {
t.Fatalf("Problem reading saved file at %s: %s\n", testConfigFile, err)
}
if c.EncryptConfig != fileEncryptionEnabled {
t.Error("Config encryption flag should be set after prompts")
}
assert.True(t, IsEncrypted(data), "data should be encrypted after prompts")
}
func TestReadEncryptedConfigFromReader(t *testing.T) {
t.Parallel()
c := &Config{
EncryptionKeyProvider: func(_ bool) ([]byte, error) { return []byte("pass"), nil },
}
// Encrypted conf for: `{"name":"test"}` with key `pass`
confBytes := []byte{84, 72, 79, 82, 83, 45, 72, 65, 77, 77, 69, 82, 126, 71, 67, 84, 126, 83, 79, 126, 83, 65, 76, 84, 89, 126, 246, 110, 128, 3, 30, 168, 172, 160, 198, 176, 136, 62, 152, 155, 253, 176, 16, 48, 52, 246, 44, 29, 151, 47, 217, 226, 178, 12, 218, 113, 248, 172, 195, 232, 136, 104, 9, 199, 20, 4, 71, 4, 253, 249}
err := c.readConfig(bytes.NewReader(confBytes))
require.NoError(t, err)
assert.Equal(t, "test", c.Name)
// Change the salt
confBytes[20] = 0
err = c.readConfig(bytes.NewReader(confBytes))
require.ErrorIs(t, err, errDecryptFailed)
}
// TestSaveConfigToFileWithErrorInPasswordPrompt should preserve the original file
func TestSaveConfigToFileWithErrorInPasswordPrompt(t *testing.T) {
c := &Config{
Name: "test",
EncryptConfig: fileEncryptionEnabled,
}
testData := []byte("testdata")
f, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err, "CreateTemp must not error")
targetFile := f.Name()
_, err = io.Copy(f, bytes.NewReader(testData))
require.NoError(t, err, "io.Copy must not error")
require.NoError(t, f.Close(), "file Close must not error")
withInteractiveResponse(t, "\n\n", func() {
err = c.SaveConfigToFile(targetFile)
require.ErrorIs(t, err, io.EOF, "SaveConfigToFile must not error")
})
data, err := os.ReadFile(targetFile)
require.NoError(t, err, "ReadFile must not error")
assert.Equal(t, testData, data)
}
func withInteractiveResponse(tb testing.TB, response string, fn func()) {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.in")
require.NoError(tb, err, "CreateTemp must not error")
defer f.Close()
_, err = f.WriteString(response)
require.NoError(tb, err, "WriteString must not error")
_, err = f.Seek(0, 0)
require.NoError(tb, err, "Seek must not error")
defer func(orig *os.File) { os.Stdin = orig }(os.Stdin)
os.Stdin = f
fn()
}