Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context to the task function signature #12

Merged
merged 1 commit into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ routines would be detrimental to performances.
package main

import (
"context"
"fmt"
"os"
"runtime"
Expand Down Expand Up @@ -41,7 +42,7 @@ func main() {
id := fmt.Sprintf("task #%d", i)
// Use Submit to submit tasks for processing. Submit blocks when no
// worker is available to pick up the task.
err := wp.Submit(id, func() error {
err := wp.Submit(id, func(_ context.Context) error {
fmt.Println("isprime", n)
if IsPrime(n) {
fmt.Println(n, "is prime!")
Expand Down
3 changes: 2 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package workerpool

import (
"context"
"fmt"
"os"
"runtime"
Expand All @@ -40,7 +41,7 @@ func Example() {
id := fmt.Sprintf("task #%d", i)
// Use Submit to submit tasks for processing. Submit blocks when no
// worker is available to pick up the task.
err := wp.Submit(id, func() error {
err := wp.Submit(id, func(_ context.Context) error {
fmt.Println("isprime", n)
if IsPrime(n) {
fmt.Println(n, "is prime!")
Expand Down
7 changes: 5 additions & 2 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package workerpool

import "fmt"
import (
"context"
"fmt"
)

// Task is a unit of work.
type Task interface {
Expand All @@ -27,7 +30,7 @@ type Task interface {

type task struct {
id string
run func() error
run func(context.Context) error
err error
}

Expand Down
23 changes: 16 additions & 7 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package workerpool

import (
"context"
"errors"
"fmt"
"sync"
Expand All @@ -40,6 +41,7 @@ type WorkerPool struct {
wg sync.WaitGroup
mu sync.Mutex
draining bool
cancel context.CancelFunc
closed bool
}

Expand All @@ -53,7 +55,9 @@ func New(n int) *WorkerPool {
workers: make(chan struct{}, n),
tasks: make(chan *task),
}
go wp.run()
ctx, cancel := context.WithCancel(context.Background())
wp.cancel = cancel
go wp.run(ctx)
return wp
}

Expand All @@ -63,13 +67,16 @@ func (wp *WorkerPool) Cap() int {
}

// Submit submits f for processing by a worker. The given id is useful for
// identifying the task once it is completed.
// identifying the task once it is completed. The task f must return when the
// context ctx is cancelled.
//
// Submit blocks until a routine start processing the task.
//
// If a drain operation is in progress, ErrDraining is returned and the task
// is not submitted for processing.
// If the worker pool is closed, ErrClosed is returned and the task is not
// submitted for processing.
func (wp *WorkerPool) Submit(id string, f func() error) error {
func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error {
wp.mu.Lock()
if wp.closed {
wp.mu.Unlock()
Expand Down Expand Up @@ -127,8 +134,9 @@ func (wp *WorkerPool) Drain() ([]Task, error) {
}

// Close closes the worker pool, rendering it unable to process new tasks.
// It should be called after a call to Drain and the worker pool is no longer
// needed. Close will return ErrClosed if it has already been called.
// Close sends the cancellation signal to any running task and waits for all
// workers, if any, to return.
// Close will return ErrClosed if it has already been called.
func (wp *WorkerPool) Close() error {
wp.mu.Lock()
if wp.closed {
Expand All @@ -138,6 +146,7 @@ func (wp *WorkerPool) Close() error {
wp.closed = true
wp.mu.Unlock()

wp.cancel()
wp.wg.Wait()

// At this point, all routines have returned. This means that Submit is not
Expand All @@ -151,14 +160,14 @@ func (wp *WorkerPool) Close() error {

// run loops over the tasks channel and starts processing routines. It should
// only be called once during the lifetime of a WorkerPool.
func (wp *WorkerPool) run() {
func (wp *WorkerPool) run(ctx context.Context) {
for t := range wp.tasks {
t := t
wp.results = append(wp.results, t)
wp.workers <- struct{}{}
go func() {
defer wp.wg.Done()
t.err = t.run()
t.err = t.run(ctx)
<-wp.workers
}()
}
Expand Down
37 changes: 33 additions & 4 deletions workerpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestWorkerPoolConcurrentTasksCount(t *testing.T) {
// NOTE: schedule one more task than we have workers, hence n+1.
for i := 0; i < n+1; i++ {
id := fmt.Sprintf("task #%2d", i)
err := wp.Submit(id, func() error {
err := wp.Submit(id, func(_ context.Context) error {
working <- struct{}{}
<-ctx.Done()
return nil
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestWorkerPool(t *testing.T) {
wg.Add(numTasks - 1)
for i := 0; i < numTasks-1; i++ {
id := fmt.Sprintf("task #%2d", i)
err := wp.Submit(id, func() error {
err := wp.Submit(id, func(_ context.Context) error {
defer wg.Done()
working <- struct{}{}
done <- struct{}{}
Expand All @@ -163,7 +163,7 @@ func TestWorkerPool(t *testing.T) {
go func() {
id := fmt.Sprintf("task #%2d", numTasks-1)
ready <- struct{}{}
wp.Submit(id, func() error {
wp.Submit(id, func(_ context.Context) error {
defer wg.Done()
done <- struct{}{}
return nil
Expand Down Expand Up @@ -227,7 +227,7 @@ func TestConcurrentDrain(t *testing.T) {
wg.Add(numTasks)
for i := 0; i < numTasks; i++ {
id := fmt.Sprintf("task #%2d", i)
err := wp.Submit(id, func() error {
err := wp.Submit(id, func(_ context.Context) error {
defer wg.Done()
done <- struct{}{}
return nil
Expand Down Expand Up @@ -332,3 +332,32 @@ func TestWorkerPoolManyClose(t *testing.T) {
t.Fatalf("got %v; want %v", err, ErrClosed)
}
}

func TestWorkerPoolClose(t *testing.T) {
n := runtime.NumCPU()
wp := New(n)

// working is written to by each task as soon as possible.
working := make(chan struct{})
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
id := fmt.Sprintf("task #%2d", i)
wp.Submit(id, func(ctx context.Context) error {
working <- struct{}{}
<-ctx.Done()
wg.Done()
return ctx.Err()
})
}

// ensure n workers are busy
for i := 0; i < n; i++ {
<-working
}

if err := wp.Close(); err != nil {
t.Fatalf("unexpected error on Close(): %s", err)
}
wg.Wait() // all routines should have returned
}