Skip to content

Commit

Permalink
refactor: create tty copy handler (#1485)
Browse files Browse the repository at this point in the history
Signed-off-by: Terry Howe <terrylhowe@gmail.com>
Signed-off-by: Billy Zha <jinzha1@microsoft.com>
Co-authored-by: Billy Zha <jinzha1@microsoft.com>
  • Loading branch information
TerryHowe and qweeah authored Dec 5, 2024
1 parent 0998ba4 commit 363dc2d
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 73 deletions.
5 changes: 4 additions & 1 deletion cmd/oras/internal/display/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ func NewManifestIndexUpdateHandler(outputPath string, printer *output.Printer, p
}

// NewCopyHandler returns copy handlers.
func NewCopyHandler(printer *output.Printer, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) {
func NewCopyHandler(printer *output.Printer, tty *os.File, fetcher fetcher.Fetcher) (status.CopyHandler, metadata.CopyHandler) {
if tty != nil {
return status.NewTTYCopyHandler(tty), text.NewCopyHandler(printer)
}
return status.NewTextCopyHandler(printer, fetcher), text.NewCopyHandler(printer)
}
24 changes: 22 additions & 2 deletions cmd/oras/internal/display/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ limitations under the License.
package display

import (
"oras.land/oras/internal/testutils"
"os"
"reflect"
"testing"

"oras.land/oras/internal/testutils"

"oras.land/oras/cmd/oras/internal/display/metadata/text"
"oras.land/oras/cmd/oras/internal/display/status"
"oras.land/oras/cmd/oras/internal/option"
"oras.land/oras/cmd/oras/internal/output"
)
Expand Down Expand Up @@ -50,3 +52,21 @@ func TestNewPullHandler(t *testing.T) {
t.Errorf("NewPullHandler() error = %v, want nil", err)
}
}

func TestNewCopyHandler(t *testing.T) {
printer := output.NewPrinter(os.Stdout, os.Stderr)
copyHandler, copyMetadataHandler := NewCopyHandler(printer, os.Stdout, nil)
if _, ok := copyHandler.(*status.TTYCopyHandler); !ok {
t.Errorf("expected *status.TTYCopyHandler actual %v", reflect.TypeOf(copyHandler))
}
if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok {
t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler))
}
copyHandler, copyMetadataHandler = NewCopyHandler(printer, nil, nil)
if _, ok := copyHandler.(*status.TextCopyHandler); !ok {
t.Errorf("expected *status.TextCopyHandler actual %v", reflect.TypeOf(copyHandler))
}
if _, ok := copyMetadataHandler.(*text.CopyHandler); !ok {
t.Errorf("expected metadata.CopyHandler actual %v", reflect.TypeOf(copyMetadataHandler))
}
}
2 changes: 2 additions & 0 deletions cmd/oras/internal/display/status/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type CopyHandler interface {
PreCopy(ctx context.Context, desc ocispec.Descriptor) error
PostCopy(ctx context.Context, desc ocispec.Descriptor) error
OnMounted(ctx context.Context, desc ocispec.Descriptor) error
StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error)
StopTracking() error
}

// ManifestIndexCreateHandler handles status output for manifest index create command.
Expand Down
10 changes: 10 additions & 0 deletions cmd/oras/internal/display/status/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ func NewTextCopyHandler(printer *output.Printer, fetcher content.Fetcher) CopyHa
}
}

// StartTracking starts a tracked target from a graph target.
func (ch *TextCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) {
return gt, nil
}

// StopTracking ends the copy tracking for the target.
func (ch *TextCopyHandler) StopTracking() error {
return nil
}

