From b95093922f6ede87fe200f1d1964881b2d5a5170 Mon Sep 17 00:00:00 2001 From: Jonathan Boulle Date: Sun, 29 May 2016 14:36:26 +0200 Subject: [PATCH] progressutil: refactor to use channels+select Rather than maintaining a list of results and individual synchronised booleans signifying goroutine completion, just use a single channel to propagate the result of each copyReader. This means we can use a for/select loop as His Lord and Holiness Rob Pike always intended. Notably, it is now explicitly prohibited to: i) add additional copy operations to a CopyProgressPrinter after it has been started (i.e. after PrintAndWait has been called) ii) call PrintAndWait more than once If either of these is attempted, ErrAlreadyStarted is returned. To achieve i), we would need to either change the interface of the CopyProgressPrinter to accept a pre-defined size (i.e. number of copy operations), or other sychronise the "shutting down" of the CopyProgressPrinter (i.e. to block further copies being added just before the CopyProgressPrinter is about to return - otherwise we can never be sure that another will not be added as we're finishing). Both of these seem overly complex for now and I suggest we only explore them if use cases arise. Similarly, ii) would require more delicate refactoring and it's hard to imagine a use case since this package is typically designed to be used with a single output stream (i.e. stdout). Adding the safety check in this PR helps to mitigate potential abuse of the package, though (e.g. https://github.com/appc/docker2aci/issues/167) --- progressutil/iocopy.go | 123 ++++++++++++++++++++---------------- progressutil/progressbar.go | 4 +- 2 files changed, 69 insertions(+), 58 deletions(-) diff --git a/progressutil/iocopy.go b/progressutil/iocopy.go index a335545..c02f48d 100644 --- a/progressutil/iocopy.go +++ b/progressutil/iocopy.go @@ -15,32 +15,22 @@ package progressutil import ( + "errors" "fmt" "io" "sync" "time" ) -type copyReader struct { - reader io.Reader - current int64 - total int64 - done bool - doneLock sync.Mutex - pb *ProgressBar -} - -func (cr *copyReader) getDone() bool { - cr.doneLock.Lock() - val := cr.done - cr.doneLock.Unlock() - return val -} +var ( + ErrAlreadyStarted = errors.New("cannot add copies after PrintAndWait has been called") +) -func (cr *copyReader) setDone(val bool) { - cr.doneLock.Lock() - cr.done = val - cr.doneLock.Unlock() +type copyReader struct { + reader io.Reader + current int64 + total int64 + pb *ProgressBar } func (cr *copyReader) Read(p []byte) (int, error) { @@ -63,12 +53,20 @@ func (cr *copyReader) updateProgressBar() error { return cr.pb.SetCurrentProgress(progress) } +// NewCopyProgressPrinter returns a new CopyProgressPrinter +func NewCopyProgressPrinter() *CopyProgressPrinter { + return &CopyProgressPrinter{results: make(chan error), cancel: make(chan struct{})} +} + // CopyProgressPrinter will perform an arbitrary number of io.Copy calls, while // continually printing the progress of each copy. type CopyProgressPrinter struct { - readers []*copyReader - errors []error + results chan error + cancel chan struct{} + lock sync.Mutex + readers []*copyReader + started bool pbp *ProgressBarPrinter } @@ -76,8 +74,15 @@ type CopyProgressPrinter struct { // will be made to copy bytes from reader to dest, and name and size will be // used to label the progress bar and display how much progress has been made. // If size is 0, the total size of the reader is assumed to be unknown. -func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) { +// AddCopy can only be called before PrintAndWait; otherwise, ErrAlreadyStarted +// will be returned. +func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) error { cpp.lock.Lock() + defer cpp.lock.Unlock() + + if cpp.started { + return ErrAlreadyStarted + } if cpp.pbp == nil { cpp.pbp = &ProgressBarPrinter{} cpp.pbp.PadToBeEven = true @@ -93,60 +98,66 @@ func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int6 cr.pb.SetPrintAfter(cr.formattedProgress()) cpp.readers = append(cpp.readers, cr) - cpp.lock.Unlock() go func() { _, err := io.Copy(dest, cr) - if err != nil { - cpp.lock.Lock() - cpp.errors = append(cpp.errors, err) - cpp.lock.Unlock() + select { + case <-cpp.cancel: + return + case cpp.results <- err: } - cr.setDone(true) }() + return nil } // PrintAndWait will print the progress for each copy operation added with // AddCopy to printTo every printInterval. This will continue until every added // copy is finished, or until cancel is written to. +// PrintAndWait may only be called once; any subsequent calls will immediately +// return ErrAlreadyStarted. After PrintAndWait has been called, no more +// copies may be added to the CopyProgressPrinter. func (cpp *CopyProgressPrinter) PrintAndWait(printTo io.Writer, printInterval time.Duration, cancel chan struct{}) error { - for { - // If cancel is not nil, see if anything has been written to it. If - // something has, return, otherwise keep drawing. - if cancel != nil { - select { - case <-cancel: - return nil - default: - } - } - - cpp.lock.Lock() - readers := cpp.readers - errors := cpp.errors + cpp.lock.Lock() + if cpp.started { cpp.lock.Unlock() + return ErrAlreadyStarted + } + cpp.started = true + cpp.lock.Unlock() - if len(errors) > 0 { - return errors[0] - } + n := len(cpp.readers) + if n == 0 { + // Nothing to do. + return nil + } - if len(readers) > 0 { + defer close(cpp.cancel) + t := time.NewTicker(printInterval) + allDone := false + for i := 0; i < n; { + select { + case <-cancel: + return nil + case <-t.C: _, err := cpp.pbp.Print(printTo) if err != nil { return err } + case err := <-cpp.results: + i++ + if err == nil { + // Once completion is signaled, further on this just drains + // (unlikely) errors from the channel. + if !allDone { + allDone, err = cpp.pbp.Print(printTo) + } + } + if err != nil { + return err + } } - - allDone := true - for _, r := range readers { - allDone = allDone && r.getDone() - } - if allDone && len(readers) > 0 { - return nil - } - - time.Sleep(printInterval) } + return nil } func (cr *copyReader) formattedProgress() string { diff --git a/progressutil/progressbar.go b/progressutil/progressbar.go index 31c6247..224c124 100644 --- a/progressutil/progressbar.go +++ b/progressutil/progressbar.go @@ -191,14 +191,14 @@ func (pbp *ProgressBarPrinter) Print(printTo io.Writer) (bool, error) { } } - allDone := false + allDone := true for _, bar := range bars { if isTerminal(printTo) { bar.printToTerminal(printTo, numColumns, pbp.PadToBeEven, pbp.maxBefore, pbp.maxAfter) } else { bar.printToNonTerminal(printTo) } - allDone = allDone || bar.GetCurrentProgress() == 1 + allDone = allDone && bar.GetCurrentProgress() == 1 } pbp.numLinesInLastPrint = len(bars)