Skip to content

Commit

Permalink
[native] Fix bug in creation of HashNode(kLeftSemiProject).
Browse files Browse the repository at this point in the history
In Presto Native we create HashNode(kLeftSemiProject) with nullAware=false
from toVeloxQueryPlan(FilterNode). 'nullAware' dictates how we treat nulls
in the left keys.

This causes some queries in Presto Native to return more rows than Presto.
Because Presto does not have any member in class SemiJoinNode that indicates
how it processes nulls and is always 'null aware'.

Interesting, that toVeloxQueryPlan(SemiJoinNode) code path creates
HashNode(kLeftSemiProject) with nullAware=true, which is correct.

To fix the issue we route the creation of HashNode(kLeftSemiProject) in
toVeloxQueryPlan(FilterNode) through toVeloxQueryPlan(SemiJoinNode), thus
having only one place where we create HashNode(kLeftSemiProject) with
nullAware=true.
  • Loading branch information
spershin authored and amitkdutta committed Mar 2, 2024
1 parent 245c93f commit 8b15be6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,14 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
}
}

// No clear join type - fallback to the standard 'to velox expr'.
if (!joinType.has_value()) {
return std::make_shared<core::FilterNode>(
node->id,
exprConverter_.toVeloxExpr(node->predicate),
toVeloxQueryPlan(semiJoin, tableWriteInfo, taskId));
}

std::vector<core::FieldAccessTypedExprPtr> leftKeys = {
exprConverter_.toVeloxExpr(semiJoin->sourceJoinVariable)};
std::vector<core::FieldAccessTypedExprPtr> rightKeys = {
Expand All @@ -1640,25 +1648,6 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan(
auto names = leftNames;
names.push_back(semiJoin->semiJoinOutput.name);

if (!joinType.has_value()) {
auto types = leftTypes;
types.push_back(BOOLEAN());

return std::make_shared<core::FilterNode>(
node->id,
exprConverter_.toVeloxExpr(node->predicate),
std::make_shared<core::HashJoinNode>(
semiJoin->id,
core::JoinType::kLeftSemiProject,
false,
leftKeys,
rightKeys,
nullptr, // filter
left,
right,
ROW(std::move(names), std::move(types))));
}

std::vector<core::TypedExprPtr> projections;
projections.reserve(leftNames.size() + 1);
for (auto i = 0; i < leftNames.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ public void testSemiJoin(Session joinTypeSession)
{
assertQuery(joinTypeSession, "SELECT * FROM orders WHERE orderdate IN (SELECT shipdate FROM lineitem) or orderdate IN (SELECT commitdate FROM lineitem)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem " +
"WHERE linenumber = 3 OR orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE linenumber = 3 OR orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)");
assertQuery(joinTypeSession, "WITH\n" +
"users AS (SELECT orderkey FROM orders ),\n" +
"left_table AS (SELECT * FROM ( VALUES (0, NULL), (283755559, NULL), (NULL, NULL) ) AS left_table (userid, sid_cast))\n" +
"SELECT userid FROM left_table WHERE (sid_cast IS NOT NULL) OR (NOT (userid IN (SELECT * FROM users)))");
}

@Test(dataProvider = "joinTypeProvider")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Arrays;
import java.util.List;

import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem;
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders;

public abstract class AbstractTestNativeWindowQueries
Expand All @@ -36,6 +37,7 @@ protected void createTables()
{
QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner();
createOrders(queryRunner);
createLineitem(queryRunner);
}

private static final List<String> OVER_CLAUSES_WITH_ORDER_BY = Arrays.asList(
Expand Down Expand Up @@ -197,7 +199,6 @@ public void testOverlappingPartitionAndSortingKeys()
assertQuery("SELECT min(orderkey) OVER (PARTITION BY orderdate ORDER BY orderdate, totalprice) FROM orders");
assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1");

assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) " +
"SELECT * FROM t WHERE rn = 1");
assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1");
}
}

0 comments on commit 8b15be6

Please sign in to comment.