Skip to content

Commit

Permalink
fix: fix case where last a few outputs are dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
hiroara committed Aug 13, 2023
1 parent e895ef7 commit 7f84697
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 75 deletions.
2 changes: 1 addition & 1 deletion pipe/take.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (op *TakeOp[S]) run(ctx context.Context, in <-chan S, out chan<- S) error {
}
c += 1
if c == op.n {
return task.ErrAbort
break
}
}
return nil
Expand Down
43 changes: 34 additions & 9 deletions task/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,44 @@ func Connect[S, M, T any](src Task[S, M], dest Task[M, T], buf int, opts ...Opti
return FromFn(conn.run, opts...)
}

var ErrAbort = errors.New("connection aborted")
var errDownstreamFinished = errors.New("a downstream task has finished")

func ignoreIfErrDownstreamFinished(err error) error {
if errors.Is(err, errDownstreamFinished) {
return nil
}
return err
}

// Run two tasks that the Connection contains.
func (conn *Connection[S, M, T]) run(ctx context.Context, in <-chan S, out chan<- T) error {
grp, grpctx := errgroup.WithContext(ctx)
grp, ctx := errgroup.WithContext(ctx)
grp.SetLimit(2)

grp.Go(func() error { return conn.Src.Run(grpctx, in, conn.c) })
srcCtx, cancel := context.WithCancelCause(ctx)

grp.Go(func() error { return conn.Dest.Run(ctx, conn.c, out) })
destDone := make(chan struct{})

err := grp.Wait()
if errors.Is(err, ErrAbort) {
err = nil
}
return err
grp.Go(func() error {
err := conn.Src.Run(srcCtx, in, conn.c)
return ignoreIfErrDownstreamFinished(err)
})

grp.Go(func() error {
defer close(destDone)
err := conn.Dest.Run(ctx, conn.c, out)
return ignoreIfErrDownstreamFinished(err)
})

grp.Go(func() error {
select {
case <-ctx.Done():
case <-destDone:
// Call cancel if Dest finished early
cancel(errDownstreamFinished)
}
return nil
})

return grp.Wait()
}
22 changes: 8 additions & 14 deletions task/internal/inout/inout.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@ type Options struct {
Timeout time.Duration
}

