From 3fab712f69f0073d6e5481d43c455363431952fc Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Fri, 29 Nov 2024 21:46:32 +0800 Subject: [PATCH] [SPARK-50441][SQL] Fix parametrized identifiers not working when referencing CTEs ### What changes were proposed in this pull request? Fix parametrized identifiers not working when referencing CTEs ### Why are the changes needed? For a query: `with t1 as (select 1) select * from identifier(:cte) using cte as "t1"` the resolution fails because `BindParameters` can't resolve parameters because it waits for `ResolveIdentifierClause` to resolve `UnresolvedWithCTERelation`, but `ResolveIdentifierClause` can't resolve `UnresolvedWithCTERelation` until all `NamedParameters` in the plan are resolved. Instead of delaying CTE resolution with `UnresolvedWithCTERelation`, we can remove node entirely and delay the resolution by keeping the original `PlanWithUnresolvedIdentifier` and moving the CTE resolution to its `planBuilder`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a new test to `ParametersSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48994 from mihailotim-db/mihailotim-db/cte_identifer. Authored-by: Mihailo Timotic Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 3 -- .../catalyst/analysis/CTESubstitution.scala | 41 +++++++++++++------ .../analysis/ResolveIdentifierClause.scala | 15 ++----- .../sql/catalyst/analysis/parameters.scala | 6 +-- .../sql/catalyst/analysis/unresolved.scala | 13 +----- .../sql/catalyst/trees/TreePatterns.scala | 1 - .../apache/spark/sql/ParametersSuite.scala | 11 +++++ 7 files changed, 47 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3af3565220bdb..089e18e3df4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1610,9 +1610,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case s: Sort if !s.resolved || s.missingInput.nonEmpty => resolveReferencesInSort(s) - case u: UnresolvedWithCTERelations => - UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations) - case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index ff0dbcd7ef153..d75e7d528d5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -267,6 +267,25 @@ object CTESubstitution extends Rule[LogicalPlan] { resolvedCTERelations } + private def resolveWithCTERelations( + table: String, + alwaysInline: Boolean, + cteRelations: Seq[(String, CTERelationDef)], + unresolvedRelation: UnresolvedRelation): LogicalPlan = { + cteRelations + .find(r => conf.resolver(r._1, table)) + .map { + case (_, d) => + if (alwaysInline) { + d.child + } else { + // Add a `SubqueryAlias` for hint-resolving rules to match relation names. + SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming)) + } + } + .getOrElse(unresolvedRelation) + } + private def substituteCTE( plan: LogicalPlan, alwaysInline: Boolean, @@ -279,22 +298,20 @@ object CTESubstitution extends Rule[LogicalPlan] { throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table)) case u @ UnresolvedRelation(Seq(table), _, _) => - cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) => - if (alwaysInline) { - d.child - } else { - // Add a `SubqueryAlias` for hint-resolving rules to match relation names. - SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming)) - } - }.getOrElse(u) + resolveWithCTERelations(table, alwaysInline, cteRelations, u) case p: PlanWithUnresolvedIdentifier => // We must look up CTE relations first when resolving `UnresolvedRelation`s, // but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node - // and may produce `UnresolvedRelation` later. - // Here we wrap it with `UnresolvedWithCTERelations` so that we can - // delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved. - UnresolvedWithCTERelations(p, cteRelations) + // and may produce `UnresolvedRelation` later. Instead, we delay CTE resolution + // by moving it to the planBuilder of the corresponding `PlanWithUnresolvedIdentifier`. + p.copy(planBuilder = (nameParts, children) => { + p.planBuilder.apply(nameParts, children) match { + case u @ UnresolvedRelation(Seq(table), _, _) => + resolveWithCTERelations(table, alwaysInline, cteRelations, u) + case other => other + } + }) case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index 0e1e71a658c8b..2cf3c6390d5fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE} +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER import org.apache.spark.sql.types.StringType /** @@ -30,18 +30,9 @@ import org.apache.spark.sql.types.StringType object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) { + _.containsPattern(UNRESOLVED_IDENTIFIER)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children) - case u @ UnresolvedWithCTERelations(p, cteRelations) => - this.apply(p) match { - case u @ UnresolvedRelation(Seq(table), _, _) => - cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) => - // Add a `SubqueryAlias` for hint-resolving rules to match relation names. - SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming)) - }.getOrElse(u) - case other => other - } case other => other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala index f24227abbb651..de73747769469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_IDENTIFIER_WITH_CTE, UNRESOLVED_WITH} +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.DataType @@ -189,7 +189,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase { // We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE // relations are not children of `UnresolvedWith`. case NameParameterizedQuery(child, argNames, argValues) - if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) && + if !child.containsPattern(UNRESOLVED_WITH) && argValues.forall(_.resolved) => if (argNames.length != argValues.length) { throw SparkException.internalError(s"The number of argument names ${argNames.length} " + @@ -200,7 +200,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase { bind(child) { case NamedParameter(name) if args.contains(name) => args(name) } case PosParameterizedQuery(child, args) - if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) && + if !child.containsPattern(UNRESOLVED_WITH) && args.forall(_.resolved) => val indexedArgs = args.zipWithIndex checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7fc8aff72b81d..0a73b6b856740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -76,17 +76,6 @@ case class PlanWithUnresolvedIdentifier( copy(identifierExpr, newChildren, planBuilder) } -/** - * A logical plan placeholder which delays CTE resolution - * to moment when PlanWithUnresolvedIdentifier gets resolved - */ -case class UnresolvedWithCTERelations( - unresolvedPlan: LogicalPlan, - cteRelations: Seq[(String, CTERelationDef)]) - extends UnresolvedLeafNode { - final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE) -} - /** * An expression placeholder that holds the identifier clause string expression. It will be * replaced by the actual expression with the evaluated identifier string. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 7435f4c527034..e95712281cb42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -154,7 +154,6 @@ object TreePattern extends Enumeration { val UNRESOLVED_FUNCTION: Value = Value val UNRESOLVED_HINT: Value = Value val UNRESOLVED_WINDOW_EXPRESSION: Value = Value - val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 791bcc91d5094..2ac8ed26868a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -758,4 +758,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest { checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1)) } } + + test("SPARK-50441: parameterized identifier referencing a CTE") { + def query(p: String): String = { + s""" + |WITH t1 AS (SELECT 1) + |SELECT * FROM IDENTIFIER($p)""".stripMargin + } + + checkAnswer(spark.sql(query(":cte"), args = Map("cte" -> "t1")), Row(1)) + checkAnswer(spark.sql(query("?"), args = Array("t1")), Row(1)) + } }