Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner/core: change agg cost factor #25210

Merged
merged 10 commits into from
Jun 8, 2021
8 changes: 6 additions & 2 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ func (p *basePhysicalAgg) numDistinctFunc() (num int) {
return
}

func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
func (p *basePhysicalAgg) getAggFuncCostFactor(isMPP bool) (factor float64) {
factor = 0.0
for _, agg := range p.AggFuncs {
if fac, ok := aggFuncFactor[agg.Name]; ok {
Expand All @@ -1058,7 +1058,11 @@ func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
}
}
if factor == 0 {
factor = 1.0
if isMPP {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments and TODO for this if branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

factor = aggFuncFactor[ast.AggFuncFirstRow]
} else {
factor = 1.0
}
}
return
}
Expand Down
18 changes: 9 additions & 9 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task {

// GetCost computes cost of stream aggregation considering CPU/memory.
func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool) float64 {
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(false)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down Expand Up @@ -1876,7 +1876,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if proj != nil {
attachPlan2Task(proj, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
p.cost = mpp.cost()
return mpp
case Mpp2Phase:
Expand Down Expand Up @@ -1909,7 +1909,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
attachPlan2Task(proj, newMpp)
}
// TODO: how to set 2-phase cost?
newMpp.addCost(p.GetCost(inputRows, false))
newMpp.addCost(p.GetCost(inputRows, false, true))
finalAgg.SetCost(mpp.cost())
if proj != nil {
proj.SetCost(mpp.cost())
Expand All @@ -1920,14 +1920,14 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
if partialAgg != nil {
partialAgg.SetCost(mpp.cost())
}
t = mpp.convertToRootTask(p.ctx)
inputRows = t.count()
attachPlan2Task(finalAgg, t)
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
finalAgg.SetCost(t.cost())
return t
default:
Expand Down Expand Up @@ -1958,7 +1958,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
partialAgg.SetChildren(cop.indexPlan)
cop.indexPlan = partialAgg
}
cop.addCost(p.GetCost(inputRows, false))
cop.addCost(p.GetCost(inputRows, false, false))
}
// In `newPartialAggregate`, we are using stats of final aggregation as stats
// of `partialAgg`, so the network cost of transferring result rows of `partialAgg`
Expand Down Expand Up @@ -1991,16 +1991,16 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
// hash aggregation, it would cause under-estimation as the reason mentioned in comment above.
// To make it simple, we also treat 2-phase parallel hash aggregation in TiDB layer as
// 1-phase when computing cost.
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
p.cost = t.cost()
return t
}

// GetCost computes the cost of hash aggregation considering CPU/memory.
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool) float64 {
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool, isMPP bool) float64 {
cardinality := p.statsInfo().RowCount
numDistinctFunc := p.numDistinctFunc()
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(isMPP)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down
24 changes: 2 additions & 22 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
{
"name": "TestMPPOuterJoinBuildSideForBroadcastJoin",
"cases": [
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
"explain format = 'brief' select count(*) from a left join b on a.id = b.id",
"explain format = 'brief' select count(*) from b right join a on a.id = b.id"
]
},
{
Expand Down Expand Up @@ -101,26 +101,6 @@
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestBroadcastJoin",
"cases": [
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t), broadcast_join_local(d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t), broadcast_join_local(d2_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col2 > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestJoinNotSupportedByTiFlash",
"cases": [
Expand Down
Loading