Skip to content

Commit

Permalink
refactor: task/internal/inout package
Browse files Browse the repository at this point in the history
  • Loading branch information
hiroara committed Aug 13, 2023
1 parent 942b14f commit e895ef7
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 71 deletions.
47 changes: 6 additions & 41 deletions task/internal/inout/inout.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,23 @@ package inout

import (
"context"
"io"
"time"
)

type inOut[T any] struct {
*Options
src <-chan T
dest chan<- T
type InOut[T any] interface {
io.Closer
passThrough(ctx context.Context) (bool, error)
}

type Options struct {
Timeout time.Duration
}

func newInOut[T any](src <-chan T, dest chan<- T, opts *Options) *inOut[T] {
if opts == nil {
opts = &Options{}
}
return &inOut[T]{src: src, dest: dest, Options: opts}
}

func (io *inOut[T]) StartWithContext(ctx context.Context) context.Context {
func StartWithContext[T any](ctx context.Context, io InOut[T]) context.Context {
ctx, cancel := context.WithCancelCause(ctx)
go func() {
defer close(io.dest)
defer io.Close()
ok := true
var err error
for ok {
Expand All @@ -37,31 +30,3 @@ func (io *inOut[T]) StartWithContext(ctx context.Context) context.Context {
}()
return ctx
}

func (io *inOut[T]) passThrough(ctx context.Context) (bool, error) {
cancel := func() {}
if io.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, io.Timeout)
}
defer cancel()
select {
case <-ctx.Done():
return false, context.Cause(ctx)
case el, ok := <-io.src:
if ok {
if err := io.emit(ctx, el); err != nil {
return false, err
}
}
return ok, nil
}
}

func (io *inOut[T]) emit(ctx context.Context, el T) error {
select {
case <-ctx.Done():
return context.Cause(ctx)
case io.dest <- el:
return nil
}
}
34 changes: 31 additions & 3 deletions task/internal/inout/input.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
package inout

import "context"

type Input[T any] struct {
*inOut[T]
dest <-chan T
src <-chan T
dest chan T
options *Options
}

func NewInput[T any](c <-chan T, opts *Options) *Input[T] {
if opts == nil {
opts = &Options{}
}
dest := make(chan T)
return &Input[T]{inOut: newInOut(c, dest, opts), dest: dest}
return &Input[T]{src: c, dest: dest, options: opts}
}

func (in *Input[T]) Chan() <-chan T {
return in.dest
}

func (in *Input[T]) Close() error {
close(in.dest)
return nil
}

func (in *Input[T]) passThrough(ctx context.Context) (bool, error) {
cancel := func() {}
if in.options.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, in.options.Timeout)
}
defer cancel()
select {
case <-ctx.Done():
return false, context.Cause(ctx)
case el, ok := <-in.src:
if ok {
in.dest <- el
}
return ok, nil
}
}
19 changes: 9 additions & 10 deletions task/internal/inout/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package inout_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -18,7 +19,7 @@ func TestInput(t *testing.T) {
in := inout.NewInput(src, nil)
dest := in.Chan()

_ = in.StartWithContext(context.Background())
_ = inout.StartWithContext[string](context.Background(), in)

go func() {
defer close(src)
Expand All @@ -40,26 +41,24 @@ func TestInputWithTimeout(t *testing.T) {
in := inout.NewInput(src, &inout.Options{Timeout: 1 * time.Nanosecond})
dest := in.Chan()

out := make([]string, 0)
// Slow upstream
go func() {
defer close(src)
time.Sleep(1 * time.Second)
for el := range dest {
out = append(out, el)
}
time.Sleep(10 * time.Second)
src <- "string1"
}()

ctx := context.Background()
ctx = in.StartWithContext(ctx)
ctx = inout.StartWithContext[string](ctx, in)

timeout := time.After(10 * time.Second)
for {
select {
case src <- "string1":
case el := <-dest:
require.Fail(t, fmt.Sprintf("Test timeout (received %s)", el))
case <-ctx.Done(): // Timeout by input option
assert.ErrorIs(t, context.Cause(ctx), context.DeadlineExceeded)
return
case <-timeout:
case <-time.After(1 * time.Second):
require.Fail(t, "Test timeout")
}
}
Expand Down
40 changes: 35 additions & 5 deletions task/internal/inout/output.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,45 @@
package inout

import "context"

type Output[T any] struct {
*inOut[T]
src chan<- T
src chan T
dest chan<- T
options *Options
}

func NewOutput[T any](c chan<- T, opts *Options) *Output[T] {
if opts == nil {
opts = &Options{}
}
src := make(chan T)
return &Output[T]{inOut: newInOut(src, c, opts), src: src}
return &Output[T]{src: src, dest: c, options: opts}
}

func (out *Output[T]) Chan() chan<- T {
return out.src
}

func (in *Output[T]) Chan() chan<- T {
return in.src
func (out *Output[T]) Close() error {
close(out.dest)
return nil
}

func (out *Output[T]) passThrough(ctx context.Context) (bool, error) {
el, ok := <-out.src
if !ok {
return false, nil
}

cancel := func() {}
if out.options.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, out.options.Timeout)
}
defer cancel()
select {
case <-ctx.Done():
return false, context.Cause(ctx)
case out.dest <- el:
return ok, nil
}
}
26 changes: 16 additions & 10 deletions task/internal/inout/output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/hiroara/carbo/task/internal/inout"
)
Expand All @@ -23,7 +24,7 @@ func TestOutput(t *testing.T) {
src <- "string2"
}()

_ = out.StartWithContext(context.Background())
_ = inout.StartWithContext[string](context.Background(), out)

assert.Equal(t, "string1", <-dest)
assert.Equal(t, "string2", <-dest)
Expand All @@ -34,19 +35,24 @@ func TestOutputWithTimeout(t *testing.T) {

dest := make(chan string)
out := inout.NewOutput(dest, &inout.Options{Timeout: 1 * time.Nanosecond})
src := out.Chan()

// Slow downstream
go func() {
defer close(src)
time.Sleep(1 * time.Second)
src <- "string1"
time.Sleep(10 * time.Second)
<-dest
}()

ctx := context.Background()
ctx = out.StartWithContext(ctx)
ctx = inout.StartWithContext[string](ctx, out)

_, ok := <-dest
assert.False(t, ok)

<-ctx.Done() // Returned context is canceled when timeout is exceeded.
src := out.Chan()
src <- "item1"
close(src)

select {
case <-ctx.Done(): // Returned context is canceled when timeout is exceeded.
assert.ErrorIs(t, ctx.Err(), context.Canceled)
case <-time.After(time.Second):
require.Fail(t, "Test timeout")
}
}
4 changes: 2 additions & 2 deletions task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (t *task[S, T]) Run(ctx context.Context, in <-chan S, out chan<- T) error {
func (t *task[S, T]) wrapInOut(ctx context.Context, in <-chan S, out chan<- T) (context.Context, <-chan S, chan<- T) {
ip := inout.NewInput(in, newOptions(t.inOpts))
op := inout.NewOutput(out, newOptions(t.outOpts))
ctx = ip.StartWithContext(ctx)
ctx = op.StartWithContext(ctx)
ctx = inout.StartWithContext[S](ctx, ip)
ctx = inout.StartWithContext[T](ctx, op)
return ctx, ip.Chan(), op.Chan()
}

0 comments on commit e895ef7

Please sign in to comment.