Fix Docker os.Rename invalid cross-device link issue (#386)

* Adds new file.Move func to address a bug with Golang/Docker volumes when using os.Rename

Also uses TempDir for tests instead of live directories and increases test coverage for file.Write

* Goimport the imports

* Make usage of file package name consistent so it no longer clashes with vars

* Remove outputFile if io.Copy fails
This commit is contained in:
Adrian Gallagher
2019-11-28 11:56:05 +11:00
committed by GitHub
parent 63191ce3ec
commit e20d204b19
15 changed files with 205 additions and 50 deletions

View File

@@ -5,7 +5,7 @@ import (
"io/ioutil"
"log"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/config"
)
@@ -43,19 +43,19 @@ func main() {
key = string(result)
}
file, err := ioutil.ReadFile(inFile)
fileData, err := ioutil.ReadFile(inFile)
if err != nil {
log.Fatalf("Unable to read input file %s. Error: %s.", inFile, err)
}
if config.ConfirmECS(file) && encrypt {
if config.ConfirmECS(fileData) && encrypt {
log.Println("File is already encrypted. Decrypting..")
encrypt = false
}
if !config.ConfirmECS(file) && !encrypt {
if !config.ConfirmECS(fileData) && !encrypt {
var result interface{}
errf := config.ConfirmConfigJSON(file, result)
errf := config.ConfirmConfigJSON(fileData, result)
if errf != nil {
log.Fatal("File isn't in JSON format")
}
@@ -65,18 +65,18 @@ func main() {
var data []byte
if encrypt {
data, err = config.EncryptConfigFile(file, []byte(key))
data, err = config.EncryptConfigFile(fileData, []byte(key))
if err != nil {
log.Fatalf("Unable to encrypt config data. Error: %s.", err)
}
} else {
data, err = config.DecryptConfigFile(file, []byte(key))
data, err = config.DecryptConfigFile(fileData, []byte(key))
if err != nil {
log.Fatalf("Unable to decrypt config data. Error: %s.", err)
}
}
err = common.WriteFile(outFile, data)
err = file.Write(outFile, data)
if err != nil {
log.Fatalf("Unable to write output file %s. Error: %s", outFile, err)
}

View File

@@ -14,6 +14,7 @@ import (
"text/template"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/config"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/engine"
@@ -754,7 +755,7 @@ func saveConfig(config *Config) {
}
log.Printf("Outputting to: %v", filepath.Join(dir, "wrapperconfig.json"))
err = common.WriteFile(filepath.Join(dir, "wrapperconfig.json"), jsonOutput)
err = file.Write(filepath.Join(dir, "wrapperconfig.json"), jsonOutput)
if err != nil {
log.Printf("Encountered error writing to disk: %v", err)
return
@@ -775,7 +776,7 @@ func outputToJSON(exchangeResponses []ExchangeResponses) {
}
log.Printf("Outputting to: %v", filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName)))
err = common.WriteFile(filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName)), jsonOutput)
err = file.Write(filepath.Join(dir, fmt.Sprintf("%v.json", outputFileName)), jsonOutput)
if err != nil {
log.Printf("Encountered error writing to disk: %v", err)
return

View File

@@ -14,7 +14,7 @@ import (
"os"
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
)
func main() {
@@ -83,13 +83,13 @@ func main() {
log.Fatalf("key pem data is nil")
}
err = common.WriteFile("key.pem", keyData)
err = file.Write("key.pem", keyData)
if err != nil {
log.Fatalf("failed to write key.pem file %s", err)
}
log.Printf("wrote key.pem file")
err = common.WriteFile("cert.pem", certData)
err = file.Write("cert.pem", certData)
if err != nil {
log.Fatalf("failed to write cert.pem file %s", err)
}

View File

@@ -68,7 +68,7 @@ func main() {
}
path := filepath.Join(outputFolder, "sqlboiler.json")
err = ioutil.WriteFile(path, jsonOutput, 0644)
err = ioutil.WriteFile(path, jsonOutput, 0770)
if err != nil {
fmt.Printf("Write failed: %v", err)
os.Exit(1)

View File

@@ -13,7 +13,7 @@ import (
"io/ioutil"
"log"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
)
func encodePEM(privKey *ecdsa.PrivateKey, pubKey bool) ([]byte, error) {
@@ -54,8 +54,8 @@ func decodePEM(pemPrivKey []byte) (*ecdsa.PrivateKey, error) {
return x509.ParseECPrivateKey(x509Enc)
}
func writeFile(file string, data []byte) error {
return common.WriteFile(file, data)
func writeFile(fileName string, data []byte) error {
return file.Write(fileName, data)
}
func main() {

View File

@@ -279,7 +279,7 @@ func ExtractPort(host string) int {
func OutputCSV(filePath string, data [][]string) error {
_, err := ioutil.ReadFile(filePath)
if err != nil {
errTwo := WriteFile(filePath, nil)
errTwo := ioutil.WriteFile(filePath, nil, 0770)
if errTwo != nil {
return errTwo
}
@@ -295,11 +295,6 @@ func OutputCSV(filePath string, data [][]string) error {
return writer.WriteAll(data)
}
// WriteFile writes selected data to a file and returns an error
func WriteFile(file string, data []byte) error {
return ioutil.WriteFile(file, data, 0644)
}
// GetURIPath returns the path of a URL given a URI
func GetURIPath(uri string) string {
urip, err := url.Parse(uri)
@@ -353,8 +348,8 @@ func CreateDir(dir string) error {
return os.MkdirAll(dir, 0770)
}
// ChangePerm lists all the directories and files in an array
func ChangePerm(directory string) error {
// ChangePermission lists all the directories and files in an array
func ChangePermission(directory string) error {
return filepath.Walk(directory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err

View File

@@ -504,11 +504,11 @@ func TestCreateDir(t *testing.T) {
}
}
func TestChangePerm(t *testing.T) {
testDir := filepath.Join(GetDefaultDataDir(runtime.GOOS), "TestFileASDFGHJ")
func TestChangePermission(t *testing.T) {
testDir := filepath.Join(os.TempDir(), "TestFileASDFGHJ")
switch runtime.GOOS {
case "windows":
err := ChangePerm("*")
err := ChangePermission("*")
if err == nil {
t.Fatal("expected an error on non-existent path")
}
@@ -516,7 +516,7 @@ func TestChangePerm(t *testing.T) {
if err != nil {
t.Fatalf("Mkdir failed. Err: %v", err)
}
err = ChangePerm(GetDefaultDataDir(runtime.GOOS))
err = ChangePermission(testDir)
if err != nil {
t.Fatalf("ChangePerm was unsuccessful. Err: %v", err)
}
@@ -529,7 +529,7 @@ func TestChangePerm(t *testing.T) {
t.Fatalf("os.Remove failed. Err: %v", err)
}
default:
err := ChangePerm("")
err := ChangePermission("")
if err == nil {
t.Fatal("expected an error on non-existent path")
}
@@ -537,7 +537,7 @@ func TestChangePerm(t *testing.T) {
if err != nil {
t.Fatalf("Mkdir failed. Err: %v", err)
}
err = ChangePerm(GetDefaultDataDir(runtime.GOOS))
err = ChangePermission(testDir)
if err != nil {
t.Fatalf("ChangePerm was unsuccessful. Err: %v", err)
}

47
common/file/file.go Normal file
View File

@@ -0,0 +1,47 @@
package file
import (
"fmt"
"io"
"io/ioutil"
"os"
)
// Write writes selected data to a file or returns an error if it fails. This
// func also ensures that all files are set to this permission (only rw access
// for the running user and the group the user is a member of)
func Write(file string, data []byte) error {
return ioutil.WriteFile(file, data, 0770)
}
// Move moves a file from a source path to a destination path
// This must be used across the codebase for compatibility with Docker volumes
// and Golang (fixes Invalid cross-device link when using os.Rename)
func Move(sourcePath, destPath string) error {
inputFile, err := os.Open(sourcePath)
if err != nil {
return err
}
outputFile, err := os.Create(destPath)
if err != nil {
inputFile.Close()
return err
}
_, err = io.Copy(outputFile, inputFile)
inputFile.Close()
outputFile.Close()
if err != nil {
if errRem := os.Remove(destPath); errRem != nil {
return fmt.Errorf(
"unable to os.Remove error: %s after io.Copy error: %s",
errRem,
err,
)
}
return err
}
return os.Remove(sourcePath)
}

106
common/file/file_test.go Normal file
View File

@@ -0,0 +1,106 @@
package file
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func TestWrite(t *testing.T) {
tester := func(in string) error {
err := Write(in, []byte("GoCryptoTrader"))
if err != nil {
return err
}
return os.Remove(in)
}
type testTable struct {
InFile string
ErrExpected bool
Cleanup bool
}
var tests []testTable
testFile := filepath.Join(os.TempDir(), "gcttest.txt")
switch runtime.GOOS {
case "windows":
tests = []testTable{
{InFile: "*", ErrExpected: true},
{InFile: testFile, ErrExpected: false},
}
default:
tests = []testTable{
{InFile: "", ErrExpected: true},
{InFile: testFile, ErrExpected: false},
}
}
for x := range tests {
err := tester(tests[x].InFile)
if err != nil && !tests[x].ErrExpected {
t.Errorf("Test %d failed, unexpected err %s\n", x, err)
}
}
}
func TestMove(t *testing.T) {
tester := func(in, out string, write bool) error {
if write {
if err := ioutil.WriteFile(in, []byte("GoCryptoTrader"), 0770); err != nil {
return err
}
}
if err := Move(in, out); err != nil {
return err
}
contents, err := ioutil.ReadFile(out)
if err != nil {
return err
}
if !strings.Contains(string(contents), "GoCryptoTrader") {
return fmt.Errorf("unable to find previously written data")
}
return os.Remove(out)
}
type testTable struct {
InFile string
OutFile string
Write bool
ErrExpected bool
}
var tests []testTable
switch runtime.GOOS {
case "windows":
tests = []testTable{
{InFile: "*", OutFile: "gct.txt", Write: true, ErrExpected: true},
{InFile: "*", OutFile: "gct.txt", Write: false, ErrExpected: true},
{InFile: "in.txt", OutFile: "*", Write: true, ErrExpected: true},
{InFile: "in.txt", OutFile: "gct.txt", Write: true, ErrExpected: false},
}
default:
tests = []testTable{
{InFile: "", OutFile: "gct.txt", Write: true, ErrExpected: true},
{InFile: "", OutFile: "gct.txt", Write: false, ErrExpected: true},
{InFile: "in.txt", OutFile: "", Write: true, ErrExpected: true},
{InFile: "in.txt", OutFile: "gct.txt", Write: true, ErrExpected: false},
}
}
for x := range tests {
err := tester(tests[x].InFile, tests[x].OutFile, tests[x].Write)
if err != nil && !tests[x].ErrExpected {
t.Errorf("Test %d failed, unexpected err %s\n", x, err)
}
}
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/convert"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/connchecker"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/currency/forexprovider"
@@ -1331,9 +1332,9 @@ func (c *Config) CheckConnectionMonitorConfig() {
// GetFilePath returns the desired config file or the default config file name
// based on if the application is being run under test or normal mode.
func GetFilePath(file string) (string, error) {
if file != "" {
return file, nil
func GetFilePath(configfile string) (string, error) {
if configfile != "" {
return configfile, nil
}
if flag.Lookup("test.v") != nil && !testBypass {
@@ -1375,7 +1376,7 @@ func GetFilePath(file string) (string, error) {
return newDirs[x], nil
}
if filepath.Ext(oldDirs[x]) == ".json" {
err = os.Rename(oldDirs[x], newDirs[0])
err = file.Move(oldDirs[x], newDirs[0])
if err != nil {
return "", err
}
@@ -1384,7 +1385,7 @@ func GetFilePath(file string) (string, error) {
oldDirs[x],
newDirs[0])
} else {
err = os.Rename(oldDirs[x], newDirs[1])
err = file.Move(oldDirs[x], newDirs[1])
if err != nil {
return "", err
}
@@ -1412,7 +1413,7 @@ func GetFilePath(file string) (string, error) {
return newDirs[x], nil
}
err = os.Rename(newDirs[x], newDirs[1])
err = file.Move(newDirs[x], newDirs[1])
if err != nil {
return "", err
}
@@ -1423,7 +1424,7 @@ func GetFilePath(file string) (string, error) {
return newDirs[x], nil
}
err = os.Rename(newDirs[x], newDirs[0])
err = file.Move(newDirs[x], newDirs[0])
if err != nil {
return "", err
}
@@ -1443,13 +1444,13 @@ func (c *Config) ReadConfig(configPath string, dryrun bool) error {
return err
}
file, err := ioutil.ReadFile(defaultPath)
fileData, err := ioutil.ReadFile(defaultPath)
if err != nil {
return err
}
if !ConfirmECS(file) {
err = ConfirmConfigJSON(file, &c)
if !ConfirmECS(fileData) {
err = ConfirmConfigJSON(fileData, &c)
if err != nil {
return err
}
@@ -1481,7 +1482,7 @@ func (c *Config) ReadConfig(configPath string, dryrun bool) error {
}
var f []byte
f = append(f, file...)
f = append(f, fileData...)
data, err := DecryptConfigFile(f, key)
if err != nil {
log.Errorf(log.ConfigMgr, "DecryptConfigFile err: %s", err)
@@ -1535,7 +1536,7 @@ func (c *Config) SaveConfig(configPath string, dryrun bool) error {
return err
}
}
return common.WriteFile(defaultPath, payload)
return file.Write(defaultPath, payload)
}
// CheckRemoteControlConfig checks to see if the old c.Webserver field is used

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/currency/coinmarketcap"
"github.com/thrasher-corp/gocryptotrader/currency/forexprovider"
"github.com/thrasher-corp/gocryptotrader/currency/forexprovider/base"
@@ -318,7 +319,7 @@ func (s *Storage) WriteCurrencyDataToFile(path string, mainUpdate bool) error {
return err
}
return common.WriteFile(path, encoded)
return file.Write(path, encoded)
}
// LoadFileCurrencyData loads currencies into the currency codes

View File

@@ -18,6 +18,7 @@ import (
"github.com/pquerna/otp/totp"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/file"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/dispatch"
exchange "github.com/thrasher-corp/gocryptotrader/exchanges"
@@ -867,12 +868,12 @@ func genCert(targetDir string) error {
return fmt.Errorf("key pem data is nil")
}
err = common.WriteFile(filepath.Join(targetDir, "key.pem"), keyData)
err = file.Write(filepath.Join(targetDir, "key.pem"), keyData)
if err != nil {
return fmt.Errorf("failed to write key.pem file %s", err)
}
err = common.WriteFile(filepath.Join(targetDir, "cert.pem"), certData)
err = file.Write(filepath.Join(targetDir, "cert.pem"), certData)
if err != nil {
return fmt.Errorf("failed to write cert.pem file %s", err)
}

View File

@@ -13,8 +13,8 @@ import (
"strings"
"sync"
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/common/file"
)
// HTTPResponse defines expected response from the end point including request
@@ -212,7 +212,7 @@ func HTTPRecord(res *http.Response, service string, respContents []byte) error {
return err
}
return common.WriteFile(fileout, payload)
return file.Write(fileout, payload)
}
// GetFilteredHeader filters excluded http headers for insertion into a mock

View File

@@ -15,6 +15,7 @@ import (
"github.com/thrasher-corp/gocryptotrader/common"
"github.com/thrasher-corp/gocryptotrader/common/crypto"
"github.com/thrasher-corp/gocryptotrader/common/file"
)
// DefaultDirectory defines the main mock directory
@@ -56,7 +57,7 @@ func NewVCRServer(path string) (string, *http.Client, error) {
return "", nil, jErr
}
err = common.WriteFile(path, data)
err = file.Write(path, data)
if err != nil {
return "", nil, err
}

View File

@@ -5,6 +5,8 @@ import (
"os"
"path/filepath"
"time"
"github.com/thrasher-corp/gocryptotrader/common/file"
)
// Write implementation to satisfy io.Writer handles length check and rotation
@@ -78,7 +80,7 @@ func (r *Rotate) openNew() error {
timestamp := time.Now().Format("2006-01-02T15-04-05")
newName := filepath.Join(LogPath, timestamp+"-"+r.FileName)
err = os.Rename(name, newName)
err = file.Move(name, newName)
if err != nil {
return fmt.Errorf("can't rename log file: %s", err)
}