diff --git a/internal/action/update.go b/internal/action/update.go index b6b2ca526f..125b6694bd 100644 --- a/internal/action/update.go +++ b/internal/action/update.go @@ -1,9 +1,6 @@ package action import ( - "fmt" - "runtime" - "github.com/gopasspw/gopass/internal/action/exit" "github.com/gopasspw/gopass/internal/out" "github.com/gopasspw/gopass/internal/updater" @@ -23,10 +20,6 @@ func (s *Action) Update(c *cli.Context) error { return nil } - if runtime.GOOS == "windows" { - return fmt.Errorf("gopass update is not supported on windows (#1722)") - } - out.Printf(ctx, "⚒ Checking for available updates ...") if err := updater.Update(ctx, s.version); err != nil { return exit.Error(exit.Unknown, err, "Failed to update gopass: %s", err) diff --git a/internal/updater/access_others.go b/internal/updater/access_others.go index ea60602d92..40581b3150 100644 --- a/internal/updater/access_others.go +++ b/internal/updater/access_others.go @@ -8,3 +8,8 @@ import "golang.org/x/sys/unix" func canWrite(path string) error { return unix.Access(path, unix.W_OK) //nolint:wrapcheck } + +func removeOldBinary(dir, dest string) error { + // no need, os.Rename will replace the destination + return nil +} diff --git a/internal/updater/access_windows.go b/internal/updater/access_windows.go index 3402c01c32..309d9764b7 100644 --- a/internal/updater/access_windows.go +++ b/internal/updater/access_windows.go @@ -3,6 +3,33 @@ package updater +import ( + "fmt" + "os" + "path/filepath" +) + func canWrite(path string) error { return nil } + +// Windows won't allow us to remove the binary that's currently being executed. +// So rename the binary and then the updater should be able to write it's +// update to the correct location. +// +// See https://stackoverflow.com/a/459860 +func removeOldBinary(dir, dest string) error { + bakFile := filepath.Join(dir, filepath.Base(dest)+".bak") + // check if the bakup file already exists + if _, err := os.Stat(bakFile); err == nil { + // ... then remove it + _ = os.Remove(bakFile) + } + // we can't remove the currently running binary, but should be able to + // rename it. + if err := os.Rename(dest, bakFile); err != nil { + return fmt.Errorf("unable to rename %s to %s: %w", dest, bakFile, err) + } + + return nil +} diff --git a/internal/updater/extract.go b/internal/updater/extract.go index 2e32dfecd9..70ff5c3db6 100644 --- a/internal/updater/extract.go +++ b/internal/updater/extract.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "os" "path/filepath" @@ -17,6 +18,7 @@ import ( func extractFile(buf []byte, filename, dest string) error { mode := os.FileMode(0o755) + dir := filepath.Dir(dest) // if overwriting an existing binary retain it's mode flags fi, err := os.Lstat(dest) @@ -24,19 +26,32 @@ func extractFile(buf []byte, filename, dest string) error { mode = fi.Mode() } - if err := os.Remove(dest); err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("unable to remove destination file: %w", err) - } + tfn, err := extractToTempFile(buf, filename, dest) + if err != nil { + return fmt.Errorf("failed to extract update to %s: %w", dest, err) + } + + if err := removeOldBinary(dir, dest); err != nil { + return fmt.Errorf("failed to remove old binary %s: %w", dest, err) } - // open the destination file for writing - dfh, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_EXCL, mode) + if err := os.Rename(tfn, dest); err != nil { + return fmt.Errorf("failed to rename tempfile %s to %s: %w", tfn, dest, err) + } + + return os.Chmod(dest, mode) +} + +func extractToTempFile(buf []byte, filename, dest string) (string, error) { + // open a temp file for writing + dir := filepath.Dir(dest) + dfh, err := ioutil.TempFile(dir, "gopass") if err != nil { - return fmt.Errorf("failed to open file %q: %w", dest, err) + return "", fmt.Errorf("failed to create temp file in %s: %w", dir, err) } defer func() { + _ = dfh.Sync() _ = dfh.Close() }() @@ -46,23 +61,23 @@ func extractFile(buf []byte, filename, dest string) error { case ".gz": gzr, err := gzip.NewReader(rd) if err != nil { - return fmt.Errorf("failed to open gzip file: %w", err) + return "", fmt.Errorf("failed to open gzip file: %w", err) } - return extractTar(gzr, dfh, dest) + return extractTar(gzr, dfh, dfh.Name()) case ".bz2": - return extractTar(bzip2.NewReader(rd), dfh, dest) + return extractTar(bzip2.NewReader(rd), dfh, dfh.Name()) case ".zip": - return extractZip(buf, dfh, dest) + return extractZip(buf, dfh, dfh.Name()) default: - return fmt.Errorf("unsupported file extension: %q", filepath.Ext(filename)) + return "", fmt.Errorf("unsupported file extension: %q", filepath.Ext(filename)) } } -func extractZip(buf []byte, dfh io.WriteCloser, dest string) error { +func extractZip(buf []byte, dfh io.WriteCloser, dest string) (string, error) { zrd, err := zip.NewReader(bytes.NewReader(buf), int64(len(buf))) if err != nil { - return fmt.Errorf("failed to open zip file: %w", err) + return "", fmt.Errorf("failed to open zip file: %w", err) } for i := 0; i < len(zrd.File); i++ { @@ -72,7 +87,7 @@ func extractZip(buf []byte, dfh io.WriteCloser, dest string) error { file, err := zrd.File[i].Open() if err != nil { - return fmt.Errorf("failed to read from zip file: %w", err) + return "", fmt.Errorf("failed to read from zip file: %w", err) } n, err := io.Copy(dfh, file) @@ -80,18 +95,18 @@ func extractZip(buf []byte, dfh io.WriteCloser, dest string) error { _ = dfh.Close() _ = os.Remove(dest) - return fmt.Errorf("failed to read gopass.exe from zip file: %w", err) + return "", fmt.Errorf("failed to read gopass.exe from zip file: %w", err) } // success debug.Log("wrote %d bytes to %v", n, dest) - return nil + return dest, nil } - return fmt.Errorf("file not found in archive") + return "", fmt.Errorf("file not found in archive") } -func extractTar(rd io.Reader, dfh io.WriteCloser, dest string) error { +func extractTar(rd io.Reader, dfh io.WriteCloser, dest string) (string, error) { tarReader := tar.NewReader(rd) for { @@ -101,7 +116,7 @@ func extractTar(rd io.Reader, dfh io.WriteCloser, dest string) error { } if err != nil { - return fmt.Errorf("failed to read from tar file: %w", err) + return "", fmt.Errorf("failed to read from tar file: %w", err) } name := filepath.Base(header.Name) @@ -119,13 +134,13 @@ func extractTar(rd io.Reader, dfh io.WriteCloser, dest string) error { _ = dfh.Close() _ = os.Remove(dest) - return fmt.Errorf("failed to read gopass from tar file: %w", err) + return "", fmt.Errorf("failed to read gopass from tar file: %w", err) } // success debug.Log("wrote %d bytes to %v", n, dest) - return nil + return dest, nil } - return fmt.Errorf("file not found in archive") + return "", fmt.Errorf("file not found in archive") }