Skip to content

Commit

Permalink
progressutil: refactor to use channels+select
Browse files Browse the repository at this point in the history
Rather than maintaining a list of results and individual synchronised
booleans signifying goroutine completion, just use a single channel to
propagate the result of each copyReader. This means we can use a
for/select loop as His Lord and Holiness Rob Pike always intended.

Notably, it is now explicitly prohibited to:
i) add additional copy operations to a CopyProgressPrinter after it has
   been started (i.e. after PrintAndWait has been called)
ii) call PrintAndWait more than once
If either of these is attempted, ErrAlreadyStarted is returned.

To achieve i), we would need to either change the interface of the
CopyProgressPrinter to accept a pre-defined size (i.e. number of copy
operations), or other sychronise the "shutting down" of the
CopyProgressPrinter (i.e. to block further copies being added just
before the CopyProgressPrinter is about to return - otherwise we can
never be sure that another will not be added as we're finishing). Both
of these seem overly complex for now and I suggest we only explore them
if use cases arise.

Similarly, ii) would require more delicate refactoring and it's hard to
imagine a use case since this package is typically designed to be used
with a single output stream (i.e. stdout). Adding the safety check in
this PR helps to mitigate potential abuse of the package, though (e.g.
appc/docker2aci#167)
  • Loading branch information
jonboulle authored and lucab committed May 30, 2016
1 parent 77313bb commit b950939
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 58 deletions.
123 changes: 67 additions & 56 deletions progressutil/iocopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,22 @@
package progressutil

import (
"errors"
"fmt"
"io"
"sync"
"time"
)

type copyReader struct {
reader io.Reader
current int64
total int64
done bool
doneLock sync.Mutex
pb *ProgressBar
}

func (cr *copyReader) getDone() bool {
cr.doneLock.Lock()
val := cr.done
cr.doneLock.Unlock()
return val
}
var (
ErrAlreadyStarted = errors.New("cannot add copies after PrintAndWait has been called")
)

func (cr *copyReader) setDone(val bool) {
cr.doneLock.Lock()
cr.done = val
cr.doneLock.Unlock()
type copyReader struct {
reader io.Reader
current int64
total int64
pb *ProgressBar
}

func (cr *copyReader) Read(p []byte) (int, error) {
Expand All @@ -63,21 +53,36 @@ func (cr *copyReader) updateProgressBar() error {
return cr.pb.SetCurrentProgress(progress)
}

// NewCopyProgressPrinter returns a new CopyProgressPrinter
func NewCopyProgressPrinter() *CopyProgressPrinter {
return &CopyProgressPrinter{results: make(chan error), cancel: make(chan struct{})}
}

// CopyProgressPrinter will perform an arbitrary number of io.Copy calls, while
// continually printing the progress of each copy.
type CopyProgressPrinter struct {
readers []*copyReader
errors []error
results chan error
cancel chan struct{}

lock sync.Mutex
readers []*copyReader
started bool
pbp *ProgressBarPrinter
}

// AddCopy adds a copy for this CopyProgressPrinter to perform. An io.Copy call
// will be made to copy bytes from reader to dest, and name and size will be
// used to label the progress bar and display how much progress has been made.
// If size is 0, the total size of the reader is assumed to be unknown.
func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) {
// AddCopy can only be called before PrintAndWait; otherwise, ErrAlreadyStarted
// will be returned.
func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) error {
cpp.lock.Lock()
defer cpp.lock.Unlock()

if cpp.started {
return ErrAlreadyStarted
}
if cpp.pbp == nil {
cpp.pbp = &ProgressBarPrinter{}
cpp.pbp.PadToBeEven = true
Expand All @@ -93,60 +98,66 @@ func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int6
cr.pb.SetPrintAfter(cr.formattedProgress())

cpp.readers = append(cpp.readers, cr)
cpp.lock.Unlock()

go func() {
_, err := io.Copy(dest, cr)
if err != nil {
cpp.lock.Lock()
cpp.errors = append(cpp.errors, err)
cpp.lock.Unlock()
select {
case <-cpp.cancel:
return
case cpp.results <- err:
}
cr.setDone(true)
}()
return nil
}

// PrintAndWait will print the progress for each copy operation added with
// AddCopy to printTo every printInterval. This will continue until every added
// copy is finished, or until cancel is written to.
// PrintAndWait may only be called once; any subsequent calls will immediately
// return ErrAlreadyStarted. After PrintAndWait has been called, no more
// copies may be added to the CopyProgressPrinter.
func (cpp *CopyProgressPrinter) PrintAndWait(printTo io.Writer, printInterval time.Duration, cancel chan struct{}) error {
for {
// If cancel is not nil, see if anything has been written to it. If
// something has, return, otherwise keep drawing.
if cancel != nil {
select {
case <-cancel:
return nil
default:
}
}

cpp.lock.Lock()
readers := cpp.readers
errors := cpp.errors
cpp.lock.Lock()
if cpp.started {
cpp.lock.Unlock()
return ErrAlreadyStarted
}
cpp.started = true
cpp.lock.Unlock()

if len(errors) > 0 {
return errors[0]
}
n := len(cpp.readers)
if n == 0 {
// Nothing to do.
return nil
}

if len(readers) > 0 {
defer close(cpp.cancel)
t := time.NewTicker(printInterval)
allDone := false
for i := 0; i < n; {
select {
case <-cancel:
return nil
case <-t.C:
_, err := cpp.pbp.Print(printTo)
if err != nil {
return err
}
case err := <-cpp.results:
i++
if err == nil {
// Once completion is signaled, further on this just drains
// (unlikely) errors from the channel.
if !allDone {
allDone, err = cpp.pbp.Print(printTo)
}
}
if err != nil {
return err
}
}

allDone := true
for _, r := range readers {
allDone = allDone && r.getDone()
}
if allDone && len(readers) > 0 {
return nil
}

time.Sleep(printInterval)
}
return nil
}

func (cr *copyReader) formattedProgress() string {
Expand Down
4 changes: 2 additions & 2 deletions progressutil/progressbar.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ func (pbp *ProgressBarPrinter) Print(printTo io.Writer) (bool, error) {
}
}

allDone := false
allDone := true
for _, bar := range bars {
if isTerminal(printTo) {
bar.printToTerminal(printTo, numColumns, pbp.PadToBeEven, pbp.maxBefore, pbp.maxAfter)
} else {
bar.printToNonTerminal(printTo)
}
allDone = allDone || bar.GetCurrentProgress() == 1
allDone = allDone && bar.GetCurrentProgress() == 1
}

pbp.numLinesInLastPrint = len(bars)
Expand Down

0 comments on commit b950939

Please sign in to comment.