From af09e47cdda921eb11cab970939740adb1612af4 Mon Sep 17 00:00:00 2001 From: garethgeorge Date: Mon, 26 Aug 2024 19:38:03 -0700 Subject: [PATCH] fix: use 'restic restore :' for restore operations --- internal/orchestrator/repo/repo.go | 15 ++++-- internal/orchestrator/repo/repo_test.go | 63 +++++++++++++++++++++++++ pkg/restic/restic.go | 3 +- 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/internal/orchestrator/repo/repo.go b/internal/orchestrator/repo/repo.go index 43c04e1a..9e23d3c6 100644 --- a/internal/orchestrator/repo/repo.go +++ b/internal/orchestrator/repo/repo.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "path" "slices" "sort" "strings" @@ -292,7 +293,7 @@ func (r *RepoOrchestrator) Check(ctx context.Context, output io.Writer) error { return nil } -func (r *RepoOrchestrator) Restore(ctx context.Context, snapshotId string, path string, target string, progressCallback func(event *v1.RestoreProgressEntry)) (*v1.RestoreProgressEntry, error) { +func (r *RepoOrchestrator) Restore(ctx context.Context, snapshotId string, snapshotPath string, target string, progressCallback func(event *v1.RestoreProgressEntry)) (*v1.RestoreProgressEntry, error) { r.mu.Lock() defer r.mu.Unlock() ctx, flush := forwardResticLogs(ctx) @@ -302,8 +303,16 @@ func (r *RepoOrchestrator) Restore(ctx context.Context, snapshotId string, path var opts []restic.GenericOption opts = append(opts, restic.WithFlags("--target", target)) - if path != "" { - opts = append(opts, restic.WithFlags("--include", path)) + + if snapshotPath != "" { + dir := path.Dir(snapshotPath) + base := path.Base(snapshotPath) + if dir != "" { + snapshotId = snapshotId + ":" + dir + } + if base != "" { + opts = append(opts, restic.WithFlags("--include", base)) + } } summary, err := r.repo.Restore(ctx, snapshotId, func(event *restic.RestoreProgressEntry) { diff --git a/internal/orchestrator/repo/repo_test.go b/internal/orchestrator/repo/repo_test.go index ee842a9f..8a2fa7ef 100644 --- a/internal/orchestrator/repo/repo_test.go +++ b/internal/orchestrator/repo/repo_test.go @@ -3,6 +3,7 @@ package repo import ( "bytes" "context" + "io/ioutil" "os" "runtime" "slices" @@ -86,6 +87,68 @@ func TestBackup(t *testing.T) { } } +func TestRestore(t *testing.T) { + t.Parallel() + + testFile := t.TempDir() + "/test.txt" + if err := ioutil.WriteFile(testFile, []byte("test"), 0644); err != nil { + t.Fatalf("failed to create test file: %v", err) + } + + r := &v1.Repo{ + Id: "test", + Uri: t.TempDir(), + Password: "test", + Flags: []string{"--no-cache"}, + } + + plan := &v1.Plan{ + Id: "test", + Repo: "test", + Paths: []string{testFile}, + } + + orchestrator := initRepoHelper(t, configForTest, r) + + // Create a backup of the single file + summary, err := orchestrator.Backup(context.Background(), plan, nil) + if err != nil { + t.Fatalf("backup error: %v", err) + } + if summary.SnapshotId == "" { + t.Fatal("expected snapshot id") + } + if summary.FilesNew != 1 { + t.Fatalf("expected 1 new file, got %d", summary.FilesNew) + } + + // Restore the file + restoreDir := t.TempDir() + restoreSummary, err := orchestrator.Restore(context.Background(), summary.SnapshotId, testFile, restoreDir, nil) + if err != nil { + t.Fatalf("restore error: %v", err) + } + if restoreSummary.FilesRestored != 1 { + t.Fatalf("expected 1 new file, got %d", restoreSummary.FilesRestored) + } + if restoreSummary.TotalFiles != 1 { + t.Fatalf("expected 1 total file, got %d", restoreSummary.TotalFiles) + } + + // Check the restored file + restoredFile := restoreDir + "/test.txt" + if _, err := os.Stat(restoredFile); err != nil { + t.Fatalf("failed to stat restored file: %v", err) + } + restoredData, err := ioutil.ReadFile(restoredFile) + if err != nil { + t.Fatalf("failed to read restored file: %v", err) + } + if string(restoredData) != "test" { + t.Fatalf("expected 'test', got '%s'", restoredData) + } +} + func TestSnapshotParenting(t *testing.T) { t.Parallel() diff --git a/pkg/restic/restic.go b/pkg/restic/restic.go index 03c4cf0e..0b373a8e 100644 --- a/pkg/restic/restic.go +++ b/pkg/restic/restic.go @@ -21,6 +21,7 @@ import ( var errAlreadyInitialized = errors.New("repo already initialized") var ErrPartialBackup = errors.New("incomplete backup") var ErrBackupFailed = errors.New("backup failed") +var ErrRestoreFailed = errors.New("restore failed") type Repo struct { cmd string @@ -218,7 +219,7 @@ func (r *Repo) Restore(ctx context.Context, snapshot string, callback func(*Rest if exitErr.ExitCode() == 3 { cmdErr = ErrPartialBackup } else { - cmdErr = fmt.Errorf("exit code %d: %w", exitErr.ExitCode(), ErrBackupFailed) + cmdErr = fmt.Errorf("exit code %d: %w", exitErr.ExitCode(), ErrRestoreFailed) } } }