Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21759][SQL] In.checkInputDataTypes should not wrongly report unresolved plans for IN correlated subquery #18968

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,10 @@ class Analyzer(
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
In(value, Seq(expr))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ object TypeCoercion {

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)
Expand Down Expand Up @@ -434,7 +434,8 @@ object TypeCoercion {
case _ => CreateStruct(castedLhs)
}

In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
val newSub = Project(castedRhs, sub)
In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
} else {
i
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,32 +138,33 @@ case class Not(child: Expression)
case class In(value: Expression, list: Seq[Expression]) extends Predicate {

require(list != null, "list should not be null")

override def checkInputDataTypes(): TypeCheckResult = {
list match {
case ListQuery(sub, _, _) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
if (valExprs.length != sub.output.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${valExprs.length}.
|#columns in right hand side: ${sub.output.length}.
|Left side columns:
|[${valExprs.map(_.sql).mkString(", ")}].
|Right side columns:
|[${sub.output.map(_.sql).mkString(", ")}].
""".stripMargin)
} else {
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
if (mismatchOpt.isDefined) {
list match {
case ListQuery(_, _, _, childOutputs) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
if (mismatchedColumns.nonEmpty) {
if (valExprs.length != childOutputs.length) {
TypeCheckResult.TypeCheckFailure(
s"""
|The number of columns in the left hand side of an IN subquery does not match the
|number of columns in the output of subquery.
|#columns in left hand side: ${valExprs.length}.
|#columns in right hand side: ${childOutputs.length}.
|Left side columns:
|[${valExprs.map(_.sql).mkString(", ")}].
|Right side columns:
|[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
} else {
val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
}
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
Expand All @@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
}
}
case _ =>
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
case _ =>
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,24 @@ object ScalarSubquery {
case class ListQuery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId)
exprId: ExprId = NamedExpression.newExprId,
childOutputs: Seq[Attribute] = Seq.empty)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
override def dataType: DataType = plan.schema.fields.head.dataType
override def dataType: DataType = if (childOutputs.length > 1) {
childOutputs.toStructType
} else {
childOutputs.head.dataType
}
override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we fill in childOutputs, this ListQuery cannot be resolved. Otherwise, to access its dataType causes failure in In.checkInputDataTypes.

override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
ExprId(0),
childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
Expand Down Expand Up @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val exists = AttributeReference("exists", BooleanType, nullable = false)()
newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case In(value, Seq(ListQuery(sub, conditions, _))) =>
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
Expand Down Expand Up @@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
case Exists(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, newCond, exprId)
case ListQuery(sub, _, exprId) =>
case ListQuery(sub, _, exprId, childOutputs) =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
ListQuery(newPlan, newCond, exprId)
ListQuery(newPlan, newCond, exprId, childOutputs)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("PullupCorrelatedPredicates", Once,
PullupCorrelatedPredicates) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.double)
val testRelation2 = LocalRelation('c.int, 'd.double)

test("PullupCorrelatedPredicates should not produce unresolved plan") {
val correlatedSubquery =
testRelation2
.where('b < 'd)
.select('c)
val outerQuery =
testRelation
.where(In('a, Seq(ListQuery(correlatedSubquery))))
.select('a).analyze
assert(outerQuery.resolved)

val optimized = Optimize.execute(outerQuery)
assert(optimized.resolved)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ number of columns in the output of subquery.
Left side columns:
[t1.`t1a`].
Right side columns:
[t2.`t2a`, t2.`t2b`].
;
[t2.`t2a`, t2.`t2b`].;


-- !query 6
Expand All @@ -102,5 +101,4 @@ number of columns in the output of subquery.
Left side columns:
[t1.`t1a`, t1.`t1b`].
Right side columns:
[t2.`t2a`].
;
[t2.`t2a`].;