diff --git a/executor/join.go b/executor/join.go index 48d3e5d5a56f8..8528c900772d8 100644 --- a/executor/join.go +++ b/executor/join.go @@ -20,7 +20,6 @@ import ( "fmt" "runtime/trace" "strconv" - "sync" "sync/atomic" "time" @@ -70,6 +69,8 @@ type HashJoinExec struct { // closeCh add a lock for closing executor. closeCh chan struct{} + worker util.WaitGroupWrapper + waiter util.WaitGroupWrapper joinType plannercore.JoinType requiredRows int64 @@ -92,9 +93,7 @@ type HashJoinExec struct { prepared bool isOuterJoin bool - // joinWorkerWaitGroup is for sync multiple join workers. - joinWorkerWaitGroup sync.WaitGroup - finished atomic.Value + finished atomic.Bool stats *hashJoinRuntimeStats @@ -154,6 +153,7 @@ func (e *HashJoinExec) Close() error { e.probeChkResourceCh = nil e.joinChkResourceCh = nil terror.Call(e.rowContainer.Close) + e.waiter.Wait() } e.outerMatchedStatus = e.outerMatchedStatus[:0] e.buildSideRows = nil @@ -181,9 +181,10 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.diskTracker = disk.NewTracker(e.id, -1) e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker) + e.worker = util.WaitGroupWrapper{} + e.waiter = util.WaitGroupWrapper{} e.closeCh = make(chan struct{}) e.finished.Store(false) - e.joinWorkerWaitGroup = sync.WaitGroup{} if e.probeTypes == nil { e.probeTypes = retTypes(e.probeSideExec) @@ -205,7 +206,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { func (e *HashJoinExec) fetchProbeSideChunks(ctx context.Context) { hasWaitedForBuild := false for { - if e.finished.Load().(bool) { + if e.finished.Load() { return } @@ -283,24 +284,24 @@ func (e *HashJoinExec) wait4BuildSide() (emptyBuild bool, err error) { // fetchBuildSideRows fetches all rows from build side executor, and append them // to e.buildSideResult. -func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh <-chan struct{}) { +func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { defer close(chkCh) var err error failpoint.Inject("issue30289", func(val failpoint.Value) { if val.(bool) { err = errors.Errorf("issue30289 build return error") - e.buildFinished <- errors.Trace(err) + errCh <- errors.Trace(err) return } }) for { - if e.finished.Load().(bool) { + if e.finished.Load() { return } chk := e.ctx.GetSessionVars().GetNewChunk(e.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize) err = Next(ctx, e.buildSideExec, chk) if err != nil { - e.buildFinished <- errors.Trace(err) + errCh <- errors.Trace(err) return } failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) @@ -357,8 +358,7 @@ func (e *HashJoinExec) initializeForProbe() { func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { e.initializeForProbe() - e.joinWorkerWaitGroup.Add(1) - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() e.fetchProbeSideChunks(ctx) }, e.handleProbeSideFetcherPanic) @@ -373,14 +373,13 @@ func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) { } for i := uint(0); i < e.concurrency; i++ { - e.joinWorkerWaitGroup.Add(1) workID := i - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinWorker").End() e.runJoinWorker(workID, probeKeyColIdx, probeNAKeColIdx) }, e.handleJoinWorkerPanic) } - go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil) + e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) } func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) { @@ -390,14 +389,12 @@ func (e *HashJoinExec) handleProbeSideFetcherPanic(r interface{}) { if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.joinWorkerWaitGroup.Done() } func (e *HashJoinExec) handleJoinWorkerPanic(r interface{}) { if r != nil { e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} } - e.joinWorkerWaitGroup.Done() } // Concurrently handling unmatched rows from the hash table @@ -437,15 +434,14 @@ func (e *HashJoinExec) handleUnmatchedRowsFromHashTable(workerID uint) { } func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { - e.joinWorkerWaitGroup.Wait() + e.worker.Wait() if e.useOuterToBuild { // Concurrently handling unmatched rows from the hash table at the tail for i := uint(0); i < e.concurrency; i++ { var workerID = i - e.joinWorkerWaitGroup.Add(1) - go util.WithRecovery(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic) + e.worker.RunWithRecover(func() { e.handleUnmatchedRowsFromHashTable(workerID) }, e.handleJoinWorkerPanic) } - e.joinWorkerWaitGroup.Wait() + e.worker.Wait() } close(e.joinResultCh) } @@ -481,7 +477,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint, probeKeyColIdx, probeNAKeyCo naKeyColIdx: probeNAKeyColIdx, } for ok := true; ok; { - if e.finished.Load().(bool) { + if e.finished.Load() { break } select { @@ -1121,7 +1117,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { for i := uint(0); i < e.concurrency; i++ { e.rowIters = append(e.rowIters, chunk.NewIterator4Slice([]chunk.Row{}).(*chunk.Iterator4Slice)) } - go util.WithRecovery(func() { + e.worker.RunWithRecover(func() { defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() e.fetchAndBuildHashTable(ctx) }, e.handleFetchAndBuildHashTablePanic) @@ -1164,10 +1160,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) { buildSideResultCh := make(chan *chunk.Chunk, 1) doneCh := make(chan struct{}) fetchBuildSideRowsOk := make(chan error, 1) - go util.WithRecovery( + e.worker.RunWithRecover( func() { defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() - e.fetchBuildSideRows(ctx, buildSideResultCh, doneCh) + e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh) }, func(r interface{}) { if r != nil { @@ -1214,7 +1210,7 @@ func (e *HashJoinExec) buildHashTableForList(buildSideResultCh <-chan *chunk.Chu e.ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } for chk := range buildSideResultCh { - if e.finished.Load().(bool) { + if e.finished.Load() { return nil } if !e.useOuterToBuild { diff --git a/util/wait_group_wrapper.go b/util/wait_group_wrapper.go index 16c8704920a28..3fb72049f1365 100644 --- a/util/wait_group_wrapper.go +++ b/util/wait_group_wrapper.go @@ -43,7 +43,7 @@ func (w *WaitGroupWrapper) RunWithRecover(exec func(), recoverFn func(r interfac go func() { defer func() { r := recover() - if r != nil && recoverFn != nil { + if recoverFn != nil { recoverFn(r) } w.Done()