diff --git a/pipe/take.go b/pipe/take.go index 855843e..a78a683 100644 --- a/pipe/take.go +++ b/pipe/take.go @@ -27,7 +27,7 @@ func (op *TakeOp[S]) run(ctx context.Context, in <-chan S, out chan<- S) error { } c += 1 if c == op.n { - return task.ErrAbort + break } } return nil diff --git a/task/connection.go b/task/connection.go index 123840b..0b570ac 100644 --- a/task/connection.go +++ b/task/connection.go @@ -26,19 +26,44 @@ func Connect[S, M, T any](src Task[S, M], dest Task[M, T], buf int, opts ...Opti return FromFn(conn.run, opts...) } -var ErrAbort = errors.New("connection aborted") +var errDownstreamFinished = errors.New("a downstream task has finished") + +func ignoreIfErrDownstreamFinished(err error) error { + if errors.Is(err, errDownstreamFinished) { + return nil + } + return err +} // Run two tasks that the Connection contains. func (conn *Connection[S, M, T]) run(ctx context.Context, in <-chan S, out chan<- T) error { - grp, grpctx := errgroup.WithContext(ctx) + grp, ctx := errgroup.WithContext(ctx) + grp.SetLimit(2) - grp.Go(func() error { return conn.Src.Run(grpctx, in, conn.c) }) + srcCtx, cancel := context.WithCancelCause(ctx) - grp.Go(func() error { return conn.Dest.Run(ctx, conn.c, out) }) + destDone := make(chan struct{}) - err := grp.Wait() - if errors.Is(err, ErrAbort) { - err = nil - } - return err + grp.Go(func() error { + err := conn.Src.Run(srcCtx, in, conn.c) + return ignoreIfErrDownstreamFinished(err) + }) + + grp.Go(func() error { + defer close(destDone) + err := conn.Dest.Run(ctx, conn.c, out) + return ignoreIfErrDownstreamFinished(err) + }) + + grp.Go(func() error { + select { + case <-ctx.Done(): + case <-destDone: + // Call cancel if Dest finished early + cancel(errDownstreamFinished) + } + return nil + }) + + return grp.Wait() } diff --git a/task/internal/inout/inout.go b/task/internal/inout/inout.go index 5732dd6..c675b56 100644 --- a/task/internal/inout/inout.go +++ b/task/internal/inout/inout.go @@ -15,18 +15,12 @@ type Options struct { Timeout time.Duration } -func StartWithContext[T any](ctx context.Context, io InOut[T]) context.Context { - ctx, cancel := context.WithCancelCause(ctx) - go func() { - defer io.Close() - ok := true - var err error - for ok { - ok, err = io.passThrough(ctx) - } - if err != nil { - cancel(err) - } - }() - return ctx +func StartWithContext[T any](ctx context.Context, io InOut[T]) error { + defer io.Close() + ok := true + var err error + for ok { + ok, err = io.passThrough(ctx) + } + return err } diff --git a/task/internal/inout/input.go b/task/internal/inout/input.go index 65024ef..518f944 100644 --- a/task/internal/inout/input.go +++ b/task/internal/inout/input.go @@ -36,7 +36,10 @@ func (in *Input[T]) passThrough(ctx context.Context) (bool, error) { return false, context.Cause(ctx) case el, ok := <-in.src: if ok { - in.dest <- el + select { + case <-ctx.Done(): + case in.dest <- el: + } } return ok, nil } diff --git a/task/internal/inout/input_test.go b/task/internal/inout/input_test.go index 0d9c003..23adba0 100644 --- a/task/internal/inout/input_test.go +++ b/task/internal/inout/input_test.go @@ -2,7 +2,6 @@ package inout_test import ( "context" - "fmt" "testing" "time" @@ -19,19 +18,25 @@ func TestInput(t *testing.T) { in := inout.NewInput(src, nil) dest := in.Chan() - _ = inout.StartWithContext[string](context.Background(), in) - go func() { defer close(src) src <- "string1" src <- "string2" }() - out := make([]string, 0) - for el := range dest { - out = append(out, el) - } - assert.Equal(t, []string{"string1", "string2"}, out) + checked := make(chan struct{}) + go func() { + defer close(checked) + + out := make([]string, 0) + for el := range dest { + out = append(out, el) + } + assert.Equal(t, []string{"string1", "string2"}, out) + }() + + require.NoError(t, inout.StartWithContext[string](context.Background(), in)) + <-checked // Wait until consumer goroutine is done } func TestInputWithTimeout(t *testing.T) { @@ -48,18 +53,9 @@ func TestInputWithTimeout(t *testing.T) { src <- "string1" }() - ctx := context.Background() - ctx = inout.StartWithContext[string](ctx, in) + err := inout.StartWithContext[string](context.Background(), in) + assert.ErrorIs(t, err, context.DeadlineExceeded) - for { - select { - case el := <-dest: - require.Fail(t, fmt.Sprintf("Test timeout (received %s)", el)) - case <-ctx.Done(): // Timeout by input option - assert.ErrorIs(t, context.Cause(ctx), context.DeadlineExceeded) - return - case <-time.After(1 * time.Second): - require.Fail(t, "Test timeout") - } - } + _, ok := <-dest + assert.False(t, ok) } diff --git a/task/internal/inout/output.go b/task/internal/inout/output.go index 1f5e525..8817506 100644 --- a/task/internal/inout/output.go +++ b/task/internal/inout/output.go @@ -26,9 +26,15 @@ func (out *Output[T]) Close() error { } func (out *Output[T]) passThrough(ctx context.Context) (bool, error) { - el, ok := <-out.src - if !ok { - return false, nil + var el T + var ok bool + select { + case <-ctx.Done(): + return false, context.Cause(ctx) + case el, ok = <-out.src: + if !ok { + return false, nil + } } cancel := func() {} @@ -40,6 +46,6 @@ func (out *Output[T]) passThrough(ctx context.Context) (bool, error) { case <-ctx.Done(): return false, context.Cause(ctx) case out.dest <- el: - return ok, nil + return true, nil } } diff --git a/task/internal/inout/output_test.go b/task/internal/inout/output_test.go index 95cab32..3f32522 100644 --- a/task/internal/inout/output_test.go +++ b/task/internal/inout/output_test.go @@ -24,10 +24,16 @@ func TestOutput(t *testing.T) { src <- "string2" }() - _ = inout.StartWithContext[string](context.Background(), out) + checked := make(chan struct{}) + go func() { + defer close(checked) + + assert.Equal(t, "string1", <-dest) + assert.Equal(t, "string2", <-dest) + }() - assert.Equal(t, "string1", <-dest) - assert.Equal(t, "string2", <-dest) + require.NoError(t, inout.StartWithContext[string](context.Background(), out)) + <-checked // Wait until consumer goroutine is done } func TestOutputWithTimeout(t *testing.T) { @@ -42,17 +48,12 @@ func TestOutputWithTimeout(t *testing.T) { <-dest }() - ctx := context.Background() - ctx = inout.StartWithContext[string](ctx, out) - src := out.Chan() - src <- "item1" - close(src) - - select { - case <-ctx.Done(): // Returned context is canceled when timeout is exceeded. - assert.ErrorIs(t, ctx.Err(), context.Canceled) - case <-time.After(time.Second): - require.Fail(t, "Test timeout") - } + go func() { + defer close(src) + src <- "item1" + }() + + err := inout.StartWithContext[string](context.Background(), out) + assert.ErrorIs(t, err, context.DeadlineExceeded) } diff --git a/task/task.go b/task/task.go index dc2c5d2..0de6896 100644 --- a/task/task.go +++ b/task/task.go @@ -3,6 +3,8 @@ package task import ( "context" + "golang.org/x/sync/errgroup" + "github.com/hiroara/carbo/deferrer" "github.com/hiroara/carbo/task/internal/inout" "github.com/hiroara/carbo/task/internal/metadata" @@ -76,17 +78,26 @@ var GetName = metadata.GetName func (t *task[S, T]) Run(ctx context.Context, in <-chan S, out chan<- T) error { defer t.RunDeferred() ctx = metadata.WithName(ctx, t.name) - ctx, in, out = t.wrapInOut(ctx, in, out) - if err := t.TaskFn(ctx, in, out); err != nil { - return err - } - return context.Cause(ctx) -} -func (t *task[S, T]) wrapInOut(ctx context.Context, in <-chan S, out chan<- T) (context.Context, <-chan S, chan<- T) { + grp, ctx := errgroup.WithContext(ctx) + ip := inout.NewInput(in, newOptions(t.inOpts)) op := inout.NewOutput(out, newOptions(t.outOpts)) - ctx = inout.StartWithContext[S](ctx, ip) - ctx = inout.StartWithContext[T](ctx, op) - return ctx, ip.Chan(), op.Chan() + + grp.Go(func() error { + err := inout.StartWithContext[S](ctx, ip) + return ignoreIfErrDownstreamFinished(err) + }) + + grp.Go(func() error { + err := t.TaskFn(ctx, ip.Chan(), op.Chan()) + return ignoreIfErrDownstreamFinished(err) + }) + + grp.Go(func() error { + err := inout.StartWithContext[T](ctx, op) + return ignoreIfErrDownstreamFinished(err) + }) + + return grp.Wait() }