Skip to content

Commit

Permalink
cherry pick pingcap#25210 to release-5.0
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
hanfei1991 authored and ti-srebot committed Jun 8, 2021
1 parent dc40a09 commit cabc0cb
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 267 deletions.
12 changes: 10 additions & 2 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ func (p *basePhysicalAgg) numDistinctFunc() (num int) {
return
}

func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
func (p *basePhysicalAgg) getAggFuncCostFactor(isMPP bool) (factor float64) {
factor = 0.0
for _, agg := range p.AggFuncs {
if fac, ok := aggFuncFactor[agg.Name]; ok {
Expand All @@ -1018,7 +1018,15 @@ func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
}
}
if factor == 0 {
factor = 1.0
if isMPP {
// The default factor 1.0 will lead to 1-phase agg in pseudo stats settings.
// But in mpp cases, 2-phase is more usual. So we change this factor.
// TODO: This is still a little tricky and might cause regression. We should
// calibrate these factors and polish our cost model in the future.
factor = aggFuncFactor[ast.AggFuncFirstRow]
} else {
factor = 1.0
}
}
return
}
Expand Down
38 changes: 34 additions & 4 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task {

// GetCost computes cost of stream aggregation considering CPU/memory.
func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool) float64 {
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(false)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down Expand Up @@ -1786,7 +1786,12 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if proj != nil {
attachPlan2Task(proj, mpp)
}
<<<<<<< HEAD
mpp.addCost(p.GetCost(inputRows, false))
=======
mpp.addCost(p.GetCost(inputRows, false, true))
p.cost = mpp.cost()
>>>>>>> a7f3c4d8b... planner/core: change agg cost factor (#25210)
return mpp
case Mpp2Phase:
proj := p.convertAvgForMPP()
Expand Down Expand Up @@ -1817,18 +1822,38 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
attachPlan2Task(proj, newMpp)
}
// TODO: how to set 2-phase cost?
<<<<<<< HEAD
newMpp.addCost(p.GetCost(inputRows, false))
=======
newMpp.addCost(p.GetCost(inputRows, false, true))
finalAgg.SetCost(mpp.cost())
if proj != nil {
proj.SetCost(mpp.cost())
}
>>>>>>> a7f3c4d8b... planner/core: change agg cost factor (#25210)
return newMpp
case MppTiDB:
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, false)
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}
<<<<<<< HEAD
mpp.addCost(p.GetCost(inputRows, false))
t = mpp.convertToRootTask(p.ctx)
inputRows = t.count()
attachPlan2Task(finalAgg, t)
t.addCost(p.GetCost(inputRows, true))
=======
mpp.addCost(p.GetCost(inputRows, false, true))
if partialAgg != nil {
partialAgg.SetCost(mpp.cost())
}
t = mpp.convertToRootTask(p.ctx)
inputRows = t.count()
attachPlan2Task(finalAgg, t)
t.addCost(p.GetCost(inputRows, true, false))
finalAgg.SetCost(t.cost())
>>>>>>> a7f3c4d8b... planner/core: change agg cost factor (#25210)
return t
default:
return invalidTask
Expand Down Expand Up @@ -1858,7 +1883,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
partialAgg.SetChildren(cop.indexPlan)
cop.indexPlan = partialAgg
}
cop.addCost(p.GetCost(inputRows, false))
cop.addCost(p.GetCost(inputRows, false, false))
}
// In `newPartialAggregate`, we are using stats of final aggregation as stats
// of `partialAgg`, so the network cost of transferring result rows of `partialAgg`
Expand Down Expand Up @@ -1891,15 +1916,20 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
// hash aggregation, it would cause under-estimation as the reason mentioned in comment above.
// To make it simple, we also treat 2-phase parallel hash aggregation in TiDB layer as
// 1-phase when computing cost.
<<<<<<< HEAD
t.addCost(p.GetCost(inputRows, true))
=======
t.addCost(p.GetCost(inputRows, true, false))
p.cost = t.cost()
>>>>>>> a7f3c4d8b... planner/core: change agg cost factor (#25210)
return t
}

// GetCost computes the cost of hash aggregation considering CPU/memory.
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool) float64 {
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool, isMPP bool) float64 {
cardinality := p.statsInfo().RowCount
numDistinctFunc := p.numDistinctFunc()
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(isMPP)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down
49 changes: 29 additions & 20 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,36 @@
"explain format = 'brief' select count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
<<<<<<< HEAD
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
=======
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select count(*) from fact_t join d1_t on fact_t.d1_k > d1_t.d1_k",
"explain format = 'brief' select count(*) from fact_t left join d1_t on fact_t.d1_k > d1_t.d1_k",
"explain format = 'brief' select count(*) from fact_t right join d1_t on fact_t.d1_k > d1_t.d1_k",
"explain format = 'brief' select count(*) from fact_t where d1_k not in (select d1_k from d1_t)"
]
},
{
"name": "TestMPPOuterJoinBuildSideForBroadcastJoin",
"cases": [
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
]
},
{
"name": "TestMPPOuterJoinBuildSideForShuffleJoinWithFixedBuildSide",
"cases": [
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
]
},
{
"name": "TestMPPOuterJoinBuildSideForShuffleJoin",
"cases": [
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
>>>>>>> a7f3c4d8b... planner/core: change agg cost factor (#25210)
]
},
{
Expand All @@ -55,26 +84,6 @@
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestBroadcastJoin",
"cases": [
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t), broadcast_join_local(d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t), broadcast_join_local(d2_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col2 > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestJoinNotSupportedByTiFlash",
"cases": [
Expand Down
Loading

0 comments on commit cabc0cb

Please sign in to comment.