Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50441][SQL] Fix parametrized identifiers not working when referencing CTEs #48994

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
resolveReferencesInSort(s)

case u: UnresolvedWithCTERelations =>
UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ object CTESubstitution extends Rule[LogicalPlan] {
resolvedCTERelations
}

private def resolveWithCTERelations(
table: String,
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)],
unresolvedRelation: UnresolvedRelation): LogicalPlan = {
cteRelations
.find(r => conf.resolver(r._1, table))
.map {
case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}
.getOrElse(unresolvedRelation)
}

private def substituteCTE(
plan: LogicalPlan,
alwaysInline: Boolean,
Expand All @@ -279,22 +298,20 @@ object CTESubstitution extends Rule[LogicalPlan] {
throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))

case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}.getOrElse(u)
resolveWithCTERelations(table, alwaysInline, cteRelations, u)

case p: PlanWithUnresolvedIdentifier =>
// We must look up CTE relations first when resolving `UnresolvedRelation`s,
// but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node
// and may produce `UnresolvedRelation` later.
// Here we wrap it with `UnresolvedWithCTERelations` so that we can
// delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved.
Copy link
Contributor

@cloud-fan cloud-fan Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the comment needs an update, we now use a different planBuilder to delay the CTE lookup.

UnresolvedWithCTERelations(p, cteRelations)
// and may produce `UnresolvedRelation` later. Instead, we delay CTE resolution
// by moving it to the planBuilder of the corresponding `PlanWithUnresolvedIdentifier`.
p.copy(planBuilder = (nameParts, children) => {
p.planBuilder.apply(nameParts, children) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
resolveWithCTERelations(table, alwaysInline, cteRelations, u)
case other => other
}
})

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.types.StringType

/**
Expand All @@ -30,18 +30,9 @@ import org.apache.spark.sql.types.StringType
object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) {
_.containsPattern(UNRESOLVED_IDENTIFIER)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
case u @ UnresolvedWithCTERelations(p, cteRelations) =>
this.apply(p) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}.getOrElse(u)
case other => other
}
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_IDENTIFIER_WITH_CTE, UNRESOLVED_WITH}
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -189,7 +189,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase {
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
// relations are not children of `UnresolvedWith`.
case NameParameterizedQuery(child, argNames, argValues)
if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) &&
if !child.containsPattern(UNRESOLVED_WITH) &&
argValues.forall(_.resolved) =>
if (argNames.length != argValues.length) {
throw SparkException.internalError(s"The number of argument names ${argNames.length} " +
Expand All @@ -200,7 +200,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase {
bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }

case PosParameterizedQuery(child, args)
if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) &&
if !child.containsPattern(UNRESOLVED_WITH) &&
args.forall(_.resolved) =>
val indexedArgs = args.zipWithIndex
checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand Down Expand Up @@ -76,17 +76,6 @@ case class PlanWithUnresolvedIdentifier(
copy(identifierExpr, newChildren, planBuilder)
}

/**
* A logical plan placeholder which delays CTE resolution
* to moment when PlanWithUnresolvedIdentifier gets resolved
*/
case class UnresolvedWithCTERelations(
unresolvedPlan: LogicalPlan,
cteRelations: Seq[(String, CTERelationDef)])
extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE)
}

/**
* An expression placeholder that holds the identifier clause string expression. It will be
* replaced by the actual expression with the evaluated identifier string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ object TreePattern extends Enumeration {
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value

// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,4 +758,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest {
checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1))
}
}

test("SPARK-50441: parameterized identifier referencing a CTE") {
def query(p: String): String = {
s"""
|WITH t1 AS (SELECT 1)
|SELECT * FROM IDENTIFIER($p)""".stripMargin
}

checkAnswer(spark.sql(query(":cte"), args = Map("cte" -> "t1")), Row(1))
checkAnswer(spark.sql(query("?"), args = Array("t1")), Row(1))
}
}