func StartWithContext[T any](ctx context.Context, io InOut[T]) context.Context {
ctx, cancel := context.WithCancelCause(ctx)
go func() {
defer io.Close()
ok := true
var err error
for ok {
ok, err = io.passThrough(ctx)
}
if err != nil {
cancel(err)
}
}()
return ctx
func StartWithContext[T any](ctx context.Context, io InOut[T]) error {
defer io.Close()
ok := true
var err error
for ok {
ok, err = io.passThrough(ctx)
}
return err
}
5 changes: 4 additions & 1 deletion task/internal/inout/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func (in *Input[T]) passThrough(ctx context.Context) (bool, error) {
return false, context.Cause(ctx)
case el, ok := <-in.src:
if ok {
in.dest <- el
select {
case <-ctx.Done():
case in.dest <- el:
}
}
return ok, nil
}
Expand Down
38 changes: 17 additions & 21 deletions task/internal/inout/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package inout_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -19,19 +18,25 @@ func TestInput(t *testing.T) {
in := inout.NewInput(src, nil)
dest := in.Chan()

_ = inout.StartWithContext[string](context.Background(), in)

go func() {
defer close(src)
src <- "string1"
src <- "string2"
}()

out := make([]string, 0)
for el := range dest {
out = append(out, el)
}
assert.Equal(t, []string{"string1", "string2"}, out)
checked := make(chan struct{})
go func() {
defer close(checked)

out := make([]string, 0)
for el := range dest {
out = append(out, el)
}
assert.Equal(t, []string{"string1", "string2"}, out)
}()

require.NoError(t, inout.StartWithContext[string](context.Background(), in))
<-checked // Wait until consumer goroutine is done
}

func TestInputWithTimeout(t *testing.T) {
Expand All @@ -48,18 +53,9 @@ func TestInputWithTimeout(t *testing.T) {
src <- "string1"
}()

ctx := context.Background()
ctx = inout.StartWithContext[string](ctx, in)
err := inout.StartWithContext[string](context.Background(), in)
assert.ErrorIs(t, err, context.DeadlineExceeded)

for {
select {
case el := <-dest:
require.Fail(t, fmt.Sprintf("Test timeout (received %s)", el))
case <-ctx.Done(): // Timeout by input option
assert.ErrorIs(t, context.Cause(ctx), context.DeadlineExceeded)
return
case <-time.After(1 * time.Second):
require.Fail(t, "Test timeout")
}
}
_, ok := <-dest
assert.False(t, ok)
}
14 changes: 10 additions & 4 deletions task/internal/inout/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ func (out *Output[T]) Close() error {
}

func (out *Output[T]) passThrough(ctx context.Context) (bool, error) {
el, ok := <-out.src
if !ok {
return false, nil
var el T
var ok bool
select {
case <-ctx.Done():
return false, context.Cause(ctx)
case el, ok = <-out.src:
if !ok {
return false, nil
}
}

cancel := func() {}
Expand All @@ -40,6 +46,6 @@ func (out *Output[T]) passThrough(ctx context.Context) (bool, error) {
case <-ctx.Done():
return false, context.Cause(ctx)
case out.dest <- el:
return ok, nil
return true, nil
}
}
31 changes: 16 additions & 15 deletions task/internal/inout/output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ func TestOutput(t *testing.T) {
src <- "string2"
}()

_ = inout.StartWithContext[string](context.Background(), out)
checked := make(chan struct{})
go func() {
defer close(checked)

assert.Equal(t, "string1", <-dest)
assert.Equal(t, "string2", <-dest)
}()

assert.Equal(t, "string1", <-dest)
assert.Equal(t, "string2", <-dest)
require.NoError(t, inout.StartWithContext[string](context.Background(), out))
<-checked // Wait until consumer goroutine is done
}

func TestOutputWithTimeout(t *testing.T) {
Expand All @@ -42,17 +48,12 @@ func TestOutputWithTimeout(t *testing.T) {
<-dest
}()

ctx := context.Background()
ctx = inout.StartWithContext[string](ctx, out)

src := out.Chan()
src <- "item1"
close(src)

select {
case <-ctx.Done(): // Returned context is canceled when timeout is exceeded.
assert.ErrorIs(t, ctx.Err(), context.Canceled)
case <-time.After(time.Second):
require.Fail(t, "Test timeout")
}
go func() {
defer close(src)
src <- "item1"
}()

err := inout.StartWithContext[string](context.Background(), out)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}
31 changes: 21 additions & 10 deletions task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package task
import (
"context"

"golang.org/x/sync/errgroup"

"github.com/hiroara/carbo/deferrer"
"github.com/hiroara/carbo/task/internal/inout"
"github.com/hiroara/carbo/task/internal/metadata"
Expand Down Expand Up @@ -76,17 +78,26 @@ var GetName = metadata.GetName
func (t *task[S, T]) Run(ctx context.Context, in <-chan S, out chan<- T) error {
defer t.RunDeferred()
ctx = metadata.WithName(ctx, t.name)
ctx, in, out = t.wrapInOut(ctx, in, out)
if err := t.TaskFn(ctx, in, out); err != nil {
return err
}
return context.Cause(ctx)
}

func (t *task[S, T]) wrapInOut(ctx context.Context, in <-chan S, out chan<- T) (context.Context, <-chan S, chan<- T) {
grp, ctx := errgroup.WithContext(ctx)

ip := inout.NewInput(in, newOptions(t.inOpts))
op := inout.NewOutput(out, newOptions(t.outOpts))
ctx = inout.StartWithContext[S](ctx, ip)
ctx = inout.StartWithContext[T](ctx, op)
return ctx, ip.Chan(), op.Chan()

grp.Go(func() error {
err := inout.StartWithContext[S](ctx, ip)
return ignoreIfErrDownstreamFinished(err)
})

grp.Go(func() error {
err := t.TaskFn(ctx, ip.Chan(), op.Chan())
return ignoreIfErrDownstreamFinished(err)
})

grp.Go(func() error {
err := inout.StartWithContext[T](ctx, op)
return ignoreIfErrDownstreamFinished(err)
})

return grp.Wait()
}

0 comments on commit 7f84697

Please sign in to comment.