// OnCopySkipped is called when an object already exists.
func (ch *TextCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error {
ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
Expand Down
61 changes: 61 additions & 0 deletions cmd/oras/internal/display/status/tty.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,64 @@ func (ph *TTYPullHandler) TrackTarget(gt oras.GraphTarget) (oras.GraphTarget, St
ph.tracked = tracked
return tracked, tracked.Close, nil
}

// TTYCopyHandler handles tty status output for copy events.
type TTYCopyHandler struct {
tty *os.File
committed sync.Map
tracked track.GraphTarget
}

// NewTTYCopyHandler returns a new handler for copy command.
func NewTTYCopyHandler(tty *os.File) CopyHandler {
return &TTYCopyHandler{
tty: tty,
}
}

// StartTracking returns a tracked target from a graph target.
func (ch *TTYCopyHandler) StartTracking(gt oras.GraphTarget) (oras.GraphTarget, error) {
var err error
ch.tracked, err = track.NewTarget(gt, copyPromptCopying, copyPromptCopied, ch.tty)
if err != nil {
return nil, err
}
return ch.tracked, err
}

// StopTracking ends the copy tracking for the target.
func (ch *TTYCopyHandler) StopTracking() error {
return ch.tracked.Close()
}

// OnCopySkipped is called when an object already exists.
func (ch *TTYCopyHandler) OnCopySkipped(_ context.Context, desc ocispec.Descriptor) error {
ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
return ch.tracked.Prompt(desc, copyPromptExists)
}

// PreCopy implements PreCopy of CopyHandler.
func (ch *TTYCopyHandler) PreCopy(context.Context, ocispec.Descriptor) error {
return nil
}

// PostCopy implements PostCopy of CopyHandler.
func (ch *TTYCopyHandler) PostCopy(ctx context.Context, desc ocispec.Descriptor) error {
ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
successors, err := graph.FilteredSuccessors(ctx, desc, ch.tracked, DeduplicatedFilter(&ch.committed))
if err != nil {
return err
}
for _, successor := range successors {
if err = ch.tracked.Prompt(successor, copyPromptSkipped); err != nil {
return err
}
}
return nil
}

// OnMounted implements OnMounted of CopyHandler.
func (ch *TTYCopyHandler) OnMounted(_ context.Context, desc ocispec.Descriptor) error {
ch.committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
return ch.tracked.Prompt(desc, copyPromptMounted)
}
106 changes: 105 additions & 1 deletion cmd/oras/internal/display/status/tty_console_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ limitations under the License.
package status

import (
"strconv"
"testing"

"oras.land/oras-go/v2"
"oras.land/oras-go/v2/content/memory"
"oras.land/oras/internal/testutils"
"testing"
)

type testGraphTarget struct {
oras.GraphTarget
}

func TestTTYPushHandler_TrackTarget(t *testing.T) {
// prepare pty
_, slave, err := testutils.NewPty()
Expand Down Expand Up @@ -78,3 +85,100 @@ func Test_TTYPullHandler_TrackTarget(t *testing.T) {
}
})
}

func TestTTYCopyHandler_OnMounted(t *testing.T) {
pty, slave, err := testutils.NewPty()
if err != nil {
t.Fatal(err)
}
defer slave.Close()
ch := NewTTYCopyHandler(slave)
_, err = ch.StartTracking(&testGraphTarget{memory.New()})
if err != nil {
t.Fatal(err)
}

if err = ch.OnMounted(ctx, mockFetcher.OciImage); err != nil {
t.Fatalf("OnMounted() should not return an error: %v", err)
}

if err = ch.StopTracking(); err != nil {
t.Fatalf("StopTracking() should not return an error: %v", err)
}

if err = testutils.MatchPty(pty, slave, "✓", "Mounted", strconv.FormatInt(mockFetcher.OciImage.Size, 10), "100.00%", mockFetcher.OciImage.Digest.String()); err != nil {
t.Fatal(err)
}
}

func TestTTYCopyHandler_OnCopySkipped(t *testing.T) {
pty, slave, err := testutils.NewPty()
if err != nil {
t.Fatal(err)
}
defer slave.Close()
ch := NewTTYCopyHandler(slave)
_, err = ch.StartTracking(&testGraphTarget{memory.New()})
if err != nil {
t.Fatal(err)
}

if err = ch.OnCopySkipped(ctx, mockFetcher.OciImage); err != nil {
t.Errorf("OnCopySkipped() should not return an error: %v", err)
}

if err = ch.StopTracking(); err != nil {
t.Errorf("StopTracking() should not return an error: %v", err)
}
if err = testutils.MatchPty(pty, slave, "Exists", "oci-image", strconv.FormatInt(mockFetcher.OciImage.Size, 10), "100.00%"); err != nil {
t.Fatal(err)
}
}

func TestTTYCopyHandler_PostCopy(t *testing.T) {
pty, slave, err := testutils.NewPty()
if err != nil {
t.Fatal(err)
}
defer slave.Close()
ch := NewTTYCopyHandler(slave)
_, err = ch.StartTracking(&testGraphTarget{memory.New()})
if err != nil {
t.Fatal(err)
}

if ch.PostCopy(ctx, bogus) == nil {
t.Error("PostCopy() should return an error")
}

if err = ch.StopTracking(); err != nil {
t.Errorf("StopTracking() should not return an error: %v", err)
}
if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil {
t.Fatal(err)
}
}

func TestTTYCopyHandler_PreCopy(t *testing.T) {
pty, slave, err := testutils.NewPty()
if err != nil {
t.Fatal(err)
}
defer slave.Close()
ch := NewTTYCopyHandler(slave)
_, err = ch.StartTracking(&testGraphTarget{memory.New()})
if err != nil {
t.Fatal(err)
}

if err = ch.PreCopy(ctx, mockFetcher.OciImage); err != nil {
t.Errorf("PreCopy() should not return an error: %v", err)
}

if err = ch.StopTracking(); err != nil {
t.Errorf("StopTracking() should not return an error: %v", err)
}
if err = testutils.MatchPty(pty, slave, "\x1b[?25l\x1b7\x1b[0m"); err != nil {
t.Fatal(err)
}
}
71 changes: 18 additions & 53 deletions cmd/oras/root/cp.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"fmt"
"slices"
"strings"
"sync"

