Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41630][SQL] Support implicit lateral column alias resolution on Project #38776

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
],
"sqlState" : "42000"
},
"AMBIGUOUS_LATERAL_COLUMN_ALIAS" : {
"message" : [
"Lateral column alias <name> is ambiguous and has <n> matches."
],
"sqlState" : "42000"
},
"AMBIGUOUS_REFERENCE" : {
"message" : [
"Reference <name> is ambiguous, could be: <referenceNames>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])

override def contains(k: Attribute): Boolean = get(k).isDefined

override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
AttributeMap(baseMap.values.toMap + kv)

override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])

override def contains(k: Attribute): Boolean = get(k).isDefined

override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
AttributeMap(baseMap.values.toMap + kv)

override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] =
baseMap.values.toMap + (key -> value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -288,6 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager)
AddMetadataColumns ::
DeduplicateRelations ::
ResolveReferences ::
WrapLateralColumnAliasReference ::
ResolveLateralColumnAliasReference ::
ResolveExpressionsWithNamePlaceholders ::
ResolveDeserializer ::
ResolveNewInstance ::
Expand Down Expand Up @@ -1672,7 +1674,7 @@ class Analyzer(override val catalogManager: CatalogManager)
// Only Project and Aggregate can host star expressions.
case u @ (_: Project | _: Aggregate) =>
Try(s.expand(u.children.head, resolver)) match {
case Success(expanded) => expanded.map(wrapOuterReference)
case Success(expanded) => expanded.map(wrapOuterReference(_))
case Failure(_) => throw e
}
// Do not use the outer plan to resolve the star expression
Expand Down Expand Up @@ -1761,6 +1763,117 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

/**
* The first phase to resolve lateral column alias. See comments in
* [[ResolveLateralColumnAliasReference]] for more detailed explanation.
*/
object WrapLateralColumnAliasReference extends Rule[LogicalPlan] {
import ResolveLateralColumnAliasReference.AliasEntry

private def insertIntoAliasMap(
a: Alias,
idx: Int,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = {
val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry])
aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx)))
}

/**
* Use the given lateral alias to resolve the unresolved attribute with the name parts.
*
* Construct a dummy plan with the given lateral alias as project list, use the output of the
* plan to resolve.
* @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve.
*/
private def resolveByLateralAlias(
nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = {
val resolvedAttr = resolveExpressionByPlanOutput(
expr = UnresolvedAttribute(nameParts),
plan = LocalRelation(Seq(lateralAlias.toAttribute)),
throws = false
).asInstanceOf[NamedExpression]
if (resolvedAttr.resolved) {
Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute))
} else {
None
}
}

/**
* Recognize all the attributes in the given expression that reference lateral column aliases
* by looking up the alias map. Resolve these attributes and replace by wrapping with
* [[LateralColumnAliasReference]].
*
* @param currentPlan Because lateral alias has lower resolution priority than table columns,
* the current plan is needed to first try resolving the attribute by its
* children
*/
private def wrapLCARef(
e: NamedExpression,
currentPlan: LogicalPlan,
aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = {
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) {
case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) &&
resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] =>
val aliases = aliasMap.get(u.nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n)
case n if n == 1 && aliases.head.alias.resolved =>
// Only resolved alias can be the lateral column alias
// The lateral alias can be a struct and have nested field, need to construct
// a dummy plan to resolve the expression
resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u)
case _ => u
}
case o: OuterReference
if aliasMap.contains(
o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR)
.map(_.head)
.getOrElse(o.name)) =>
// handle OuterReference exactly same as UnresolvedAttribute
val nameParts = o
.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR)
.getOrElse(Seq(o.name))
val aliases = aliasMap.get(nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n)
case n if n == 1 && aliases.head.alias.resolved =>
resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o)
case _ => o
}
}.asInstanceOf[NamedExpression]
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) {
case p @ Project(projectList, _) if p.childrenResolved
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>
var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias]
// Insert the LCA-resolved alias instead of the unresolved one into map. If it is
// resolved, it can be referenced as LCA by later expressions (chaining).
// Unresolved Alias is also added to the map to perform ambiguous name check, but
// only resolved alias can be LCA.
aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap)
lcaWrapped
case (e, _) =>
wrapLCARef(e, p, aliasMap)
}
p.copy(projectList = newProjectList)
}
}
}
}

