diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index c4da6edc9f5cb..e958f2fe43a94 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -361,6 +361,13 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, SetFromInfoSchema(e.ctx.GetInfoSchema()). SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.ctx, &builder.Request, e.partialNetDataSizes[workID])) + var notClosedSelectResult distsql.SelectResult + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + if notClosedSelectResult != nil { + terror.Call(notClosedSelectResult.Close) + } + }() for parTblIdx, keyRange := range keyRanges { // check if this executor is closed select { @@ -384,6 +391,8 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) return } + notClosedSelectResult = result + failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) worker.batchSize = e.maxChunkSize if worker.batchSize > worker.maxBatchSize { worker.batchSize = worker.maxBatchSize @@ -398,6 +407,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, if fetchErr != nil { // this error is synced in fetchHandles(), don't sync it again e.feedbacks[workID].Invalidate() } + notClosedSelectResult = nil if err := result.Close(); err != nil { logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) } @@ -475,6 +485,13 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, partialTableReader.dagPB = e.dagPBs[workID] } + var tableReaderClosed bool + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + if !tableReaderClosed { + terror.Call(worker.tableReader.Close) + } + }() for parTblIdx, tbl := range tbls { // check if this executor is closed select { @@ -490,6 +507,8 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) break } + failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) + tableReaderClosed = false worker.batchSize = e.maxChunkSize if worker.batchSize > worker.maxBatchSize { worker.batchSize = worker.maxBatchSize @@ -507,6 +526,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, // release related resources cancel() + tableReaderClosed = true if err = worker.tableReader.Close(); err != nil { logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) } diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index be1ff66a163ab..d30fce71a180e 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -881,3 +881,35 @@ func TestIndexMergePanic(t *testing.T) { require.NoError(t, failpoint.Disable(fp)) } } + +func TestIndexMergeCoprGoroutinesLeak(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int, c2 bigint, c3 bigint, primary key(c1), key(c2), key(c3));") + insertStr := "insert into t1 values(0, 0, 0)" + for i := 1; i < 1000; i++ { + insertStr += fmt.Sprintf(", (%d, %d, %d)", i, i, i) + } + tk.MustExec(insertStr) + tk.MustExec("analyze table t1;") + tk.MustExec("set tidb_partition_prune_mode = 'dynamic'") + + var err error + sql := "select /*+ use_index_merge(t1) */ c1 from t1 where c1 < 900 or c2 < 1000;" + res := tk.MustQuery("explain " + sql).Rows() + require.Contains(t, res[1][0], "IndexMerge") + + // If got goroutines leak in coprocessor, ci will fail. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergePartialTableWorkerCoprLeak", `panic("testIndexMergePartialTableWorkerCoprLeak")`)) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergePartialTableWorkerCoprLeak") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergePartialTableWorkerCoprLeak")) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIndexMergePartialIndexWorkerCoprLeak", `panic("testIndexMergePartialIndexWorkerCoprLeak")`)) + err = tk.QueryToErr(sql) + require.Contains(t, err.Error(), "testIndexMergePartialIndexWorkerCoprLeak") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIndexMergePartialIndexWorkerCoprLeak")) +}