"oras.land/oras/cmd/oras/internal/display/status"

"github.com/opencontainers/go-digest"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -36,7 +33,7 @@ import (
"oras.land/oras/cmd/oras/internal/argument"
"oras.land/oras/cmd/oras/internal/command"
"oras.land/oras/cmd/oras/internal/display"
"oras.land/oras/cmd/oras/internal/display/status/track"
"oras.land/oras/cmd/oras/internal/display/status"
oerrors "oras.land/oras/cmd/oras/internal/errors"
"oras.land/oras/cmd/oras/internal/option"
"oras.land/oras/internal/docker"
Expand Down Expand Up @@ -131,9 +128,9 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error {
return err
}
ctx = registryutil.WithScopeHint(ctx, dst, auth.ActionPull, auth.ActionPush)
copyHandler, handler := display.NewCopyHandler(opts.Printer, dst)
statusHandler, metadataHandler := display.NewCopyHandler(opts.Printer, opts.TTY, dst)

desc, err := doCopy(ctx, copyHandler, src, dst, opts)
desc, err := doCopy(ctx, statusHandler, src, dst, opts)
if err != nil {
return err
}
Expand All @@ -147,7 +144,7 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error {
if len(opts.extraRefs) != 0 {
tagNOpts := oras.DefaultTagNOptions
tagNOpts.Concurrency = opts.concurrency
tagListener := listener.NewTaggedListener(dst, handler.OnTagged)
tagListener := listener.NewTaggedListener(dst, metadataHandler.OnTagged)
if _, err = oras.TagN(ctx, tagListener, opts.To.Reference, opts.extraRefs, tagNOpts); err != nil {
return err
}
Expand All @@ -158,68 +155,36 @@ func runCopy(cmd *cobra.Command, opts *copyOptions) error {
return nil
}

func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (ocispec.Descriptor, error) {
func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOnlyGraphTarget, dst oras.GraphTarget, opts *copyOptions) (desc ocispec.Descriptor, err error) {
// Prepare copy options
committed := &sync.Map{}
extendedCopyOptions := oras.DefaultExtendedCopyOptions
extendedCopyOptions.Concurrency = opts.concurrency
extendedCopyOptions.FindPredecessors = func(ctx context.Context, src content.ReadOnlyGraphStorage, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) {
return registry.Referrers(ctx, src, desc, "")
}

const (
promptExists = "Exists "
promptCopying = "Copying"
promptCopied = "Copied "
promptSkipped = "Skipped"
promptMounted = "Mounted"
)
srcRepo, srcIsRemote := src.(*remote.Repository)
dstRepo, dstIsRemote := dst.(*remote.Repository)
if srcIsRemote && dstIsRemote && srcRepo.Reference.Registry == dstRepo.Reference.Registry {
extendedCopyOptions.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
return []string{srcRepo.Reference.Repository}, nil
}
}
if opts.TTY == nil {
// no TTY output
extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped
extendedCopyOptions.PreCopy = copyHandler.PreCopy
extendedCopyOptions.PostCopy = copyHandler.PostCopy
extendedCopyOptions.OnMounted = copyHandler.OnMounted
} else {
// TTY output
tracked, err := track.NewTarget(dst, promptCopying, promptCopied, opts.TTY)
if err != nil {
return ocispec.Descriptor{}, err
}
defer tracked.Close()
dst = tracked
extendedCopyOptions.OnCopySkipped = func(ctx context.Context, desc ocispec.Descriptor) error {
committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
return tracked.Prompt(desc, promptExists)
}
extendedCopyOptions.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
successors, err := graph.FilteredSuccessors(ctx, desc, tracked, status.DeduplicatedFilter(committed))
if err != nil {
return err
}
for _, successor := range successors {
if err = tracked.Prompt(successor, promptSkipped); err != nil {
return err
}
}
return nil
}
extendedCopyOptions.OnMounted = func(ctx context.Context, desc ocispec.Descriptor) error {
committed.Store(desc.Digest.String(), desc.Annotations[ocispec.AnnotationTitle])
return tracked.Prompt(desc, promptMounted)
}
dst, err = copyHandler.StartTracking(dst)
if err != nil {
return desc, err
}
defer func() {
stopErr := copyHandler.StopTracking()
if err == nil {
err = stopErr
}
}()
extendedCopyOptions.OnCopySkipped = copyHandler.OnCopySkipped
extendedCopyOptions.PreCopy = copyHandler.PreCopy
extendedCopyOptions.PostCopy = copyHandler.PostCopy
extendedCopyOptions.OnMounted = copyHandler.OnMounted

var desc ocispec.Descriptor
var err error
rOpts := oras.DefaultResolveOptions
rOpts.TargetPlatform = opts.Platform.Platform
if opts.recursive {
Expand Down
Loading

0 comments on commit 363dc2d

Please sign in to comment.