diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 62555c9a99cc3..e53efc3d881c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2425,7 +2425,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { - ListQuery(plan, exprs, exprId, plan.output) + ListQuery(plan, exprs, exprId, plan.output.length) }) InSubquery(values, expr.asInstanceOf[ListQuery]) case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 059c36c4f9044..bd2255134fca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -361,11 +361,11 @@ abstract class TypeCoercionBase { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions, _)) - if !i.resolved && lhs.length == sub.output.length => + case i @ InSubquery(lhs, l: ListQuery) + if !i.resolved && lhs.length == l.plan.output.length => // LHS is the value expressions of IN subquery. // RHS is the subquery output. - val rhs = sub.output + val rhs = l.plan.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findWiderTypeForTwo(l.dataType, r.dataType) @@ -383,8 +383,7 @@ abstract class TypeCoercionBase { case (e, _) => e } - val newSub = Project(castedRhs, sub) - InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output, conditions)) + InSubquery(newLhs, l.withNewPlan(Project(castedRhs, l.plan))) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 64bee643c86c6..38005e7865385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -367,12 +367,12 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY) override def checkInputDataTypes(): TypeCheckResult = { - if (values.length != query.childOutputs.length) { + if (values.length != query.numCols) { DataTypeMismatch( errorSubClass = "IN_SUBQUERY_LENGTH_MISMATCH", messageParameters = Map( "leftLength" -> values.length.toString, - "rightLength" -> query.childOutputs.length.toString, + "rightLength" -> query.numCols.toString, "leftColumns" -> values.map(toSQLExpr(_)).mkString(", "), "rightColumns" -> query.childOutputs.map(toSQLExpr(_)).mkString(", ") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 228bb4805c85f..1e957466308a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -354,16 +354,19 @@ case class ListQuery( plan: LogicalPlan, outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, - childOutputs: Seq[Attribute] = Seq.empty, + // The plan of list query may have more columns after de-correlation, and we need to track the + // number of the columns of the original plan, to report the data type properly. + numCols: Int = -1, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { - override def dataType: DataType = if (childOutputs.length > 1) { + def childOutputs: Seq[Attribute] = plan.output.take(numCols) + override def dataType: DataType = if (numCols > 1) { childOutputs.toStructType } else { - childOutputs.head.dataType + plan.output.head.dataType } - override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty + override lazy val resolved: Boolean = childrenResolved && plan.resolved && numCols != -1 override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint) @@ -373,7 +376,7 @@ case class ListQuery( plan.canonicalized, outerAttrs.map(_.canonicalized), ExprId(0), - childOutputs.map(_.canonicalized.asInstanceOf[Attribute]), + numCols, joinCond.map(_.canonicalized)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 5c11a39857559..44c5586037514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -109,7 +109,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J return filterApplicationSidePlan } val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), - ListQuery(aggregate, childOutputs = aggregate.output)) + ListQuery(aggregate, numCols = aggregate.output.length)) Filter(filter, filterApplicationSidePlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 83ff5e3973910..7355032db79b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -346,10 +346,10 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan) Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) - case ListQuery(sub, children, exprId, childOutputs, conditions, hint) if children.nonEmpty => + case ListQuery(sub, children, exprId, numCols, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan) val joinCond = getJoinCondition(newCond, conditions) - ListQuery(newPlan, children, exprId, childOutputs, joinCond, hint) + ListQuery(newPlan, children, exprId, numCols, joinCond, hint) case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true) LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index d1dcf7b76f67d..fe3a74f66a596 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1500,4 +1500,28 @@ class AnalysisSuite extends AnalysisTest with Matchers { assert(refs.map(_.output).distinct.length == 3) } } + + test("SPARK-43190: ListQuery.childOutput should be consistent with child output") { + val listQuery1 = ListQuery(testRelation2.select($"a")) + val listQuery2 = ListQuery(testRelation2.select($"b")) + val plan = testRelation3.where($"f".in(listQuery1) && $"f".in(listQuery2)).analyze + val resolvedCondition = plan.expressions.head + val finalPlan = testRelation2.join(testRelation3).where(resolvedCondition).analyze + val resolvedListQueries = finalPlan.expressions.flatMap(_.collect { + case l: ListQuery => l + }) + assert(resolvedListQueries.length == 2) + + def collectLocalRelations(plan: LogicalPlan): Seq[LocalRelation] = plan.collect { + case l: LocalRelation => l + } + val localRelations = resolvedListQueries.flatMap(l => collectLocalRelations(l.plan)) + assert(localRelations.length == 2) + // DeduplicateRelations should deduplicate plans in subquery expressions as well. + assert(localRelations.head.output != localRelations.last.output) + + resolvedListQueries.foreach { l => + assert(l.childOutputs == l.plan.output) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e8943c2dba383..a2ad4f370de2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -78,10 +78,7 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s case e: Exists => e.copy(plan = normalizeExprIds(e.plan), exprId = ExprId(0)) case l: ListQuery => - l.copy( - plan = normalizeExprIds(l.plan), - exprId = ExprId(0), - childOutputs = l.childOutputs.map(_.withExprId(ExprId(0)))) + l.copy(plan = normalizeExprIds(l.plan), exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case OuterReference(a: AttributeReference) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index 42c4c20e20d8a..fef92edbce649 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -83,7 +83,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp val alias = Alias(buildKeys(broadcastKeyIndex), buildKeys(broadcastKeyIndex).toString)() val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan) DynamicPruningExpression(expressions.InSubquery( - Seq(value), ListQuery(aggregate, childOutputs = aggregate.output))) + Seq(value), ListQuery(aggregate, numCols = aggregate.output.length))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index f2b513e630b5b..2877ff46edb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -90,6 +90,6 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla val buildQuery = Aggregate(buildKeys, buildKeys, matchingRowsPlan) DynamicPruningExpression( - InSubquery(pruningKeys, ListQuery(buildQuery, childOutputs = buildQuery.output))) + InSubquery(pruningKeys, ListQuery(buildQuery, numCols = buildQuery.output.length))) } }