Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41630][SQL] Support implicit lateral column alias resolution on Project #38776

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
],
"sqlState" : "42000"
},
"AMBIGUOUS_LATERAL_COLUMN_ALIAS" : {
"message" : [
"Lateral column alias <name> is ambiguous and has <n> matches."
],
"sqlState" : "42000"
},
"AMBIGUOUS_REFERENCE" : {
"message" : [
"Reference <name> is ambiguous, could be: <referenceNames>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -288,6 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager)
AddMetadataColumns ::
DeduplicateRelations ::
ResolveReferences ::
WrapLateralColumnAliasReference ::
ResolveLateralColumnAliasReference ::
ResolveExpressionsWithNamePlaceholders ::
ResolveDeserializer ::
ResolveNewInstance ::
Expand Down Expand Up @@ -1672,7 +1674,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// Only Project and Aggregate can host star expressions.
case u @ (_: Project | _: Aggregate) =>
Try(s.expand(u.children.head, resolver)) match {
case Success(expanded) => expanded.map(wrapOuterReference)
case Success(expanded) => expanded.map(wrapOuterReference(_))
case Failure(_) => throw e
}
// Do not use the outer plan to resolve the star expression
Expand Down Expand Up @@ -1761,6 +1763,114 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

/**
* The first phase to resolve lateral column alias. See comments in
* [[ResolveLateralColumnAliasReference]] for more detailed explanation.
*/
object WrapLateralColumnAliasReference extends Rule[LogicalPlan] {
import ResolveLateralColumnAliasReference.AliasEntry

private def insertIntoAliasMap(
a: Alias,
idx: Int,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = {
val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry])
aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx)))
}

/**
* Use the given lateral alias to resolve the unresolved attribute with the name parts.
*
* Construct a dummy plan with the given lateral alias as project list, use the output of the
* plan to resolve.
* @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve.
*/
private def resolveByLateralAlias(
nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = {
// TODO question: everytime it resolves the extract field it generates a new exprId.
// Does it matter?
val resolvedAttr = resolveExpressionByPlanOutput(
expr = UnresolvedAttribute(nameParts),
plan = Project(Seq(lateralAlias), OneRowRelation()),
throws = false
).asInstanceOf[NamedExpression]
if (resolvedAttr.resolved) {
Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute))
} else {
None
}
}

/**
* Recognize all the attributes in the given expression that reference lateral column aliases
* by looking up the alias map. Resolve these attributes and replace by wrapping with
* [[LateralColumnAliasReference]].
*
* @param currentPlan Because lateral alias has lower resolution priority than table columns,
* the current plan is needed to first try resolving the attribute by its
* children
*/
private def wrapLCARefHelper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private def wrapLCARefHelper(
private def wrapLCARef(

e: NamedExpression,
currentPlan: LogicalPlan,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = {
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) {
case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) &&
resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] =>
val aliases = aliasMap.get(u.nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n)
case n if n == 1 && aliases.head.alias.resolved =>
// Only resolved alias can be the lateral column alias
// The lateral alias can be a struct and have nested field, need to construct
// a dummy plan to resolve the expression
resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u)
case _ => u
}
case o: OuterReference
if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) =>
// handle OuterReference exactly same as UnresolvedAttribute
val nameParts = o.nameParts.getOrElse(Seq(o.name))
val aliases = aliasMap.get(nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n)
case n if n == 1 && aliases.head.alias.resolved =>
resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o)
case _ => o
}
}.asInstanceOf[NamedExpression]
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) {
case p @ Project(projectList, _) if p.childrenResolved
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>
var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaWrapped = wrapLCARefHelper(a, p, aliasMap).asInstanceOf[Alias]
// Insert the LCA-resolved alias instead of the unresolved one into map. If it is
// resolved, it can be referenced as LCA by later expressions (chaining).
// Unresolved Alias is also added to the map to perform ambiguous name check, but
// only resolved alias can be LCA.
aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap)
lcaWrapped
case (e, _) =>
wrapLCARefHelper(e, p, aliasMap)
}
p.copy(projectList = newProjectList)
}
}
}
}

private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer]))
}
Expand Down Expand Up @@ -2143,7 +2253,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case u @ UnresolvedAttribute(nameParts) => withPosition(u) {
try {
AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match {
case Some(resolved) => wrapOuterReference(resolved)
case Some(resolved) => wrapOuterReference(resolved, Some(nameParts))
case None => u
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
Expand Down Expand Up @@ -638,6 +638,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case UnresolvedWindowExpression(_, windowSpec) =>
throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name)
})
// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we have a rule like RemoveTempResolvedColumn to restore LateralColumnAliasReference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add it intentionally. This is because I don't want those attributes actually can be resolve as LCA but to show in the error msg as UnresolvedAttribute. Also note that unlike RemoveTempResolvedColumn, LCARef can't be directly resolved to the NamedExpression inside of it because the plan won't be right - there is no alias push down.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this should not happen, we should throw an internal error SparkThrowable.internalError, so that it can include more debug information, instead of UNRESOLVED_COLUMN

projectList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if p.resolved =>
failUnresolvedAttribute(
p, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN")
})

case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
Expand Down Expand Up @@ -714,6 +722,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"operator" -> other.nodeName,
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))

// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
case agg @ Aggregate(_, aggList, _)
if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved =>
aggList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference =>
failUnresolvedAttribute(
agg, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN")
})

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE}
import org.apache.spark.sql.internal.SQLConf

