Skip to content

Commit

Permalink
fix!: avoid closing output channel automatically by task
Browse files Browse the repository at this point in the history
A task created with task.FromFn will not close its output channel.
source.FromFn and pipe.FromFn do this instead.
  • Loading branch information
hiroara committed Aug 11, 2023
1 parent 37a8f4f commit 942b14f
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 28 deletions.
9 changes: 7 additions & 2 deletions pipe/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@ type Pipe[S, T any] task.Task[S, T]
// A function that defines a Pipe's behavior.
// This function should receive elements from the passed input channel, process them,
// and pass the results to the passed output channel.
// Please note that this function should not close the passed channels.
// Please note that this function should not close the passed channels
// because pipe.FromFn automatically closes the output channel
// and closing the input channel is the upstream task's responsibility.
// The whole pipeline will be aborted when the returned error is not nil.
type PipeFn[S, T any] func(ctx context.Context, in <-chan S, out chan<- T) error

// Build a Pipe with a PipeFn.
func FromFn[S any, T any](fn PipeFn[S, T], opts ...task.Option) Pipe[S, T] {
return task.FromFn(task.TaskFn[S, T](fn), opts...)
return task.FromFn(task.TaskFn[S, T](func(ctx context.Context, in <-chan S, out chan<- T) error {
defer close(out)
return fn(ctx, in, out)
}), opts...)
}
2 changes: 1 addition & 1 deletion pipe/take.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (op *TakeOp[S]) run(ctx context.Context, in <-chan S, out chan<- S) error {
}
c += 1
if c == op.n {
break
return task.ErrAbort
}
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion pipe/take_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestTake(t *testing.T) {
t.Run("MoreItems", func(t *testing.T) {
t.Parallel()

result, err := take(context.Background(), []string{"item1", "item2", "item3"})
result, err := take(context.Background(), []string{"item1", "item2", "item3", "item4", "item5"})
require.NoError(t, err)
assert.Equal(t, []string{"item1", "item2"}, result)
})
Expand Down
3 changes: 3 additions & 0 deletions sink/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ type Sink[S any] task.Task[S, struct{}]

// A function that defines a Sink's behavior.
// This function should receive elements via the passed input channel.
// Please note that this function should not close the passed channel
// because closing the input channel is the upstream task's responsibility.
// The whole pipeline will be aborted when the returned error is not nil.
type SinkFn[S any] func(ctx context.Context, in <-chan S) error

// Build a Sink with a SinkFn.
func FromFn[S any](fn SinkFn[S], opts ...task.Option) Sink[S] {
return task.FromFn(func(ctx context.Context, in <-chan S, out chan<- struct{}) error {
defer close(out)
return fn(ctx, in)
}, opts...)
}
4 changes: 3 additions & 1 deletion source/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ type Source[T any] task.Task[struct{}, T]

// A function that defines a Source's behavior.
// This function should send elements to the passed output channel.
// Please note that this function should not close the output channel.
// Please note that this function should not close the output channel
// because source.FromFn automatically closes the channel.
// The whole pipeline will be aborted when the returned error is not nil.
type SourceFn[T any] func(ctx context.Context, out chan<- T) error

// Build a Source with a SourceFn.
func FromFn[T any](fn SourceFn[T], opts ...task.Option) Source[T] {
return task.FromFn(func(ctx context.Context, in <-chan struct{}, out chan<- T) error {
<-in // Initial input channel will be closed immediately after starting the flow
defer close(out)
return fn(ctx, out)
}, opts...)
}
30 changes: 9 additions & 21 deletions task/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,29 @@ import (
// M: Type of elements that are sent from Src to Dest
// T: Type of elements that are passed to a downstream task
type Connection[S, M, T any] struct {
Src Task[S, M] // The first task that is contained in this Connection.
Dest Task[M, T] // The second task that is contained in this Connection.
srcOut chan M
destOut chan T
Src Task[S, M] // The first task that is contained in this Connection.
Dest Task[M, T] // The second task that is contained in this Connection.
c chan M
}

// Connect two tasks as a Connection.
func Connect[S, M, T any](src Task[S, M], dest Task[M, T], buf int, opts ...Option) Task[S, T] {
conn := &Connection[S, M, T]{Src: src, Dest: dest, srcOut: make(chan M, buf), destOut: make(chan T)}
conn := &Connection[S, M, T]{Src: src, Dest: dest, c: make(chan M, buf)}
return FromFn(conn.run, opts...)
}

var errDownstreamFinished = errors.New("a downstream task has finished")
var ErrAbort = errors.New("connection aborted")

// Run two tasks that the Connection contains.
func (conn *Connection[S, M, T]) run(ctx context.Context, in <-chan S, out chan<- T) error {
grp, ctx := errgroup.WithContext(ctx)
grp, grpctx := errgroup.WithContext(ctx)

grp.Go(func() error { return conn.Src.Run(ctx, in, conn.srcOut) })
grp.Go(func() error { return conn.Src.Run(grpctx, in, conn.c) })

// destOut will be closed by Dest.
grp.Go(func() error { return conn.Dest.Run(ctx, conn.srcOut, conn.destOut) })

// out will be closed by *task.Run.
grp.Go(func() error {
for el := range conn.destOut {
if err := Emit(ctx, out, el); err != nil {
return err
}
}
return errDownstreamFinished
})
grp.Go(func() error { return conn.Dest.Run(ctx, conn.c, out) })

err := grp.Wait()
if errors.Is(err, errDownstreamFinished) {
if errors.Is(err, ErrAbort) {
err = nil
}
return err
Expand Down
4 changes: 2 additions & 2 deletions task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ type task[S, T any] struct {

// A function that defines a Task's behavior.
// For more details, please see the Run function defined as a part of the Task interface.
// Please note that this function should not close the output channel.
// Please note that this function should close the output channel when the task finishes
// because task.FromFn does not automatically close the channel.
// The whole pipeline will be aborted when the returned error is not nil.
type TaskFn[S, T any] func(ctx context.Context, in <-chan S, out chan<- T) error

Expand All @@ -76,7 +77,6 @@ func (t *task[S, T]) Run(ctx context.Context, in <-chan S, out chan<- T) error {
defer t.RunDeferred()
ctx = metadata.WithName(ctx, t.name)
ctx, in, out = t.wrapInOut(ctx, in, out)
defer close(out)
if err := t.TaskFn(ctx, in, out); err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
)

var double = func(ctx context.Context, in <-chan string, out chan<- string) error {
defer close(out)
for el := range in {
if err := task.Emit(ctx, out, el+el); err != nil {
return err
Expand Down

0 comments on commit 942b14f

Please sign in to comment.