diff --git a/pipe/pipe.go b/pipe/pipe.go index 90e3bbd..9366355 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -15,11 +15,16 @@ type Pipe[S, T any] task.Task[S, T] // A function that defines a Pipe's behavior. // This function should receive elements from the passed input channel, process them, // and pass the results to the passed output channel. -// Please note that this function should not close the passed channels. +// Please note that this function should not close the passed channels +// because pipe.FromFn automatically closes the output channel +// and closing the input channel is the upstream task's responsibility. // The whole pipeline will be aborted when the returned error is not nil. type PipeFn[S, T any] func(ctx context.Context, in <-chan S, out chan<- T) error // Build a Pipe with a PipeFn. func FromFn[S any, T any](fn PipeFn[S, T], opts ...task.Option) Pipe[S, T] { - return task.FromFn(task.TaskFn[S, T](fn), opts...) + return task.FromFn(task.TaskFn[S, T](func(ctx context.Context, in <-chan S, out chan<- T) error { + defer close(out) + return fn(ctx, in, out) + }), opts...) } diff --git a/pipe/take.go b/pipe/take.go index a78a683..855843e 100644 --- a/pipe/take.go +++ b/pipe/take.go @@ -27,7 +27,7 @@ func (op *TakeOp[S]) run(ctx context.Context, in <-chan S, out chan<- S) error { } c += 1 if c == op.n { - break + return task.ErrAbort } } return nil diff --git a/pipe/take_test.go b/pipe/take_test.go index cbefe68..e8737e3 100644 --- a/pipe/take_test.go +++ b/pipe/take_test.go @@ -19,7 +19,7 @@ func TestTake(t *testing.T) { t.Run("MoreItems", func(t *testing.T) { t.Parallel() - result, err := take(context.Background(), []string{"item1", "item2", "item3"}) + result, err := take(context.Background(), []string{"item1", "item2", "item3", "item4", "item5"}) require.NoError(t, err) assert.Equal(t, []string{"item1", "item2"}, result) }) diff --git a/sink/sink.go b/sink/sink.go index 6c8189b..5d81a0e 100644 --- a/sink/sink.go +++ b/sink/sink.go @@ -16,12 +16,15 @@ type Sink[S any] task.Task[S, struct{}] // A function that defines a Sink's behavior. // This function should receive elements via the passed input channel. +// Please note that this function should not close the passed channel +// because closing the input channel is the upstream task's responsibility. // The whole pipeline will be aborted when the returned error is not nil. type SinkFn[S any] func(ctx context.Context, in <-chan S) error // Build a Sink with a SinkFn. func FromFn[S any](fn SinkFn[S], opts ...task.Option) Sink[S] { return task.FromFn(func(ctx context.Context, in <-chan S, out chan<- struct{}) error { + defer close(out) return fn(ctx, in) }, opts...) } diff --git a/source/source.go b/source/source.go index 21eb8dc..197355d 100644 --- a/source/source.go +++ b/source/source.go @@ -16,7 +16,8 @@ type Source[T any] task.Task[struct{}, T] // A function that defines a Source's behavior. // This function should send elements to the passed output channel. -// Please note that this function should not close the output channel. +// Please note that this function should not close the output channel +// because source.FromFn automatically closes the channel. // The whole pipeline will be aborted when the returned error is not nil. type SourceFn[T any] func(ctx context.Context, out chan<- T) error @@ -24,6 +25,7 @@ type SourceFn[T any] func(ctx context.Context, out chan<- T) error func FromFn[T any](fn SourceFn[T], opts ...task.Option) Source[T] { return task.FromFn(func(ctx context.Context, in <-chan struct{}, out chan<- T) error { <-in // Initial input channel will be closed immediately after starting the flow + defer close(out) return fn(ctx, out) }, opts...) } diff --git a/task/connection.go b/task/connection.go index 86760e5..123840b 100644 --- a/task/connection.go +++ b/task/connection.go @@ -15,41 +15,29 @@ import ( // M: Type of elements that are sent from Src to Dest // T: Type of elements that are passed to a downstream task type Connection[S, M, T any] struct { - Src Task[S, M] // The first task that is contained in this Connection. - Dest Task[M, T] // The second task that is contained in this Connection. - srcOut chan M - destOut chan T + Src Task[S, M] // The first task that is contained in this Connection. + Dest Task[M, T] // The second task that is contained in this Connection. + c chan M } // Connect two tasks as a Connection. func Connect[S, M, T any](src Task[S, M], dest Task[M, T], buf int, opts ...Option) Task[S, T] { - conn := &Connection[S, M, T]{Src: src, Dest: dest, srcOut: make(chan M, buf), destOut: make(chan T)} + conn := &Connection[S, M, T]{Src: src, Dest: dest, c: make(chan M, buf)} return FromFn(conn.run, opts...) } -var errDownstreamFinished = errors.New("a downstream task has finished") +var ErrAbort = errors.New("connection aborted") // Run two tasks that the Connection contains. func (conn *Connection[S, M, T]) run(ctx context.Context, in <-chan S, out chan<- T) error { - grp, ctx := errgroup.WithContext(ctx) + grp, grpctx := errgroup.WithContext(ctx) - grp.Go(func() error { return conn.Src.Run(ctx, in, conn.srcOut) }) + grp.Go(func() error { return conn.Src.Run(grpctx, in, conn.c) }) - // destOut will be closed by Dest. - grp.Go(func() error { return conn.Dest.Run(ctx, conn.srcOut, conn.destOut) }) - - // out will be closed by *task.Run. - grp.Go(func() error { - for el := range conn.destOut { - if err := Emit(ctx, out, el); err != nil { - return err - } - } - return errDownstreamFinished - }) + grp.Go(func() error { return conn.Dest.Run(ctx, conn.c, out) }) err := grp.Wait() - if errors.Is(err, errDownstreamFinished) { + if errors.Is(err, ErrAbort) { err = nil } return err diff --git a/task/task.go b/task/task.go index 6837dd4..db19966 100644 --- a/task/task.go +++ b/task/task.go @@ -51,7 +51,8 @@ type task[S, T any] struct { // A function that defines a Task's behavior. // For more details, please see the Run function defined as a part of the Task interface. -// Please note that this function should not close the output channel. +// Please note that this function should close the output channel when the task finishes +// because task.FromFn does not automatically close the channel. // The whole pipeline will be aborted when the returned error is not nil. type TaskFn[S, T any] func(ctx context.Context, in <-chan S, out chan<- T) error @@ -76,7 +77,6 @@ func (t *task[S, T]) Run(ctx context.Context, in <-chan S, out chan<- T) error { defer t.RunDeferred() ctx = metadata.WithName(ctx, t.name) ctx, in, out = t.wrapInOut(ctx, in, out) - defer close(out) if err := t.TaskFn(ctx, in, out); err != nil { return err } diff --git a/task/task_test.go b/task/task_test.go index 41cc686..058d125 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -12,6 +12,7 @@ import ( ) var double = func(ctx context.Context, in <-chan string, out chan<- string) error { + defer close(out) for el := range in { if err := task.Emit(ctx, out, el+el); err != nil { return err