diff --git a/config/config.go b/config/config.go index 38586137..2551022e 100644 --- a/config/config.go +++ b/config/config.go @@ -1487,7 +1487,7 @@ func (c *Config) readConfig(d io.Reader) error { // If they agree, c.EncryptConfig is set to Enabled, the config is encrypted and saved // Otherwise, c.EncryptConfig is set to Disabled and the file is resaved func (c *Config) saveWithEncryptPrompt(path string) error { - if confirm, err := promptForConfigEncryption(); err != nil { + if confirm, err := promptForConfigEncryption(os.Stdin); err != nil { return nil //nolint:nilerr // Ignore encryption prompt failures; The user will be prompted again } else if confirm { c.EncryptConfig = fileEncryptionEnabled diff --git a/config/config_encryption.go b/config/config_encryption.go index d58f2ef8..6914fe24 100644 --- a/config/config_encryption.go +++ b/config/config_encryption.go @@ -43,11 +43,11 @@ var ( // promptForConfigEncryption asks for encryption confirmation // returns true if encryption was desired, false otherwise -func promptForConfigEncryption() (bool, error) { +func promptForConfigEncryption(r io.Reader) (bool, error) { fmt.Println("Would you like to encrypt your config file (y/n)?") input := "" - if _, err := fmt.Scanln(&input); err != nil { + if _, err := fmt.Fscanln(r, &input); err != nil { return false, err } diff --git a/config/config_encryption_test.go b/config/config_encryption_test.go index 4d856932..19cbbbba 100644 --- a/config/config_encryption_test.go +++ b/config/config_encryption_test.go @@ -19,9 +19,49 @@ import ( func TestPromptForConfigEncryption(t *testing.T) { t.Parallel() - confirm, err := promptForConfigEncryption() - require.ErrorIs(t, err, io.EOF) - require.False(t, confirm) + testCases := []struct { + name string + input string + expectedBool bool + expectedError error + }{ + { + name: "input_y", + input: "y\n", + expectedBool: true, + }, + { + name: "input_n", + input: "n\n", + }, + { + name: "input_yes", + input: "yes\n", + expectedBool: true, + }, + { + name: "input_no", + input: "no\n", + }, + { + name: "input_invalid", + input: "invalid\n", + }, + { + name: "input_empty_eof", + expectedError: io.EOF, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + reader := strings.NewReader(tc.input) + confirm, err := promptForConfigEncryption(reader) + require.ErrorIs(t, err, tc.expectedError) + require.Equal(t, tc.expectedBool, confirm) + }) + } } func TestPromptForConfigKey(t *testing.T) {