Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mihailotim-db committed Nov 27, 2024
1 parent 6edcf43 commit 13f2813
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with E
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}.getOrElse(u)
case other => other
case _ => u
}
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,19 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase {
}
}

private def bind(p: LogicalPlan)(f: PartialFunction[Expression, Expression]): LogicalPlan = {
p.resolveExpressionsWithPruning(_.containsPattern(PARAMETER)) (f orElse {
case sub: SubqueryExpression => sub.withNewPlan(bind(sub.plan)(f))
})
private def bindExpressionsForPlan(plan: LogicalPlan)(
f: PartialFunction[Expression, Expression]): LogicalPlan = {
plan.resolveExpressionsWithPruning(_.containsPattern(PARAMETER))(f)
}

private def bind(p: LogicalPlan)(
f: PartialFunction[Expression, Expression]): LogicalPlan = {
p.transformWithSubqueries {
case _ @UnresolvedWithCTERelations(plan, cteRelations) =>
val planWithIdentifiersReplaced = bindExpressionsForPlan(plan)(f)
UnresolvedWithCTERelations(planWithIdentifiersReplaced, cteRelations)
case plan: LogicalPlan => bindExpressionsForPlan(plan)(f)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,4 +758,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest {
checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1))
}
}

test("SPARK-50441: parameterized identifier referencing a CTE") {
def query(p: String): String = {
s"""
|WITH t1 AS (SELECT 1)
|SELECT * FROM IDENTIFIER($p)""".stripMargin
}

checkAnswer(spark.sql(query(":cte"), args = Map("cte" -> "t1")), Row(1))
checkAnswer(spark.sql(query("?"), args = Array("t1")), Row(1))
}
}

0 comments on commit 13f2813

Please sign in to comment.