Skip to content

Commit

Permalink
[fix](nereids) nest loop join stats estimation (#21275)
Browse files Browse the repository at this point in the history
1. fix bug in nest loop join estimation
2. update column=column stats estimation
  • Loading branch information
englefly authored Jun 30, 2023
1 parent 6d63261 commit 9f44c2d
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -390,24 +390,7 @@ private Statistics estimateColumnEqualToColumn(Expression leftExpr, ColumnStatis
rightBuilder.setNdv(rightIntersectLeft.getDistinctValues());
rightBuilder.setMinValue(rightIntersectLeft.getLow());
rightBuilder.setMaxValue(rightIntersectLeft.getDistinctValues());
double sel;
double reduceRatio = 0.25;
double bothSideReducedRatio = 0.9;
if (!leftStats.rangeChanged() && !rightStats.rangeChanged()
&& leftStats.ndv < leftStats.getOriginalNdv() * bothSideReducedRatio
&& rightStats.ndv < rightStats.getOriginalNdv() * bothSideReducedRatio) {
double sel1;
if (leftStats.ndv > rightStats.ndv) {
sel1 = 1 / StatsMathUtil.nonZeroDivisor(leftStats.ndv);
} else {
sel1 = 1 / StatsMathUtil.nonZeroDivisor(rightStats.ndv);
}
double sel2 = Math.min(rightStats.ndv / rightStats.getOriginalNdv(),
leftStats.ndv / leftStats.getOriginalNdv());
sel = sel1 * Math.pow(sel2, reduceRatio);
} else {
sel = 1 / StatsMathUtil.nonZeroDivisor(Math.max(leftStats.ndv, rightStats.ndv));
}
double sel = 1 / StatsMathUtil.nonZeroDivisor(Math.max(leftStats.ndv, rightStats.ndv));
Statistics updatedStatistics = context.statistics.withSel(sel);
updatedStatistics.addColumnStats(leftExpr, leftBuilder.build());
updatedStatistics.addColumnStats(rightExpr, rightBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,7 @@ private static boolean hashJoinConditionContainsUnknownColumnStats(Statistics le
return false;
}

private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
}
private static Statistics estimateHashJoin(Statistics leftStats, Statistics rightStats, Join join) {
/*
* When we estimate filter A=B,
* if any side of equation, A or B, is almost unique, the confidence level of estimation is high.
Expand All @@ -95,31 +86,31 @@ private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rig
List<EqualTo> trustableConditions = join.getHashJoinConjuncts().stream()
.map(expression -> (EqualTo) expression)
.filter(
expression -> {
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
|| eqLeftColStats.ndv / leftStatsRowCount > almostUniqueThreshold;
if (!trustable) {
double rNdv = StatsMathUtil.nonZeroDivisor(eqRightColStats.ndv);
double lNdv = StatsMathUtil.nonZeroDivisor(eqLeftColStats.ndv);
if (leftBigger) {
unTrustEqualRatio.add((rightStatsRowCount / rNdv)
* Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / lNdv);
} else {
unTrustEqualRatio.add((leftStatsRowCount / lNdv)
* Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / rNdv);
expression -> {
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
|| eqLeftColStats.ndv / leftStatsRowCount > almostUniqueThreshold;
if (!trustable) {
double rNdv = StatsMathUtil.nonZeroDivisor(eqRightColStats.ndv);
double lNdv = StatsMathUtil.nonZeroDivisor(eqLeftColStats.ndv);
if (leftBigger) {
unTrustEqualRatio.add((rightStatsRowCount / rNdv)
* Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / lNdv);
} else {
unTrustEqualRatio.add((leftStatsRowCount / lNdv)
* Math.min(eqLeftColStats.ndv, eqRightColStats.ndv) / rNdv);
}
unTrustableCondition.add(equal);
}
unTrustableCondition.add(equal);
return trustable;
}
return trustable;
}
).collect(Collectors.toList());

Statistics innerJoinStats;
Expand Down Expand Up @@ -159,6 +150,34 @@ private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rig
}
innerJoinStats = crossJoinStats.updateRowCountOnly(outputRowCount);
}
return innerJoinStats;
}

private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) {
return new StatisticsBuilder()
.setRowCount(Math.max(1, leftStats.getRowCount() * rightStats.getRowCount()))
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
}

private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
}

Statistics innerJoinStats;
if (join.getHashJoinConjuncts().isEmpty()) {
innerJoinStats = estimateNestLoopJoin(leftStats, rightStats, join);
} else {
innerJoinStats = estimateHashJoin(leftStats, rightStats, join);
}

if (!join.getOtherJoinConjuncts().isEmpty()) {
FilterEstimation filterEstimation = new FilterEstimation();
Expand Down
20 changes: 10 additions & 10 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query11.out
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,24 @@ CteAnchor[cteId= ( CTEId#4=] )
----PhysicalDistribute
------PhysicalTopN
--------PhysicalProject
----------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(CASE WHEN (year_total > 0.00) THEN (cast(year_total as DECIMALV3(38, 8)) / year_total) ELSE 0.000000 END > CASE WHEN (year_total > 0.00) THEN (cast(year_total as DECIMALV3(38, 8)) / year_total) ELSE 0.000000 END)
----------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)(CASE WHEN (year_total > 0.00) THEN (cast(year_total as DECIMALV3(38, 8)) / year_total) ELSE 0.000000 END > CASE WHEN (year_total > 0.00) THEN (cast(year_total as DECIMALV3(38, 8)) / year_total) ELSE 0.000000 END)
------------PhysicalDistribute
--------------PhysicalProject
----------------filter((t_w_secyear.dyear = 2002)(t_w_secyear.sale_type = 'w'))
----------------filter((t_s_secyear.sale_type = 's')(t_s_secyear.dyear = 2002))
------------------CteConsumer[cteId= ( CTEId#4=] )
------------PhysicalProject
--------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)
----------------PhysicalDistribute
------------------PhysicalProject
--------------------filter((t_s_secyear.sale_type = 's')(t_s_secyear.dyear = 2002))
----------------------CteConsumer[cteId= ( CTEId#4=] )
----------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
--------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
----------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------filter((t_s_firstyear.dyear = 2001)(t_s_firstyear.sale_type = 's')(t_s_firstyear.year_total > 0.00))
----------------------filter((t_w_secyear.dyear = 2002)(t_w_secyear.sale_type = 'w'))
------------------------CteConsumer[cteId= ( CTEId#4=] )
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------filter((t_w_firstyear.year_total > 0.00)(t_w_firstyear.sale_type = 'w')(t_w_firstyear.dyear = 2001))
----------------------filter((t_s_firstyear.dyear = 2001)(t_s_firstyear.sale_type = 's')(t_s_firstyear.year_total > 0.00))
------------------------CteConsumer[cteId= ( CTEId#4=] )
----------------PhysicalDistribute
------------------PhysicalProject
--------------------filter((t_w_firstyear.year_total > 0.00)(t_w_firstyear.sale_type = 'w')(t_w_firstyear.dyear = 2001))
----------------------CteConsumer[cteId= ( CTEId#4=] )

59 changes: 30 additions & 29 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query18.out
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,41 @@ PhysicalTopN
------------hashAgg[LOCAL]
--------------PhysicalRepeat
----------------PhysicalProject
------------------hashJoin[INNER_JOIN](catalog_sales.cs_item_sk = item.i_item_sk)
--------------------PhysicalProject
----------------------PhysicalOlapScan[item]
------------------hashJoin[INNER_JOIN](customer.c_current_cdemo_sk = cd2.cd_demo_sk)
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN](customer.c_current_cdemo_sk = cd2.cd_demo_sk)
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------PhysicalOlapScan[customer_demographics]
------------------------PhysicalOlapScan[customer_demographics]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN](catalog_sales.cs_item_sk = item.i_item_sk)
--------------------------PhysicalProject
----------------------------PhysicalOlapScan[item]
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)
--------------------------------PhysicalProject
----------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_bill_cdemo_sk = cd1.cd_demo_sk)
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[catalog_sales]
----------------------------------------PhysicalDistribute
------------------------------------------hashJoin[INNER_JOIN](customer.c_current_addr_sk = customer_address.ca_address_sk)
--------------------------------------------PhysicalDistribute
----------------------------------------------PhysicalProject
------------------------------------------------filter(c_birth_month IN (1, 2, 4, 7, 8, 10))
--------------------------------------------------PhysicalOlapScan[customer]
--------------------------------------------PhysicalDistribute
------------------------------hashJoin[INNER_JOIN](customer.c_current_addr_sk = customer_address.ca_address_sk)
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)
--------------------------------------PhysicalDistribute
----------------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN](catalog_sales.cs_bill_cdemo_sk = cd1.cd_demo_sk)
----------------------------------------------PhysicalProject
------------------------------------------------filter(ca_state IN ('WA', 'GA', 'NC', 'ME', 'WY', 'OK', 'IN'))
--------------------------------------------------PhysicalOlapScan[customer_address]
------------------------------------PhysicalDistribute
--------------------------------------PhysicalProject
----------------------------------------filter((cast(cd_gender as VARCHAR(*)) = 'F')(cast(cd_education_status as VARCHAR(*)) = 'Advanced Degree'))
------------------------------------------PhysicalOlapScan[customer_demographics]
------------------------------------------------PhysicalOlapScan[catalog_sales]
----------------------------------------------PhysicalDistribute
------------------------------------------------PhysicalProject
--------------------------------------------------filter((cast(cd_gender as VARCHAR(*)) = 'F')(cast(cd_education_status as VARCHAR(*)) = 'Advanced Degree'))
----------------------------------------------------PhysicalOlapScan[customer_demographics]
------------------------------------------PhysicalDistribute
--------------------------------------------PhysicalProject
----------------------------------------------filter((date_dim.d_year = 1998))
------------------------------------------------PhysicalOlapScan[date_dim]
--------------------------------------PhysicalDistribute
----------------------------------------PhysicalProject
------------------------------------------filter(c_birth_month IN (1, 2, 4, 7, 8, 10))
--------------------------------------------PhysicalOlapScan[customer]
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter((date_dim.d_year = 1998))
--------------------------------------PhysicalOlapScan[date_dim]
------------------------------------filter(ca_state IN ('WA', 'GA', 'NC', 'ME', 'WY', 'OK', 'IN'))
--------------------------------------PhysicalOlapScan[customer_address]

44 changes: 22 additions & 22 deletions regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query4.out
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,35 @@ CteAnchor[cteId= ( CTEId#6=] )
------PhysicalTopN
--------PhysicalProject
----------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_secyear.customer_id)(CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END > CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END)
------------PhysicalProject
--------------filter((t_w_secyear.sale_type = 'w')(t_w_secyear.dyear = 2000))
----------------CteConsumer[cteId= ( CTEId#6=] )
------------PhysicalDistribute
--------------PhysicalProject
----------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
----------------filter((t_w_secyear.sale_type = 'w')(t_w_secyear.dyear = 2000))
------------------CteConsumer[cteId= ( CTEId#6=] )
------------PhysicalProject
--------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_w_firstyear.customer_id)
----------------PhysicalDistribute
------------------PhysicalProject
--------------------filter((t_w_firstyear.dyear = 1999)(t_w_firstyear.sale_type = 'w')(t_w_firstyear.year_total > 0.000000))
----------------------CteConsumer[cteId= ( CTEId#6=] )
------------------PhysicalDistribute
----------------PhysicalProject
------------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)(CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END > CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END)
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------filter((t_s_secyear.sale_type = 's')(t_s_secyear.dyear = 2000))
--------------------------CteConsumer[cteId= ( CTEId#6=] )
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN](t_s_secyear.customer_id = t_s_firstyear.customer_id)(CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END > CASE WHEN (year_total > 0.000000) THEN (cast(year_total as DECIMALV3(38, 16)) / year_total) ELSE NULL END)
----------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_secyear.customer_id)
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter((t_s_secyear.sale_type = 's')(t_s_secyear.dyear = 2000))
----------------------------filter((t_c_secyear.sale_type = 'c')(t_c_secyear.dyear = 2000))
------------------------------CteConsumer[cteId= ( CTEId#6=] )
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_secyear.customer_id)
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter((t_c_secyear.sale_type = 'c')(t_c_secyear.dyear = 2000))
----------------------------------CteConsumer[cteId= ( CTEId#6=] )
----------------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_firstyear.customer_id)
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter((t_s_firstyear.year_total > 0.000000)(t_s_firstyear.dyear = 1999)(t_s_firstyear.sale_type = 's'))
------------------------------------CteConsumer[cteId= ( CTEId#6=] )
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter((t_c_firstyear.year_total > 0.000000)(t_c_firstyear.dyear = 1999)(t_c_firstyear.sale_type = 'c'))
------------------------------------CteConsumer[cteId= ( CTEId#6=] )
------------------------hashJoin[INNER_JOIN](t_s_firstyear.customer_id = t_c_firstyear.customer_id)
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------filter((t_s_firstyear.year_total > 0.000000)(t_s_firstyear.dyear = 1999)(t_s_firstyear.sale_type = 's'))
--------------------------------CteConsumer[cteId= ( CTEId#6=] )
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------filter((t_c_firstyear.year_total > 0.000000)(t_c_firstyear.dyear = 1999)(t_c_firstyear.sale_type = 'c'))
--------------------------------CteConsumer[cteId= ( CTEId#6=] )

Loading

0 comments on commit 9f44c2d

Please sign in to comment.