From bc88e13edd8d4857fc1171aa1da3ec815d3e5e7d Mon Sep 17 00:00:00 2001 From: tangenta Date: Mon, 21 Aug 2023 14:39:33 +0800 Subject: [PATCH] workerpool: generic result type for worker pool (#46185) ref pingcap/tidb#46258 --- ddl/backfilling_scheduler.go | 11 ++- executor/executor.go | 8 +- resourcemanager/pool/workerpool/BUILD.bazel | 3 +- resourcemanager/pool/workerpool/workerpool.go | 98 +++++++++++-------- .../pool/workerpool/workpool_test.go | 96 +++++++++++++++++- 5 files changed, 162 insertions(+), 54 deletions(-) diff --git a/ddl/backfilling_scheduler.go b/ddl/backfilling_scheduler.go index caa1e96cfc15b..6596b3e49d611 100644 --- a/ddl/backfilling_scheduler.go +++ b/ddl/backfilling_scheduler.go @@ -268,7 +268,7 @@ type ingestBackfillScheduler struct { copReqSenderPool *copReqSenderPool - writerPool *workerpool.WorkerPool[idxRecResult] + writerPool *workerpool.WorkerPool[idxRecResult, workerpool.None] writerMaxID int poolErr chan error backendCtx ingest.BackendCtx @@ -308,12 +308,12 @@ func (b *ingestBackfillScheduler) setupWorkers() error { } b.copReqSenderPool = copReqSenderPool readerCnt, writerCnt := b.expectedWorkerSize() - skipReg := workerpool.OptionSkipRegister[idxRecResult]{} writerPool, err := workerpool.NewWorkerPool[idxRecResult]("ingest_writer", - poolutil.DDL, writerCnt, b.createWorker, skipReg) + poolutil.DDL, writerCnt, b.createWorker) if err != nil { return errors.Trace(err) } + writerPool.Start() b.writerPool = writerPool b.copReqSenderPool.chunkSender = writerPool b.copReqSenderPool.adjustSize(readerCnt) @@ -382,7 +382,7 @@ func (b *ingestBackfillScheduler) adjustWorkerSize() error { return nil } -func (b *ingestBackfillScheduler) createWorker() workerpool.Worker[idxRecResult] { +func (b *ingestBackfillScheduler) createWorker() workerpool.Worker[idxRecResult, workerpool.None] { reorgInfo := b.reorgInfo job := reorgInfo.Job sessCtx, err := newSessCtx(reorgInfo) @@ -447,7 +447,7 @@ func (*ingestBackfillScheduler) expectedWorkerSize() (readerSize int, writerSize return readerSize, writerSize } -func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) { +func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) (_ workerpool.None) { defer util.Recover(metrics.LabelDDL, "ingestWorker.HandleTask", func() { w.resultCh <- &backfillResult{taskID: rs.id, err: dbterror.ErrReorgPanic} }, false) @@ -494,6 +494,7 @@ func (w *addIndexIngestWorker) HandleTask(rs idxRecResult) { ResultCounterForTest.Add(1) } w.resultCh <- result + return } func (*addIndexIngestWorker) Close() {} diff --git a/executor/executor.go b/executor/executor.go index e6856a6703193..f7eb3dabb1a43 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -2350,7 +2350,7 @@ func getCheckSum(ctx context.Context, se sessionctx.Context, sql string) ([]grou } // HandleTask implements the Worker interface. -func (w *checkIndexWorker) HandleTask(task checkIndexTask) { +func (w *checkIndexWorker) HandleTask(task checkIndexTask) (_ workerpool.None) { defer w.e.wg.Done() idxInfo := w.indexInfos[task.indexOffset] bucketSize := int(CheckTableFastBucketSize.Load()) @@ -2688,12 +2688,13 @@ func (w *checkIndexWorker) HandleTask(task checkIndexTask) { } } } + return } // Close implements the Worker interface. func (*checkIndexWorker) Close() {} -func (e *FastCheckTableExec) createWorker() workerpool.Worker[checkIndexTask] { +func (e *FastCheckTableExec) createWorker() workerpool.Worker[checkIndexTask, workerpool.None] { return &checkIndexWorker{sctx: e.Ctx(), dbName: e.dbName, table: e.table, indexInfos: e.indexInfos, e: e} } @@ -2711,10 +2712,11 @@ func (e *FastCheckTableExec) Next(context.Context, *chunk.Chunk) error { }() workerPool, err := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", - poolutil.CheckTable, 3, e.createWorker, workerpool.OptionSkipRegister[checkIndexTask]{}) + poolutil.CheckTable, 3, e.createWorker) if err != nil { return errors.Trace(err) } + workerPool.Start() e.wg.Add(len(e.indexInfos)) for i := range e.indexInfos { diff --git a/resourcemanager/pool/workerpool/BUILD.bazel b/resourcemanager/pool/workerpool/BUILD.bazel index cc6f0222e834b..30647736b656d 100644 --- a/resourcemanager/pool/workerpool/BUILD.bazel +++ b/resourcemanager/pool/workerpool/BUILD.bazel @@ -7,7 +7,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//metrics", - "//resourcemanager", "//resourcemanager/util", "//util", "//util/syncutil", @@ -25,11 +24,13 @@ go_test( embed = [":workerpool"], flaky = True, race = "on", + shard_count = 3, deps = [ "//resourcemanager/util", "//testkit/testsetup", "//util/logutil", "@com_github_stretchr_testify//require", + "@org_golang_x_sync//errgroup", "@org_uber_go_goleak//:goleak", "@org_uber_go_zap//:zap", ], diff --git a/resourcemanager/pool/workerpool/workerpool.go b/resourcemanager/pool/workerpool/workerpool.go index e97ed218ee228..dee92351b674d 100644 --- a/resourcemanager/pool/workerpool/workerpool.go +++ b/resourcemanager/pool/workerpool/workerpool.go @@ -18,7 +18,6 @@ import ( "time" "github.com/pingcap/tidb/metrics" - "github.com/pingcap/tidb/resourcemanager" "github.com/pingcap/tidb/resourcemanager/util" tidbutil "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/syncutil" @@ -26,85 +25,99 @@ import ( ) // Worker is worker interface. -type Worker[T any] interface { - HandleTask(task T) +type Worker[T, R any] interface { + HandleTask(task T) R Close() } // WorkerPool is a pool of workers. -type WorkerPool[T any] struct { +type WorkerPool[T, R any] struct { name string numWorkers int32 originWorkers int32 runningTask atomicutil.Int32 taskChan chan T + resChan chan R quitChan chan struct{} wg tidbutil.WaitGroupWrapper - createWorker func() Worker[T] + createWorker func() Worker[T, R] lastTuneTs atomicutil.Time mu syncutil.RWMutex - skipRegister bool } // Option is the config option for WorkerPool. -type Option[T any] interface { - Apply(pool *WorkerPool[T]) +type Option[T, R any] interface { + Apply(pool *WorkerPool[T, R]) } -// OptionSkipRegister is an option to skip register the worker pool to resource manager. -type OptionSkipRegister[T any] struct{} - -// Apply implements the Option interface. -func (OptionSkipRegister[T]) Apply(pool *WorkerPool[T]) { - pool.skipRegister = true -} +// None is a type placeholder for the worker pool that does not have a result receiver. +type None struct{} // NewWorkerPool creates a new worker pool. -func NewWorkerPool[T any](name string, component util.Component, numWorkers int, - createWorker func() Worker[T], opts ...Option[T]) (*WorkerPool[T], error) { +func NewWorkerPool[T, R any](name string, _ util.Component, numWorkers int, + createWorker func() Worker[T, R], opts ...Option[T, R]) (*WorkerPool[T, R], error) { if numWorkers <= 0 { numWorkers = 1 } - p := &WorkerPool[T]{ + p := &WorkerPool[T, R]{ name: name, numWorkers: int32(numWorkers), originWorkers: int32(numWorkers), - taskChan: make(chan T), quitChan: make(chan struct{}), - createWorker: createWorker, } for _, opt := range opts { opt.Apply(p) } - if !p.skipRegister { - err := resourcemanager.InstanceResourceManager.Register(p, name, component) - if err != nil { - return nil, err + p.createWorker = createWorker + return p, nil +} + +// SetTaskReceiver sets the task receiver for the pool. +func (p *WorkerPool[T, R]) SetTaskReceiver(recv chan T) { + p.taskChan = recv +} + +// SetResultSender sets the result sender for the pool. +func (p *WorkerPool[T, R]) SetResultSender(sender chan R) { + p.resChan = sender +} + +// Start starts default count of workers. +func (p *WorkerPool[T, R]) Start() { + if p.taskChan == nil { + p.taskChan = make(chan T) + } + + if p.resChan == nil { + var zero R + var r interface{} = zero + if _, ok := r.(None); !ok { + p.resChan = make(chan R) } } - // Start default count of workers. for i := 0; i < int(p.numWorkers); i++ { p.runAWorker() } - - return p, nil } -func (p *WorkerPool[T]) handleTaskWithRecover(w Worker[T], task T) { +func (p *WorkerPool[T, R]) handleTaskWithRecover(w Worker[T, R], task T) { p.runningTask.Add(1) defer func() { p.runningTask.Add(-1) }() defer tidbutil.Recover(metrics.LabelWorkerPool, "handleTaskWithRecover", nil, false) - w.HandleTask(task) + r := w.HandleTask(task) + if p.resChan != nil { + p.resChan <- r + } } -func (p *WorkerPool[T]) runAWorker() { +func (p *WorkerPool[T, R]) runAWorker() { w := p.createWorker() if w == nil { return // Fail to create worker, quit. @@ -123,12 +136,17 @@ func (p *WorkerPool[T]) runAWorker() { } // AddTask adds a task to the pool. -func (p *WorkerPool[T]) AddTask(task T) { +func (p *WorkerPool[T, R]) AddTask(task T) { p.taskChan <- task } +// GetResultChan gets the result channel from the pool. +func (p *WorkerPool[T, R]) GetResultChan() <-chan R { + return p.resChan +} + // Tune tunes the pool to the specified number of workers. -func (p *WorkerPool[T]) Tune(numWorkers int32) { +func (p *WorkerPool[T, R]) Tune(numWorkers int32) { if numWorkers <= 0 { numWorkers = 1 } @@ -151,37 +169,37 @@ func (p *WorkerPool[T]) Tune(numWorkers int32) { } // LastTunerTs returns the last time when the pool was tuned. -func (p *WorkerPool[T]) LastTunerTs() time.Time { +func (p *WorkerPool[T, R]) LastTunerTs() time.Time { return p.lastTuneTs.Load() } // Cap returns the capacity of the pool. -func (p *WorkerPool[T]) Cap() int32 { +func (p *WorkerPool[T, R]) Cap() int32 { p.mu.RLock() defer p.mu.RUnlock() return p.numWorkers } // Running returns the number of running workers. -func (p *WorkerPool[T]) Running() int32 { +func (p *WorkerPool[T, R]) Running() int32 { return p.runningTask.Load() } // Name returns the name of the pool. -func (p *WorkerPool[T]) Name() string { +func (p *WorkerPool[T, R]) Name() string { return p.name } // ReleaseAndWait releases the pool and wait for complete. -func (p *WorkerPool[T]) ReleaseAndWait() { +func (p *WorkerPool[T, R]) ReleaseAndWait() { close(p.quitChan) p.wg.Wait() - if !p.skipRegister { - resourcemanager.InstanceResourceManager.Unregister(p.Name()) + if p.resChan != nil { + close(p.resChan) } } // GetOriginConcurrency return the concurrency of the pool at the init. -func (p *WorkerPool[T]) GetOriginConcurrency() int32 { +func (p *WorkerPool[T, R]) GetOriginConcurrency() int32 { return p.originWorkers } diff --git a/resourcemanager/pool/workerpool/workpool_test.go b/resourcemanager/pool/workerpool/workpool_test.go index 21e800706820e..0d21795603d30 100644 --- a/resourcemanager/pool/workerpool/workpool_test.go +++ b/resourcemanager/pool/workerpool/workpool_test.go @@ -23,35 +23,48 @@ import ( "github.com/pingcap/tidb/util/logutil" "github.com/stretchr/testify/require" "go.uber.org/zap" + "golang.org/x/sync/errgroup" ) var globalCnt atomic.Int64 var cntWg sync.WaitGroup -type MyWorker[T int64] struct { +type MyWorker[T int64, R struct{}] struct { id int } -func (w *MyWorker[T]) HandleTask(task int64) { +func (w *MyWorker[T, R]) HandleTask(task int64) struct{} { globalCnt.Add(task) cntWg.Done() logutil.BgLogger().Info("Worker handling task") + return struct{}{} } -func (w *MyWorker[T]) Close() { +func (w *MyWorker[T, R]) Close() { logutil.BgLogger().Info("Close worker", zap.Any("id", w.id)) } -func createMyWorker() Worker[int64] { - return &MyWorker[int64]{} +func createMyWorker() Worker[int64, struct{}] { + return &MyWorker[int64, struct{}]{} } func TestWorkerPool(t *testing.T) { // Create a worker pool with 3 workers. pool, err := NewWorkerPool[int64]("test", util.UNKNOWN, 3, createMyWorker) require.NoError(t, err) + pool.Start() globalCnt.Store(0) + g := new(errgroup.Group) + g.Go(func() error { + // Consume the results. + for range pool.GetResultChan() { + // Do nothing. + } + return nil + }) + defer g.Wait() + // Add some tasks to the pool. cntWg.Add(10) for i := 0; i < 10; i++ { @@ -91,3 +104,76 @@ func TestWorkerPool(t *testing.T) { // Wait for the tasks to be completed. pool.ReleaseAndWait() } + +type dummyWorker[T, R any] struct { +} + +func (d dummyWorker[T, R]) HandleTask(task T) R { + var zero R + return zero +} + +func (d dummyWorker[T, R]) Close() {} + +func TestWorkerPoolNoneResult(t *testing.T) { + pool, err := NewWorkerPool[int64, None]( + "test", util.UNKNOWN, 3, + func() Worker[int64, None] { + return dummyWorker[int64, None]{} + }) + require.NoError(t, err) + pool.Start() + ch := pool.GetResultChan() + require.Nil(t, ch) + pool.ReleaseAndWait() + + pool2, err := NewWorkerPool[int64, int64]( + "test", util.UNKNOWN, 3, + func() Worker[int64, int64] { + return dummyWorker[int64, int64]{} + }) + require.NoError(t, err) + pool2.Start() + require.NotNil(t, pool2.GetResultChan()) + pool2.ReleaseAndWait() + + pool3, err := NewWorkerPool[int64, struct{}]( + "test", util.UNKNOWN, 3, + func() Worker[int64, struct{}] { + return dummyWorker[int64, struct{}]{} + }) + require.NoError(t, err) + pool3.Start() + require.NotNil(t, pool3.GetResultChan()) + pool3.ReleaseAndWait() +} + +func TestWorkerPoolCustomChan(t *testing.T) { + pool, err := NewWorkerPool[int64, int64]( + "test", util.UNKNOWN, 3, + func() Worker[int64, int64] { + return dummyWorker[int64, int64]{} + }) + require.NoError(t, err) + + taskCh := make(chan int64) + pool.SetTaskReceiver(taskCh) + resultCh := make(chan int64) + pool.SetResultSender(resultCh) + count := 0 + g := new(errgroup.Group) + g.Go(func() error { + for range resultCh { + count++ + } + return nil + }) + + pool.Start() + for i := 0; i < 5; i++ { + taskCh <- int64(i) + } + pool.ReleaseAndWait() + require.NoError(t, g.Wait()) + require.Equal(t, 5, count) +}