Skip to content

Commit

Permalink
executor: fix incorrect result when hash agg spill is triggered (#55336)
Browse files Browse the repository at this point in the history
close #55290
  • Loading branch information
xzhangxian1008 authored Aug 10, 2024
1 parent f29a4ea commit 236a79c
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 86 deletions.
1 change: 1 addition & 0 deletions pkg/executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ var (
_ AggFunc = (*maxMin4Float64)(nil)
_ AggFunc = (*maxMin4Decimal)(nil)
_ AggFunc = (*maxMin4String)(nil)
_ AggFunc = (*maxMin4Time)(nil)
_ AggFunc = (*maxMin4Duration)(nil)
_ AggFunc = (*maxMin4JSON)(nil)
_ AggFunc = (*maxMin4Enum)(nil)
Expand Down
4 changes: 2 additions & 2 deletions pkg/executor/aggfuncs/func_count.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func (e *baseCount) DeserializePartialResult(src *chunk.Chunk) ([]PartialResult,

func (e *baseCount) deserializeForSpill(helper *deserializeHelper) (PartialResult, int64) {
pr, memDelta := e.AllocPartialResult()
result := *(*partialResult4Count)(pr)
success := helper.deserializePartialResult4Count(&result)
result := (*partialResult4Count)(pr)
success := helper.deserializePartialResult4Count(result)
if !success {
return nil, 0
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,26 @@ func (e *maxMin4Time) MergePartialResult(_ AggFuncUpdateContext, src, dst Partia
return 0, nil
}

func (e *maxMin4Time) SerializePartialResult(partialResult PartialResult, chk *chunk.Chunk, spillHelper *SerializeHelper) {
pr := (*partialResult4MaxMinTime)(partialResult)
resBuf := spillHelper.serializePartialResult4MaxMinTime(*pr)
chk.AppendBytes(e.ordinal, resBuf)
}

func (e *maxMin4Time) DeserializePartialResult(src *chunk.Chunk) ([]PartialResult, int64) {
return deserializePartialResultCommon(src, e.ordinal, e.deserializeForSpill)
}

func (e *maxMin4Time) deserializeForSpill(helper *deserializeHelper) (PartialResult, int64) {
pr, memDelta := e.AllocPartialResult()
result := (*partialResult4MaxMinTime)(pr)
success := helper.deserializePartialResult4MaxMinTime(result)
if !success {
return nil, 0
}
return pr, memDelta
}

type maxMin4TimeSliding struct {
maxMin4Time
windowInfo
Expand Down
7 changes: 6 additions & 1 deletion pkg/executor/aggregate/agg_hash_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,14 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) error {
spillChunkFieldTypes[i] = types.NewFieldType(mysql.TypeVarString)
}
spillChunkFieldTypes[baseRetTypeNum] = types.NewFieldType(mysql.TypeString)
e.spillHelper = newSpillHelper(e.memTracker, e.PartialAggFuncs, func() *chunk.Chunk {

var err error
e.spillHelper, err = newSpillHelper(e.memTracker, e.PartialAggFuncs, e.FinalAggFuncs, func() *chunk.Chunk {
return chunk.New(spillChunkFieldTypes, e.InitCap(), e.MaxChunkSize())
}, spillChunkFieldTypes)
if err != nil {
return err
}

if isTrackerEnabled && isParallelHashAggSpillEnabled {
if e.diskTracker != nil {
Expand Down
15 changes: 12 additions & 3 deletions pkg/executor/aggregate/agg_spill.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"sync"
"sync/atomic"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
Expand Down Expand Up @@ -66,15 +67,22 @@ type parallelHashAggSpillHelper struct {
// They only be used for restoring data that are spilled to disk in partial stage.
aggFuncsForRestoring []aggfuncs.AggFunc

finalWorkerAggFuncs []aggfuncs.AggFunc

getNewSpillChunkFunc func() *chunk.Chunk
spillChunkFieldTypes []*types.FieldType
}

func newSpillHelper(
tracker *memory.Tracker,
aggFuncsForRestoring []aggfuncs.AggFunc,
finalWorkerAggFuncs []aggfuncs.AggFunc,
getNewSpillChunkFunc func() *chunk.Chunk,
spillChunkFieldTypes []*types.FieldType) *parallelHashAggSpillHelper {
spillChunkFieldTypes []*types.FieldType) (*parallelHashAggSpillHelper, error) {
if len(aggFuncsForRestoring) != len(finalWorkerAggFuncs) {
return nil, errors.NewNoStackError("len(aggFuncsForRestoring) != len(finalWorkerAggFuncs)")
}

mu := new(sync.Mutex)
helper := &parallelHashAggSpillHelper{
lock: struct {
Expand All @@ -97,11 +105,12 @@ func newSpillHelper(
memTracker: tracker,
hasError: atomic.Bool{},
aggFuncsForRestoring: aggFuncsForRestoring,
finalWorkerAggFuncs: finalWorkerAggFuncs,
getNewSpillChunkFunc: getNewSpillChunkFunc,
spillChunkFieldTypes: spillChunkFieldTypes,
}

return helper
return helper, nil
}

func (p *parallelHashAggSpillHelper) close() {
Expand Down Expand Up @@ -294,7 +303,7 @@ func (p *parallelHashAggSpillHelper) processRow(context *processRowContext) (tot
exprCtx := context.ctx.GetExprCtx()
// The key has appeared before, merge results.
for aggPos := 0; aggPos < context.aggFuncNum; aggPos++ {
memDelta, err := p.aggFuncsForRestoring[aggPos].MergePartialResult(exprCtx.GetEvalCtx(), context.partialResultsRestored[aggPos][context.rowPos], prs[aggPos])
memDelta, err := p.finalWorkerAggFuncs[aggPos].MergePartialResult(exprCtx.GetEvalCtx(), context.partialResultsRestored[aggPos][context.rowPos], prs[aggPos])
if err != nil {
return totalMemDelta, 0, err
}
Expand Down
Loading

0 comments on commit 236a79c

Please sign in to comment.