/**
* This rule is the second phase to resolve lateral column alias.
*
* Resolve lateral column alias, which references the alias defined previously in the SELECT list.
* Plan-wise, it handles two types of operators: Project and Aggregate.
* - in Project, pushing down the referenced lateral alias into a newly created Project, resolve
* the attributes referencing these aliases
* - in Aggregate TODO.
*
* The whole process is generally divided into two phases:
* 1) recognize resolved lateral alias, wrap the attributes referencing them with
* [[LateralColumnAliasReference]]
* 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]].
* For Project, it further resolves the attributes and push down the referenced lateral aliases.
* For Aggregate, TODO
*
* Example for Project:
* Before rewrite:
* Project [age AS a, 'a + 1]
* +- Child
*
* After phase 1:
* Project [age AS a, lateralalias(a) + 1]
* +- Child
*
* After phase 2:
* Project [a, a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* Example for Aggregate TODO
*
*
* The name resolution priority:
* local table column > local lateral column alias > outer reference
*
* Because lateral column alias has higher resolution priority than outer reference, it will try
* to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an
* [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with
* [[LateralColumnAliasReference]].
*/
object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
case class AliasEntry(alias: Alias, index: Int)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
// phase 2: unwrap
plan.resolveOperatorsUpWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) {
case p @ Project(projectList, child) if p.resolved
&& projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
var aliasMap = Map[Attribute, AliasEntry]()
val referencedAliases = collection.mutable.Set.empty[AliasEntry]
def unwrapLCAReference(e: NamedExpression): NamedExpression = {
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) =>
val aliasEntry = aliasMap(lcaRef.a)
// If there is no chaining of lateral column alias reference, push down the alias
// and unwrap the LateralColumnAliasReference to the NamedExpression inside
// If there is chaining, don't resolve and save to future rounds
if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
referencedAliases += aliasEntry
lcaRef.ne
} else {
lcaRef
}
case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) =>
// It shouldn't happen, but restore to unresolved attribute to be safe.
UnresolvedAttribute(lcaRef.nameParts)
}.asInstanceOf[NamedExpression]
}
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaResolved = unwrapLCAReference(a)
// Insert the original alias instead of rewritten one to detect chained LCA
aliasMap += (a.toAttribute -> AliasEntry(a, idx))
lcaResolved
case (e, _) =>
unwrapLCAReference(e)
}

if (referencedAliases.isEmpty) {
p
} else {
val outerProjectList = collection.mutable.Seq(newProjectList: _*)
val innerProjectList =
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*)
referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
outerProjectList.update(idx, alias.toAttribute)
innerProjectList += alias
}
p.copy(
projectList = outerProjectList.toSeq,
child = Project(innerProjectList.toSeq, child)
)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,51 @@ case class OuterReference(e: NamedExpression)
override def qualifier: Seq[String] = e.qualifier
override def exprId: ExprId = e.exprId
override def toAttribute: Attribute = e.toAttribute
override def newInstance(): NamedExpression = OuterReference(e.newInstance())
override def newInstance(): NamedExpression =
OuterReference(e.newInstance()).setNameParts(nameParts)
final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE)

// optional field, the original name parts of UnresolvedAttribute before it is resolved to
// OuterReference. Used in rule ResolveLateralColumnAlias to convert OuterReference back to
// LateralColumnAliasReference.
var nameParts: Option[Seq[String]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have to keep a mutable state, TreeNodeTag is a better choice. Directly using var in catalyst TreeNode is strongly discouraged.

Copy link
Contributor Author

@anchovYu anchovYu Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: I didn't add it in the constructor of OuterReference due to binary compatibility. Is that concern valid? Actually, what is the risk to change the constructor, but also write another unapply function? This seems impossible without introducing a new object with another name, and still requires large portion of code change of pattern matching.

def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit tricky. Maybe we should invoke WrapLateralColumnAliasReference in ResolveOuterReferences, so that we don't need to re-resolve outer references and introduce this hack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed before, I feel it is not safe to do so given the current solution in ResolveOuterReference that each rule is applied only once. I made up a query (it can't run, just for demonstration):

SELECT *
FROM range(1, 7)
WHERE (
  SELECT id2
  FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5
ORDER BY id

It is possible that dept * 2.0 is not resolved because it needs type conversion, so the LCA rule doesn't apply. Then it just wraps the id in id + 1 AS id2 as OuterReference.

nameParts = newNameParts
this
}
}

/**
* A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the
* reference to a lateral column alias.
*
* This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]].
* There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all
* analysis check, then all [[LateralColumnAliasReference]] should already be removed.
*
* @param ne the resolved [[NamedExpression]] by lateral column alias
* @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back
* to [[UnresolvedAttribute]] when needed
* @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping
* and resolving LateralColumnAliasReference
*/
case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute)
extends LeafExpression with NamedExpression with Unevaluable {
assert(ne.resolved)
override def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
override def exprId: ExprId = ne.exprId
override def qualifier: Seq[String] = ne.qualifier
override def toAttribute: Attribute = ne.toAttribute
override def newInstance(): NamedExpression =
LateralColumnAliasReference(ne.newInstance(), nameParts, a)

override def nullable: Boolean = ne.nullable
override def dataType: DataType = ne.dataType
override def prettyName: String = "lateralAliasReference"
override def sql: String = s"$prettyName($name)"

final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE)
}

object VirtualColumn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ object SubExprUtils extends PredicateHelper {
/**
* Wrap attributes in the expression with [[OuterReference]]s.
*/
def wrapOuterReference[E <: Expression](e: E): E = {
e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E]
def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = {
e.transform { case a: Attribute => OuterReference(a).setNameParts(nameParts) }.asInstanceOf[E]
}

/**
Expand Down
Loading