From e20d204b197d9c7d7d82b88a097fa7d693796002 Mon Sep 17 00:00:00 2001 From: Adrian Gallagher Date: Thu, 28 Nov 2019 11:56:05 +1100 Subject: [PATCH] Fix Docker os.Rename invalid cross-device link issue (#386) * Adds new file.Move func to address a bug with Golang/Docker volumes when using os.Rename Also uses TempDir for tests instead of live directories and increases test coverage for file.Write * Goimport the imports * Make usage of file package name consistent so it no longer clashes with vars * Remove outputFile if io.Copy fails --- cmd/config/config.go | 16 ++--- cmd/exchange_wrapper_issues/main.go | 5 +- cmd/gen_cert/main.go | 6 +- cmd/gen_sqlboiler_config/main.go | 2 +- cmd/huobi_auth/main.go | 6 +- common/common.go | 11 +-- common/common_test.go | 12 ++-- common/file/file.go | 47 ++++++++++++ common/file/file_test.go | 106 ++++++++++++++++++++++++++++ config/config.go | 25 +++---- currency/storage.go | 3 +- engine/helpers.go | 5 +- exchanges/mock/recording.go | 4 +- exchanges/mock/server.go | 3 +- logger/logger_rotate.go | 4 +- 15 files changed, 205 insertions(+), 50 deletions(-) create mode 100644 common/file/file.go create mode 100644 common/file/file_test.go diff --git a/cmd/config/config.go b/cmd/config/config.go index a85761fb..3274b9ed 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -5,7 +5,7 @@ import ( "io/ioutil" "log" - "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/config" ) @@ -43,19 +43,19 @@ func main() { key = string(result) } - file, err := ioutil.ReadFile(inFile) + fileData, err := ioutil.ReadFile(inFile) if err != nil { log.Fatalf("Unable to read input file %s. Error: %s.", inFile, err) } - if config.ConfirmECS(file) && encrypt { + if config.ConfirmECS(fileData) && encrypt { log.Println("File is already encrypted. Decrypting..") encrypt = false } - if !config.ConfirmECS(file) && !encrypt { + if !config.ConfirmECS(fileData) && !encrypt { var result interface{} - errf := config.ConfirmConfigJSON(file, result) + errf := config.ConfirmConfigJSON(fileData, result) if errf != nil { log.Fatal("File isn't in JSON format") } @@ -65,18 +65,18 @@ func main() { var data []byte if encrypt { - data, err = config.EncryptConfigFile(file, []byte(key)) + data, err = config.EncryptConfigFile(fileData, []byte(key)) if err != nil { log.Fatalf("Unable to encrypt config data. Error: %s.", err) } } else { - data, err = config.DecryptConfigFile(file, []byte(key)) + data, err = config.DecryptConfigFile(fileData, []byte(key)) if err != nil { log.Fatalf("Unable to decrypt config data. Error: %s.", err) } } - err = common.WriteFile(outFile, data) + err = file.Write(outFile, data) if err != nil { log.Fatalf("Unable to write output file %s. Error: %s", outFile, err) } diff --git a/cmd/exchange_wrapper_issues/main.go b/cmd/exchange_wrapper_issues/main.go index a016b544..377c3658 100644 --- a/cmd/exchange_wrapper_issues/main.go +++ b/cmd/exchange_wrapper_issues/main.go @@ -14,6 +14,7 @@ import ( "text/template" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/engine" @@ -754,7 +755,7 @@ func saveConfig(config *Config) { } log.Printf("Outputting to: %v", filepath.Join(dir, "wrapperconfig.json")) - err = common.WriteFile(filepath.Join(dir, "wrapperconfig.json"), jsonOutput) + err = file.Write(filepath.Join(dir, "wrapperconfig.json"), jsonOutput) if err != nil { log.Printf("Encountered error writing to disk: %v", err) return @@ -775,7 +776,7 @@ func outputToJSON(exchangeResponses []ExchangeResponses) { } log.Printf("Outputting to: %v", filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName))) - err = common.WriteFile(filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName)), jsonOutput) + err = file.Write(filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName)), jsonOutput) if err != nil { log.Printf("Encountered error writing to disk: %v", err) return diff --git a/cmd/gen_cert/main.go b/cmd/gen_cert/main.go index 97aa51c1..d8e96a6e 100644 --- a/cmd/gen_cert/main.go +++ b/cmd/gen_cert/main.go @@ -14,7 +14,7 @@ import ( "os" "time" - "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" ) func main() { @@ -83,13 +83,13 @@ func main() { log.Fatalf("key pem data is nil") } - err = common.WriteFile("key.pem", keyData) + err = file.Write("key.pem", keyData) if err != nil { log.Fatalf("failed to write key.pem file %s", err) } log.Printf("wrote key.pem file") - err = common.WriteFile("cert.pem", certData) + err = file.Write("cert.pem", certData) if err != nil { log.Fatalf("failed to write cert.pem file %s", err) } diff --git a/cmd/gen_sqlboiler_config/main.go b/cmd/gen_sqlboiler_config/main.go index c5198315..9d739a2c 100644 --- a/cmd/gen_sqlboiler_config/main.go +++ b/cmd/gen_sqlboiler_config/main.go @@ -68,7 +68,7 @@ func main() { } path := filepath.Join(outputFolder, "sqlboiler.json") - err = ioutil.WriteFile(path, jsonOutput, 0644) + err = ioutil.WriteFile(path, jsonOutput, 0770) if err != nil { fmt.Printf("Write failed: %v", err) os.Exit(1) diff --git a/cmd/huobi_auth/main.go b/cmd/huobi_auth/main.go index e7636e1d..2b2a9572 100644 --- a/cmd/huobi_auth/main.go +++ b/cmd/huobi_auth/main.go @@ -13,7 +13,7 @@ import ( "io/ioutil" "log" - "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" ) func encodePEM(privKey *ecdsa.PrivateKey, pubKey bool) ([]byte, error) { @@ -54,8 +54,8 @@ func decodePEM(pemPrivKey []byte) (*ecdsa.PrivateKey, error) { return x509.ParseECPrivateKey(x509Enc) } -func writeFile(file string, data []byte) error { - return common.WriteFile(file, data) +func writeFile(fileName string, data []byte) error { + return file.Write(fileName, data) } func main() { diff --git a/common/common.go b/common/common.go index 71a6aa75..dcabebf8 100644 --- a/common/common.go +++ b/common/common.go @@ -279,7 +279,7 @@ func ExtractPort(host string) int { func OutputCSV(filePath string, data [][]string) error { _, err := ioutil.ReadFile(filePath) if err != nil { - errTwo := WriteFile(filePath, nil) + errTwo := ioutil.WriteFile(filePath, nil, 0770) if errTwo != nil { return errTwo } @@ -295,11 +295,6 @@ func OutputCSV(filePath string, data [][]string) error { return writer.WriteAll(data) } -// WriteFile writes selected data to a file and returns an error -func WriteFile(file string, data []byte) error { - return ioutil.WriteFile(file, data, 0644) -} - // GetURIPath returns the path of a URL given a URI func GetURIPath(uri string) string { urip, err := url.Parse(uri) @@ -353,8 +348,8 @@ func CreateDir(dir string) error { return os.MkdirAll(dir, 0770) } -// ChangePerm lists all the directories and files in an array -func ChangePerm(directory string) error { +// ChangePermission lists all the directories and files in an array +func ChangePermission(directory string) error { return filepath.Walk(directory, func(path string, info os.FileInfo, err error) error { if err != nil { return err diff --git a/common/common_test.go b/common/common_test.go index d8bf9f84..10a7a858 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -504,11 +504,11 @@ func TestCreateDir(t *testing.T) { } } -func TestChangePerm(t *testing.T) { - testDir := filepath.Join(GetDefaultDataDir(runtime.GOOS), "TestFileASDFGHJ") +func TestChangePermission(t *testing.T) { + testDir := filepath.Join(os.TempDir(), "TestFileASDFGHJ") switch runtime.GOOS { case "windows": - err := ChangePerm("*") + err := ChangePermission("*") if err == nil { t.Fatal("expected an error on non-existent path") } @@ -516,7 +516,7 @@ func TestChangePerm(t *testing.T) { if err != nil { t.Fatalf("Mkdir failed. Err: %v", err) } - err = ChangePerm(GetDefaultDataDir(runtime.GOOS)) + err = ChangePermission(testDir) if err != nil { t.Fatalf("ChangePerm was unsuccessful. Err: %v", err) } @@ -529,7 +529,7 @@ func TestChangePerm(t *testing.T) { t.Fatalf("os.Remove failed. Err: %v", err) } default: - err := ChangePerm("") + err := ChangePermission("") if err == nil { t.Fatal("expected an error on non-existent path") } @@ -537,7 +537,7 @@ func TestChangePerm(t *testing.T) { if err != nil { t.Fatalf("Mkdir failed. Err: %v", err) } - err = ChangePerm(GetDefaultDataDir(runtime.GOOS)) + err = ChangePermission(testDir) if err != nil { t.Fatalf("ChangePerm was unsuccessful. Err: %v", err) } diff --git a/common/file/file.go b/common/file/file.go new file mode 100644 index 00000000..ec3ed644 --- /dev/null +++ b/common/file/file.go @@ -0,0 +1,47 @@ +package file + +import ( + "fmt" + "io" + "io/ioutil" + "os" +) + +// Write writes selected data to a file or returns an error if it fails. This +// func also ensures that all files are set to this permission (only rw access +// for the running user and the group the user is a member of) +func Write(file string, data []byte) error { + return ioutil.WriteFile(file, data, 0770) +} + +// Move moves a file from a source path to a destination path +// This must be used across the codebase for compatibility with Docker volumes +// and Golang (fixes Invalid cross-device link when using os.Rename) +func Move(sourcePath, destPath string) error { + inputFile, err := os.Open(sourcePath) + if err != nil { + return err + } + + outputFile, err := os.Create(destPath) + if err != nil { + inputFile.Close() + return err + } + + _, err = io.Copy(outputFile, inputFile) + inputFile.Close() + outputFile.Close() + if err != nil { + if errRem := os.Remove(destPath); errRem != nil { + return fmt.Errorf( + "unable to os.Remove error: %s after io.Copy error: %s", + errRem, + err, + ) + } + return err + } + + return os.Remove(sourcePath) +} diff --git a/common/file/file_test.go b/common/file/file_test.go new file mode 100644 index 00000000..f0db7d1f --- /dev/null +++ b/common/file/file_test.go @@ -0,0 +1,106 @@ +package file + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestWrite(t *testing.T) { + tester := func(in string) error { + err := Write(in, []byte("GoCryptoTrader")) + if err != nil { + return err + } + return os.Remove(in) + } + + type testTable struct { + InFile string + ErrExpected bool + Cleanup bool + } + + var tests []testTable + testFile := filepath.Join(os.TempDir(), "gcttest.txt") + switch runtime.GOOS { + case "windows": + tests = []testTable{ + {InFile: "*", ErrExpected: true}, + {InFile: testFile, ErrExpected: false}, + } + default: + tests = []testTable{ + {InFile: "", ErrExpected: true}, + {InFile: testFile, ErrExpected: false}, + } + } + + for x := range tests { + err := tester(tests[x].InFile) + if err != nil && !tests[x].ErrExpected { + t.Errorf("Test %d failed, unexpected err %s\n", x, err) + } + } +} + +func TestMove(t *testing.T) { + tester := func(in, out string, write bool) error { + if write { + if err := ioutil.WriteFile(in, []byte("GoCryptoTrader"), 0770); err != nil { + return err + } + } + + if err := Move(in, out); err != nil { + return err + } + + contents, err := ioutil.ReadFile(out) + if err != nil { + return err + } + + if !strings.Contains(string(contents), "GoCryptoTrader") { + return fmt.Errorf("unable to find previously written data") + } + + return os.Remove(out) + } + + type testTable struct { + InFile string + OutFile string + Write bool + ErrExpected bool + } + + var tests []testTable + switch runtime.GOOS { + case "windows": + tests = []testTable{ + {InFile: "*", OutFile: "gct.txt", Write: true, ErrExpected: true}, + {InFile: "*", OutFile: "gct.txt", Write: false, ErrExpected: true}, + {InFile: "in.txt", OutFile: "*", Write: true, ErrExpected: true}, + {InFile: "in.txt", OutFile: "gct.txt", Write: true, ErrExpected: false}, + } + default: + tests = []testTable{ + {InFile: "", OutFile: "gct.txt", Write: true, ErrExpected: true}, + {InFile: "", OutFile: "gct.txt", Write: false, ErrExpected: true}, + {InFile: "in.txt", OutFile: "", Write: true, ErrExpected: true}, + {InFile: "in.txt", OutFile: "gct.txt", Write: true, ErrExpected: false}, + } + } + + for x := range tests { + err := tester(tests[x].InFile, tests[x].OutFile, tests[x].Write) + if err != nil && !tests[x].ErrExpected { + t.Errorf("Test %d failed, unexpected err %s\n", x, err) + } + } +} diff --git a/config/config.go b/config/config.go index 5ad56549..88db3703 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" + "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/connchecker" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/currency/forexprovider" @@ -1331,9 +1332,9 @@ func (c *Config) CheckConnectionMonitorConfig() { // GetFilePath returns the desired config file or the default config file name // based on if the application is being run under test or normal mode. -func GetFilePath(file string) (string, error) { - if file != "" { - return file, nil +func GetFilePath(configfile string) (string, error) { + if configfile != "" { + return configfile, nil } if flag.Lookup("test.v") != nil && !testBypass { @@ -1375,7 +1376,7 @@ func GetFilePath(file string) (string, error) { return newDirs[x], nil } if filepath.Ext(oldDirs[x]) == ".json" { - err = os.Rename(oldDirs[x], newDirs[0]) + err = file.Move(oldDirs[x], newDirs[0]) if err != nil { return "", err } @@ -1384,7 +1385,7 @@ func GetFilePath(file string) (string, error) { oldDirs[x], newDirs[0]) } else { - err = os.Rename(oldDirs[x], newDirs[1]) + err = file.Move(oldDirs[x], newDirs[1]) if err != nil { return "", err } @@ -1412,7 +1413,7 @@ func GetFilePath(file string) (string, error) { return newDirs[x], nil } - err = os.Rename(newDirs[x], newDirs[1]) + err = file.Move(newDirs[x], newDirs[1]) if err != nil { return "", err } @@ -1423,7 +1424,7 @@ func GetFilePath(file string) (string, error) { return newDirs[x], nil } - err = os.Rename(newDirs[x], newDirs[0]) + err = file.Move(newDirs[x], newDirs[0]) if err != nil { return "", err } @@ -1443,13 +1444,13 @@ func (c *Config) ReadConfig(configPath string, dryrun bool) error { return err } - file, err := ioutil.ReadFile(defaultPath) + fileData, err := ioutil.ReadFile(defaultPath) if err != nil { return err } - if !ConfirmECS(file) { - err = ConfirmConfigJSON(file, &c) + if !ConfirmECS(fileData) { + err = ConfirmConfigJSON(fileData, &c) if err != nil { return err } @@ -1481,7 +1482,7 @@ func (c *Config) ReadConfig(configPath string, dryrun bool) error { } var f []byte - f = append(f, file...) + f = append(f, fileData...) data, err := DecryptConfigFile(f, key) if err != nil { log.Errorf(log.ConfigMgr, "DecryptConfigFile err: %s", err) @@ -1535,7 +1536,7 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error { return err } } - return common.WriteFile(defaultPath, payload) + return file.Write(defaultPath, payload) } // CheckRemoteControlConfig checks to see if the old c.Webserver field is used diff --git a/currency/storage.go b/currency/storage.go index dd1e6089..e4148647 100644 --- a/currency/storage.go +++ b/currency/storage.go @@ -9,6 +9,7 @@ import ( "time" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/currency/coinmarketcap" "github.com/thrasher-corp/gocryptotrader/currency/forexprovider" "github.com/thrasher-corp/gocryptotrader/currency/forexprovider/base" @@ -318,7 +319,7 @@ func (s *Storage) WriteCurrencyDataToFile(path string, mainUpdate bool) error { return err } - return common.WriteFile(path, encoded) + return file.Write(path, encoded) } // LoadFileCurrencyData loads currencies into the currency codes diff --git a/engine/helpers.go b/engine/helpers.go index 3df4c11a..0db45287 100644 --- a/engine/helpers.go +++ b/engine/helpers.go @@ -18,6 +18,7 @@ import ( "github.com/pquerna/otp/totp" "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/common/file" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/dispatch" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" @@ -867,12 +868,12 @@ func genCert(targetDir string) error { return fmt.Errorf("key pem data is nil") } - err = common.WriteFile(filepath.Join(targetDir, "key.pem"), keyData) + err = file.Write(filepath.Join(targetDir, "key.pem"), keyData) if err != nil { return fmt.Errorf("failed to write key.pem file %s", err) } - err = common.WriteFile(filepath.Join(targetDir, "cert.pem"), certData) + err = file.Write(filepath.Join(targetDir, "cert.pem"), certData) if err != nil { return fmt.Errorf("failed to write cert.pem file %s", err) } diff --git a/exchanges/mock/recording.go b/exchanges/mock/recording.go index 36dc9826..e561e1f0 100644 --- a/exchanges/mock/recording.go +++ b/exchanges/mock/recording.go @@ -13,8 +13,8 @@ import ( "strings" "sync" - "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" + "github.com/thrasher-corp/gocryptotrader/common/file" ) // HTTPResponse defines expected response from the end point including request @@ -212,7 +212,7 @@ func HTTPRecord(res *http.Response, service string, respContents []byte) error { return err } - return common.WriteFile(fileout, payload) + return file.Write(fileout, payload) } // GetFilteredHeader filters excluded http headers for insertion into a mock diff --git a/exchanges/mock/server.go b/exchanges/mock/server.go index c04dbd3a..3965ffec 100644 --- a/exchanges/mock/server.go +++ b/exchanges/mock/server.go @@ -15,6 +15,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/crypto" + "github.com/thrasher-corp/gocryptotrader/common/file" ) // DefaultDirectory defines the main mock directory @@ -56,7 +57,7 @@ func NewVCRServer(path string) (string, *http.Client, error) { return "", nil, jErr } - err = common.WriteFile(path, data) + err = file.Write(path, data) if err != nil { return "", nil, err } diff --git a/logger/logger_rotate.go b/logger/logger_rotate.go index 985c4c70..bf42abc1 100644 --- a/logger/logger_rotate.go +++ b/logger/logger_rotate.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" "time" + + "github.com/thrasher-corp/gocryptotrader/common/file" ) // Write implementation to satisfy io.Writer handles length check and rotation @@ -78,7 +80,7 @@ func (r *Rotate) openNew() error { timestamp := time.Now().Format("2006-01-02T15-04-05") newName := filepath.Join(LogPath, timestamp+"-"+r.FileName) - err = os.Rename(name, newName) + err = file.Move(name, newName) if err != nil { return fmt.Errorf("can't rename log file: %s", err) }