private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer]))
}
Expand Down Expand Up @@ -2143,7 +2256,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case u @ UnresolvedAttribute(nameParts) => withPosition(u) {
try {
AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match {
case Some(resolved) => wrapOuterReference(resolved)
case Some(resolved) => wrapOuterReference(resolved, Some(nameParts))
case None => u
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils}
import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
Expand Down Expand Up @@ -638,6 +638,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case UnresolvedWindowExpression(_, windowSpec) =>
throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name)
})
// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we have a rule like RemoveTempResolvedColumn to restore LateralColumnAliasReference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add it intentionally. This is because I don't want those attributes actually can be resolve as LCA but to show in the error msg as UnresolvedAttribute. Also note that unlike RemoveTempResolvedColumn, LCARef can't be directly resolved to the NamedExpression inside of it because the plan won't be right - there is no alias push down.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this should not happen, we should throw an internal error SparkThrowable.internalError, so that it can include more debug information, instead of UNRESOLVED_COLUMN

projectList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if p.resolved =>
throw SparkException.internalError("Resolved Project should not contain " +
s"any LateralColumnAliasReference.\nDebugging information: plan: $p",
context = lcaRef.origin.getQueryContext,
summary = lcaRef.origin.context.summary)
})

case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
Expand Down Expand Up @@ -714,6 +724,19 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"operator" -> other.nodeName,
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))

// This should not happen, resolved Project or Aggregate should restore or resolve
// all lateral column alias references. Add check for extra safe.
case agg @ Aggregate(_, aggList, _)
if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved =>
aggList.foreach(_.transformDownWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference =>
throw SparkException.internalError("Resolved Aggregate should not contain " +
s"any LateralColumnAliasReference.\nDebugging information: plan: $agg",
context = lcaRef.origin.getQueryContext,
summary = lcaRef.origin.context.summary)
})

case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE
import org.apache.spark.sql.internal.SQLConf

/**
* This rule is the second phase to resolve lateral column alias.
*
* Resolve lateral column alias, which references the alias defined previously in the SELECT list.
* Plan-wise, it handles two types of operators: Project and Aggregate.
* - in Project, pushing down the referenced lateral alias into a newly created Project, resolve
* the attributes referencing these aliases
* - in Aggregate TODO.
*
* The whole process is generally divided into two phases:
* 1) recognize resolved lateral alias, wrap the attributes referencing them with
* [[LateralColumnAliasReference]]
* 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]].
* For Project, it further resolves the attributes and push down the referenced lateral aliases.
* For Aggregate, TODO
*
* Example for Project:
* Before rewrite:
* Project [age AS a, 'a + 1]
* +- Child
*
* After phase 1:
* Project [age AS a, lateralalias(a) + 1]
* +- Child
*
* After phase 2:
* Project [a, a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* Example for Aggregate TODO
*
*
* The name resolution priority:
* local table column > local lateral column alias > outer reference
*
* Because lateral column alias has higher resolution priority than outer reference, it will try
* to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an
* [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with
* [[LateralColumnAliasReference]].
*/
object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
case class AliasEntry(alias: Alias, index: Int)

/**
* A tag to store the nameParts from the original unresolved attribute.
* It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back
* to [[LateralColumnAliasReference]].
*/
val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr")

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
} else {
// phase 2: unwrap
plan.resolveOperatorsUpWithPruning(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) {
case p @ Project(projectList, child) if p.resolved
&& projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
var aliasMap = AttributeMap.empty[AliasEntry]
val referencedAliases = collection.mutable.Set.empty[AliasEntry]
def unwrapLCAReference(e: NamedExpression): NamedExpression = {
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) =>
val aliasEntry = aliasMap.get(lcaRef.a).get
// If there is no chaining of lateral column alias reference, push down the alias
// and unwrap the LateralColumnAliasReference to the NamedExpression inside
// If there is chaining, don't resolve and save to future rounds
if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
referencedAliases += aliasEntry
lcaRef.ne
} else {
lcaRef
}
case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) =>
// It shouldn't happen, but restore to unresolved attribute to be safe.
UnresolvedAttribute(lcaRef.nameParts)
}.asInstanceOf[NamedExpression]
}
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaResolved = unwrapLCAReference(a)
// Insert the original alias instead of rewritten one to detect chained LCA
aliasMap += (a.toAttribute -> AliasEntry(a, idx))
lcaResolved
case (e, _) =>
unwrapLCAReference(e)
}

if (referencedAliases.isEmpty) {
p
} else {
val outerProjectList = collection.mutable.Seq(newProjectList: _*)
val innerProjectList =
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*)
referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
outerProjectList.update(idx, alias.toAttribute)
innerProjectList += alias
}
p.copy(
projectList = outerProjectList.toSeq,
child = Project(innerProjectList.toSeq, child)
)
}
}
}
}
}
Loading