From 8b15be638fa424256e7db1ff0f570d55ffdb3ea9 Mon Sep 17 00:00:00 2001 From: Sergey Pershin Date: Fri, 1 Mar 2024 13:36:22 -0800 Subject: [PATCH] [native] Fix bug in creation of HashNode(kLeftSemiProject). In Presto Native we create HashNode(kLeftSemiProject) with nullAware=false from toVeloxQueryPlan(FilterNode). 'nullAware' dictates how we treat nulls in the left keys. This causes some queries in Presto Native to return more rows than Presto. Because Presto does not have any member in class SemiJoinNode that indicates how it processes nulls and is always 'null aware'. Interesting, that toVeloxQueryPlan(SemiJoinNode) code path creates HashNode(kLeftSemiProject) with nullAware=true, which is correct. To fix the issue we route the creation of HashNode(kLeftSemiProject) in toVeloxQueryPlan(FilterNode) through toVeloxQueryPlan(SemiJoinNode), thus having only one place where we create HashNode(kLeftSemiProject) with nullAware=true. --- .../main/types/PrestoToVeloxQueryPlan.cpp | 27 ++++++------------- .../AbstractTestNativeJoinQueries.java | 7 +++-- .../AbstractTestNativeWindowQueries.java | 5 ++-- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index dcef857fddbb0..a32d4c6583c90 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1625,6 +1625,14 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( } } + // No clear join type - fallback to the standard 'to velox expr'. + if (!joinType.has_value()) { + return std::make_shared( + node->id, + exprConverter_.toVeloxExpr(node->predicate), + toVeloxQueryPlan(semiJoin, tableWriteInfo, taskId)); + } + std::vector leftKeys = { exprConverter_.toVeloxExpr(semiJoin->sourceJoinVariable)}; std::vector rightKeys = { @@ -1640,25 +1648,6 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( auto names = leftNames; names.push_back(semiJoin->semiJoinOutput.name); - if (!joinType.has_value()) { - auto types = leftTypes; - types.push_back(BOOLEAN()); - - return std::make_shared( - node->id, - exprConverter_.toVeloxExpr(node->predicate), - std::make_shared( - semiJoin->id, - core::JoinType::kLeftSemiProject, - false, - leftKeys, - rightKeys, - nullptr, // filter - left, - right, - ROW(std::move(names), std::move(types)))); - } - std::vector projections; projections.reserve(leftNames.size() + 1); for (auto i = 0; i < leftNames.size(); i++) { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeJoinQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeJoinQueries.java index 92db4277036a1..9dcd7123bfbb7 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeJoinQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeJoinQueries.java @@ -72,8 +72,11 @@ public void testSemiJoin(Session joinTypeSession) { assertQuery(joinTypeSession, "SELECT * FROM orders WHERE orderdate IN (SELECT shipdate FROM lineitem) or orderdate IN (SELECT commitdate FROM lineitem)"); assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)"); - assertQuery(joinTypeSession, "SELECT * FROM lineitem " + - "WHERE linenumber = 3 OR orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)"); + assertQuery(joinTypeSession, "SELECT * FROM lineitem WHERE linenumber = 3 OR orderkey IN (SELECT orderkey FROM orders WHERE (orderkey + custkey) % 2 = 0)"); + assertQuery(joinTypeSession, "WITH\n" + + "users AS (SELECT orderkey FROM orders ),\n" + + "left_table AS (SELECT * FROM ( VALUES (0, NULL), (283755559, NULL), (NULL, NULL) ) AS left_table (userid, sid_cast))\n" + + "SELECT userid FROM left_table WHERE (sid_cast IS NOT NULL) OR (NOT (userid IN (SELECT * FROM users)))"); } @Test(dataProvider = "joinTypeProvider") diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java index 5431b422acc09..3c085def1bf25 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; +import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; public abstract class AbstractTestNativeWindowQueries @@ -36,6 +37,7 @@ protected void createTables() { QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner(); createOrders(queryRunner); + createLineitem(queryRunner); } private static final List OVER_CLAUSES_WITH_ORDER_BY = Arrays.asList( @@ -197,7 +199,6 @@ public void testOverlappingPartitionAndSortingKeys() assertQuery("SELECT min(orderkey) OVER (PARTITION BY orderdate ORDER BY orderdate, totalprice) FROM orders"); assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey, orderstatus) rn, * from orders) WHERE rn = 1"); - assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) " + - "SELECT * FROM t WHERE rn = 1"); + assertQuery("WITH t AS (SELECT linenumber, row_number() over (partition by linenumber order by linenumber) as rn FROM lineitem) SELECT * FROM t WHERE rn = 1"); } }