mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 23:16:45 +00:00
* Here's how I resolved a race condition in the encryption prompt tests:
I've refactored `promptForConfigEncryption` to accept an `io.Reader`. This allows tests to use `strings.NewReader` instead of relying on the global `os.Stdin`. This change isolates the input source for `TestPromptForConfigEncryption`, preventing concurrent access conflicts with `TestPromptForConfigKey` which uses `withInteractiveResponse` to manipulate `os.Stdin`.
The original issue manifested as a data race detected by `go test -race` between these two test functions due to their parallel execution (`t.Parallel()`) and their interaction with the shared `os.Stdin` resource.
Here are the changes I made:
- `config/config_encryption.go`:
- Modified `promptForConfigEncryption` to `promptForConfigEncryption(r io.Reader) (bool, error)`.
- Introduced `PromptForConfigEncryption()` as a public wrapper that calls the refactored function with `os.Stdin` for application use.
- `config/config.go`:
- Updated the call site for prompting config encryption to use the new `PromptForConfigEncryption()` wrapper.
- `config/config_encryption_test.go`:
- Updated `TestPromptForConfigEncryption` to call the (now unexported) `promptForConfigEncryption` with `strings.NewReader` for various input scenarios.
- `t.Parallel()` was maintained for `TestPromptForConfigEncryption`.
To verify, I confirmed that running `go test -race ./config/...` shows the previously reported race condition is no longer present. All tests in the `config` package now pass with the race detector enabled.
* Refactor(config): Remove redundant loop variable capture in test
I've removed the explicit `tc := tc` line in `TestPromptForConfigEncryption`
as it is no longer necessary for Go versions 1.22 and later.
The project's Go version (1.24.3 as per CI) handles loop variable
scoping correctly for parallel subtests, making this capture redundant.
This change is a follow-up to the fix for the race condition in issue #1928,
addressing your feedback on code style. No functional changes are introduced
by this commit.
* Style(config): Remove explicit empty string in test case
Refactors the "input_empty_eof" test case in
`TestPromptForConfigEncryption` to remove the explicit
assignment of `input: ""`. This relies on Go's default
behavior for struct field initialization (a string field
defaults to an empty string), making the code more concise.
This change is a follow-up to previous refactorings for issue #1928,
addressing your feedback on code style. No functional changes
are introduced by this commit.
* Update config/config_encryption_test.go
Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com>
* linter: Make indenting a happy bing
* config: Rid PromptForConfigEncryption
---------
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Ryan O'Hara-Reid <oharareid.ryan@gmail.com>
431 lines
12 KiB
Go
431 lines
12 KiB
Go
package config
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestPromptForConfigEncryption(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expectedBool bool
|
|
expectedError error
|
|
}{
|
|
{
|
|
name: "input_y",
|
|
input: "y\n",
|
|
expectedBool: true,
|
|
},
|
|
{
|
|
name: "input_n",
|
|
input: "n\n",
|
|
},
|
|
{
|
|
name: "input_yes",
|
|
input: "yes\n",
|
|
expectedBool: true,
|
|
},
|
|
{
|
|
name: "input_no",
|
|
input: "no\n",
|
|
},
|
|
{
|
|
name: "input_invalid",
|
|
input: "invalid\n",
|
|
},
|
|
{
|
|
name: "input_empty_eof",
|
|
expectedError: io.EOF,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
reader := strings.NewReader(tc.input)
|
|
confirm, err := promptForConfigEncryption(reader)
|
|
require.ErrorIs(t, err, tc.expectedError)
|
|
require.Equal(t, tc.expectedBool, confirm)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPromptForConfigKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
withInteractiveResponse(t, "\n\n", func() {
|
|
_, err := PromptForConfigKey(false)
|
|
require.ErrorIs(t, err, io.EOF)
|
|
})
|
|
|
|
withInteractiveResponse(t, "pass\n", func() {
|
|
k, err := PromptForConfigKey(false)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "pass", string(k))
|
|
})
|
|
|
|
withInteractiveResponse(t, "what\nwhat\n", func() {
|
|
k, err := PromptForConfigKey(true)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "what", string(k))
|
|
})
|
|
|
|
withInteractiveResponse(t, "what\nno\n", func() {
|
|
_, err := PromptForConfigKey(true)
|
|
require.ErrorIs(t, err, io.EOF, "PromptForConfigKey must EOF when asking for another input but none is given")
|
|
})
|
|
|
|
withInteractiveResponse(t, "what\nno\nwhat\nno\nwhat\nno\n", func() {
|
|
_, err := PromptForConfigKey(true)
|
|
require.ErrorIs(t, err, io.EOF, "PromptForConfigKey must EOF when asking for another input but none is given")
|
|
})
|
|
|
|
withInteractiveResponse(t, "what\nno\nwhat\nno\nwhat\nwhat\n", func() {
|
|
k, err := PromptForConfigKey(true)
|
|
require.NoError(t, err, "PromptForConfigKey must not error if the user eventually answers consistently")
|
|
assert.Equal(t, "what", string(k))
|
|
})
|
|
}
|
|
|
|
func TestEncryptConfigData(t *testing.T) {
|
|
t.Parallel()
|
|
_, err := EncryptConfigData([]byte("test"), nil)
|
|
require.ErrorIs(t, err, errKeyIsEmpty)
|
|
|
|
c := &Config{
|
|
sessionDK: []byte("a"),
|
|
}
|
|
_, err = c.encryptConfigData([]byte(`test`))
|
|
require.ErrorIs(t, err, ErrSettingEncryptConfig)
|
|
|
|
_, err = c.encryptConfigData([]byte(`{"test":1}`))
|
|
require.Error(t, err)
|
|
require.IsType(t, aes.KeySizeError(1), err)
|
|
|
|
sessDk, salt, err := makeNewSessionDK([]byte("asdf"))
|
|
require.NoError(t, err, "makeNewSessionDK must not error")
|
|
|
|
c = &Config{
|
|
sessionDK: sessDk,
|
|
storedSalt: salt,
|
|
}
|
|
_, err = c.encryptConfigData([]byte(`{"test":1}`))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestDecryptConfigData(t *testing.T) {
|
|
t.Parallel()
|
|
e, err := EncryptConfigData([]byte(`{"test":1}`), []byte("key"))
|
|
require.NoError(t, err)
|
|
|
|
d, err := DecryptConfigData(e, []byte("key"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, `{"test":1,"encryptConfig":1}`, string(d), "encryptConfig should be set to 1 after first encryption")
|
|
|
|
_, err = DecryptConfigData(e, nil)
|
|
require.ErrorIs(t, err, errKeyIsEmpty)
|
|
|
|
_, err = DecryptConfigData([]byte("test"), nil)
|
|
require.ErrorIs(t, err, errNoPrefix)
|
|
|
|
_, err = DecryptConfigData(encryptionPrefix, []byte("AAAAAAAAAAAAAAAA"))
|
|
require.ErrorIs(t, err, errAESBlockSize)
|
|
|
|
sessionDK, salt, err := makeNewSessionDK([]byte("key"))
|
|
require.NoError(t, err, "makeNewSessionDK must not error")
|
|
|
|
encData, err := legacyEncrypt(t, salt, []byte(`{"test":123}`), sessionDK)
|
|
require.NoError(t, err)
|
|
|
|
data, err := DecryptConfigData(encData, []byte("key"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, `{"test":123}`, string(data))
|
|
|
|
badVersion := make([]byte, len(encryptionPrefix)+len(encryptionVersionPrefix)+versionSize)
|
|
copy(badVersion, encryptionPrefix)
|
|
copy(badVersion[len(encryptionPrefix):], encryptionVersionPrefix)
|
|
binary.BigEndian.PutUint16(badVersion[len(encryptionPrefix)+len(encryptionVersionPrefix):], 69)
|
|
_, err = DecryptConfigData(badVersion, []byte("key"))
|
|
require.ErrorIs(t, err, errUnsupportedEncryptionVersion)
|
|
}
|
|
|
|
func legacyEncrypt(t *testing.T, salt, data, key []byte) ([]byte, error) {
|
|
t.Helper()
|
|
|
|
ciphertext, err := aesCFBEncrypt(t, data, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encData := append(bytes.Clone(encryptionPrefix), salt...)
|
|
encData = append(encData, ciphertext...)
|
|
return encData, nil
|
|
}
|
|
|
|
func aesCFBEncrypt(t *testing.T, data, key []byte) ([]byte, error) {
|
|
t.Helper()
|
|
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ciphertext := make([]byte, aes.BlockSize+len(data))
|
|
iv := ciphertext[:aes.BlockSize]
|
|
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stream := cipher.NewCFBEncrypter(block, iv) //nolint:staticcheck // For testing purposes only
|
|
stream.XORKeyStream(ciphertext[aes.BlockSize:], data)
|
|
return ciphertext, nil
|
|
}
|
|
|
|
func TestEncryptAESGCMCiphertext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := decryptAESGCMCiphertext(nil, nil)
|
|
require.ErrorIs(t, err, aes.KeySizeError(0))
|
|
|
|
validKey := []byte(strings.Repeat("A", 16))
|
|
block, err := aes.NewCipher(validKey)
|
|
require.NoError(t, err)
|
|
|
|
aead, err := cipher.NewGCMWithRandomNonce(block)
|
|
require.NoError(t, err)
|
|
|
|
ciphertext := aead.Seal(nil, nil, []byte("MEOWMEOWMEOWMEOWMEOW"), nil)
|
|
|
|
data, err := decryptAESGCMCiphertext(ciphertext, validKey)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "MEOWMEOWMEOWMEOWMEOW", string(data))
|
|
}
|
|
|
|
func TestDecryptAESCFBCiphertext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := decryptAESCFBCiphertext(nil, nil)
|
|
require.ErrorIs(t, err, errAESBlockSize)
|
|
|
|
_, err = decryptAESCFBCiphertext([]byte("WOOFWOOFWOOFWOOFWOOF"), []byte("A"))
|
|
require.ErrorIs(t, err, aes.KeySizeError(1))
|
|
|
|
validKey := []byte(strings.Repeat("A", 16))
|
|
data, err := aesCFBEncrypt(t, []byte("WOOFWOOFWOOFWOOFWOOF"), validKey)
|
|
require.NoError(t, err)
|
|
|
|
data, err = decryptAESCFBCiphertext(data, validKey)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "WOOFWOOFWOOFWOOFWOOF", string(data))
|
|
}
|
|
|
|
func TestIsEncrypted(t *testing.T) {
|
|
t.Parallel()
|
|
assert.True(t, IsEncrypted(encryptionPrefix))
|
|
assert.False(t, IsEncrypted([]byte("mhmmm. Donuts.")))
|
|
}
|
|
|
|
func TestMakeNewSessionDK(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if _, _, err := makeNewSessionDK(nil); err == nil {
|
|
t.Fatal("makeNewSessionDK passed with nil key")
|
|
}
|
|
}
|
|
|
|
func TestEncryptTwiceReusesSaltButNewCipher(t *testing.T) {
|
|
c := &Config{
|
|
EncryptConfig: 1,
|
|
}
|
|
tempDir := t.TempDir()
|
|
|
|
// Prepare input
|
|
passFile, err := os.CreateTemp(tempDir, "*.pw")
|
|
if err != nil {
|
|
t.Fatalf("Problem creating temp file at %s: %s\n", tempDir, err)
|
|
}
|
|
_, err = passFile.WriteString("pass\npass\n")
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
err = passFile.Close()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
// Temporarily replace Stdin with a custom input
|
|
oldIn := os.Stdin
|
|
t.Cleanup(func() { os.Stdin = oldIn })
|
|
os.Stdin, err = os.Open(passFile.Name())
|
|
if err != nil {
|
|
t.Fatalf("Problem opening temp file at %s: %s\n", passFile.Name(), err)
|
|
}
|
|
t.Cleanup(func() {
|
|
err = os.Stdin.Close()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
})
|
|
|
|
// Save encrypted config
|
|
enc1 := filepath.Join(tempDir, "encrypted.dat")
|
|
err = c.SaveConfigToFile(enc1)
|
|
if err != nil {
|
|
t.Fatalf("Problem storing config in file %s: %s\n", enc1, err)
|
|
}
|
|
// Save again
|
|
enc2 := filepath.Join(tempDir, "encrypted2.dat")
|
|
err = c.SaveConfigToFile(enc2)
|
|
if err != nil {
|
|
t.Fatalf("Problem storing config in file %s: %s\n", enc2, err)
|
|
}
|
|
data1, err := os.ReadFile(enc1)
|
|
if err != nil {
|
|
t.Fatalf("Problem reading file %s: %s\n", enc1, err)
|
|
}
|
|
data2, err := os.ReadFile(enc2)
|
|
if err != nil {
|
|
t.Fatalf("Problem reading file %s: %s\n", enc2, err)
|
|
}
|
|
// length of prefix + salt
|
|
l := len(encryptionPrefix) + len(saltPrefix) + saltRandomLength
|
|
// Even though prefix, including salt with the random bytes is the same
|
|
if !bytes.Equal(data1[:l], data2[:l]) {
|
|
t.Error("Salt is not reused.")
|
|
}
|
|
// the cipher text should not be
|
|
if bytes.Equal(data1, data2) {
|
|
t.Error("Encryption key must have been reused as cipher texts are the same")
|
|
}
|
|
}
|
|
|
|
func TestSaveAndReopenEncryptedConfig(t *testing.T) {
|
|
c := &Config{}
|
|
c.Name = "myCustomName"
|
|
c.EncryptConfig = 1
|
|
tempDir := t.TempDir()
|
|
|
|
// Save encrypted config
|
|
enc := filepath.Join(tempDir, "encrypted.dat")
|
|
withInteractiveResponse(t, "pass\npass\n", func() {
|
|
err := c.SaveConfigToFile(enc)
|
|
require.NoError(t, err, "SaveConfigToFile must not error")
|
|
})
|
|
|
|
readConf := &Config{}
|
|
withInteractiveResponse(t, "pass\n", func() {
|
|
// Load with no existing state, key is read from the prepared file
|
|
err := readConf.ReadConfigFromFile(enc, true)
|
|
require.NoError(t, err, "ReadConfigFromFile must not error")
|
|
})
|
|
|
|
assert.Equal(t, "myCustomName", readConf.Name, "Name should be correct")
|
|
assert.Equal(t, 1, readConf.EncryptConfig, "EncryptConfig should be set correctly")
|
|
}
|
|
|
|
func TestReadConfigWithPrompt(t *testing.T) {
|
|
// Prepare temp dir
|
|
tempDir := t.TempDir()
|
|
|
|
// Ensure we'll get the prompt when loading
|
|
c := &Config{
|
|
EncryptConfig: 0,
|
|
}
|
|
|
|
// Save config
|
|
testConfigFile := filepath.Join(tempDir, "config.json")
|
|
err := c.SaveConfigToFile(testConfigFile)
|
|
require.NoError(t, err, "SaveConfigToFile must not error")
|
|
|
|
// Run the test
|
|
c = &Config{}
|
|
withInteractiveResponse(t, "y\npass\npass\n", func() {
|
|
err = c.ReadConfigFromFile(testConfigFile, false)
|
|
require.NoError(t, err, "ReadConfigFromFile must not error")
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Problem reading config file at %s: %s\n", testConfigFile, err)
|
|
}
|
|
|
|
// Verify results
|
|
data, err := os.ReadFile(testConfigFile)
|
|
if err != nil {
|
|
t.Fatalf("Problem reading saved file at %s: %s\n", testConfigFile, err)
|
|
}
|
|
if c.EncryptConfig != fileEncryptionEnabled {
|
|
t.Error("Config encryption flag should be set after prompts")
|
|
}
|
|
assert.True(t, IsEncrypted(data), "data should be encrypted after prompts")
|
|
}
|
|
|
|
func TestReadEncryptedConfigFromReader(t *testing.T) {
|
|
t.Parallel()
|
|
c := &Config{
|
|
EncryptionKeyProvider: func(_ bool) ([]byte, error) { return []byte("pass"), nil },
|
|
}
|
|
// Encrypted conf for: `{"name":"test"}` with key `pass`
|
|
confBytes := []byte{84, 72, 79, 82, 83, 45, 72, 65, 77, 77, 69, 82, 126, 71, 67, 84, 126, 83, 79, 126, 83, 65, 76, 84, 89, 126, 246, 110, 128, 3, 30, 168, 172, 160, 198, 176, 136, 62, 152, 155, 253, 176, 16, 48, 52, 246, 44, 29, 151, 47, 217, 226, 178, 12, 218, 113, 248, 172, 195, 232, 136, 104, 9, 199, 20, 4, 71, 4, 253, 249}
|
|
err := c.readConfig(bytes.NewReader(confBytes))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "test", c.Name)
|
|
|
|
// Change the salt
|
|
confBytes[20] = 0
|
|
err = c.readConfig(bytes.NewReader(confBytes))
|
|
require.ErrorIs(t, err, errDecryptFailed)
|
|
}
|
|
|
|
// TestSaveConfigToFileWithErrorInPasswordPrompt should preserve the original file
|
|
func TestSaveConfigToFileWithErrorInPasswordPrompt(t *testing.T) {
|
|
c := &Config{
|
|
Name: "test",
|
|
EncryptConfig: fileEncryptionEnabled,
|
|
}
|
|
testData := []byte("testdata")
|
|
f, err := os.CreateTemp(t.TempDir(), "")
|
|
require.NoError(t, err, "CreateTemp must not error")
|
|
|
|
targetFile := f.Name()
|
|
|
|
_, err = io.Copy(f, bytes.NewReader(testData))
|
|
require.NoError(t, err, "io.Copy must not error")
|
|
require.NoError(t, f.Close(), "file Close must not error")
|
|
|
|
withInteractiveResponse(t, "\n\n", func() {
|
|
err = c.SaveConfigToFile(targetFile)
|
|
require.ErrorIs(t, err, io.EOF, "SaveConfigToFile must not error")
|
|
})
|
|
|
|
data, err := os.ReadFile(targetFile)
|
|
require.NoError(t, err, "ReadFile must not error")
|
|
assert.Equal(t, testData, data)
|
|
}
|
|
|
|
func withInteractiveResponse(tb testing.TB, response string, fn func()) {
|
|
tb.Helper()
|
|
f, err := os.CreateTemp(tb.TempDir(), "*.in")
|
|
require.NoError(tb, err, "CreateTemp must not error")
|
|
defer f.Close()
|
|
_, err = f.WriteString(response)
|
|
require.NoError(tb, err, "WriteString must not error")
|
|
_, err = f.Seek(0, 0)
|
|
require.NoError(tb, err, "Seek must not error")
|
|
defer func(orig *os.File) { os.Stdin = orig }(os.Stdin)
|
|
os.Stdin = f
|
|
fn()
|
|
}
|