From 04959c2ac394a2e70b8c61d7abba5354469320da Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 23 Nov 2022 10:40:07 -0800 Subject: [PATCH 01/17] refactor analyzer adding a new object --- .../sql/catalyst/analysis/Analyzer.scala | 321 ++++++++++-------- 1 file changed, 171 insertions(+), 150 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1daa8ea36bf35..fc149308578c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -182,6 +183,164 @@ object AnalysisContext { } } +object Analyzer extends Logging { + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + resolver: Resolver, + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + withOrigin(u.origin) { + ExtractValue(newChild, fieldName, resolver) + } + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + try { + innerResolve(expr, isTopLevel = true) + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + resolver: Resolver, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + getAttrCandidates = () => plan.output, + resolver = resolver, + throws = throws) + } + + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + resolver: Resolver): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + resolver = resolver, + throws = true) + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) +} + /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -258,6 +417,17 @@ class Analyzer(override val catalogManager: CatalogManager) TypeCoercion.typeCoercionRules } + private def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws) + } + + private def resolveExpressionByPlanChildren(e: Expression, q: LogicalPlan): Expression = { + Analyzer.resolveExpressionByPlanChildren(e, q, resolver) + } + override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -1386,6 +1556,7 @@ class Analyzer(override val catalogManager: CatalogManager) * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.analysis.Analyzer.containsStar /** Return true if there're conflicting attributes among children's outputs of a plan */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { @@ -1698,12 +1869,6 @@ class Analyzer(override val catalogManager: CatalogManager) }.map(_.asInstanceOf[NamedExpression]) } - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) - private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) @@ -1764,150 +1929,6 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - try { - innerResolve(expr, isTopLevel = true) - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - throws = throws) - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - throws = true) - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the From 6f44c8500850b1d122510c66dc7e9b27e6adaf2d Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 23 Nov 2022 13:25:41 -0800 Subject: [PATCH 02/17] lca code (cherry picked from commit 94adb3f98d701e4c4f19189eb11134949b61bc45) --- .../main/resources/error/error-classes.json | 6 ++ .../sql/catalyst/analysis/Analyzer.scala | 100 +++++++++++++++++- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 10 ++ .../apache/spark/sql/internal/SQLConf.scala | 8 ++ 5 files changed, 124 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 77d155bfc21e4..e279ffc87d21e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -5,6 +5,12 @@ ], "sqlState" : "42000" }, + "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { + "message" : [ + "Lateral column alias is ambiguous and has matches." + ], + "sqlState" : "42000" + }, "AMBIGUOUS_REFERENCE" : { "message" : [ "Reference is ambiguous, could be: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc149308578c5..5f595e81d56e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -458,6 +458,7 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + ResolveLateralColumnAlias :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1551,6 +1552,103 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * Resolve lateral column alias, which references the alias defined previously in the SELECT list, + * - in Project inserting a new Project node with the referenced alias so that it can be + * resolved by other rules + * - in Aggregate TODO. + * + * For Project, it rewrites the Project plan by inserting a newly created Project plan between + * the original Project and its child, and updating the project list of the original Project plan. + * The project list of the new Project plan is the lateral column aliases that are referenced + * in the original project list. These aliases in the original project list are updated to + * attribute references. + * + * Before rewrite: + * Project [age AS a, a + 1] + * +- Child + * + * After rewrite: + * Project [a, a + 1] + * +- Project [age AS a] + * +- Child + */ + object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + private case class AliasEntry(alias: Alias, index: Int) + + private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + // TODO: delta + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + var referencedAliases = Seq[AliasEntry]() + def updateAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def searchMatchedLCA(e: Expression): Unit = { + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case _ => + val referencedAlias = aliases.head + // Only resolved alias can be the lateral column alias + if (referencedAlias.alias.resolved) { + referencedAliases :+= referencedAlias + } + } + u + } + } + projectList.zipWithIndex.foreach { + case (a: Alias, idx) => + // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed + // down. Unresolved alias is added to the map to perform the ambiguous name check. + // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, + // because only resolved alias can be LCA, in the first round the rule application, + // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are + // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, + // it is pushed down. + searchMatchedLCA(a) + updateAliasMap(a, idx) + case (e, _) => + searchMatchedLCA(e) + } + + referencedAliases = referencedAliases.sortBy(_.index) + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = projectList.to[collection.mutable.Seq] + val innerProjectList = + child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + plan + } else { + rewriteLateralColumnAlias(plan) + } + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index f6bef88ab868e..8493895218f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -66,6 +66,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 63c912c15a156..eeb1dfc213d94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3402,4 +3402,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { cause = Option(other)) } } + + def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84d78f365acbc..a5b84660e0581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,6 +4027,14 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) + val LATERAL_COLUMN_ALIAS_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enabled") + .internal() + .doc("Enable lateral column alias in analyzer") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * From 725e5ac9df65438f87f2c260ea5507aaf1a1bd2b Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 10:38:29 -0800 Subject: [PATCH 03/17] add tests, refine logic (cherry picked from commit 313b2c98e9513e50d2764b28c447c3a7cd281ebb) --- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++--- .../spark/sql/LateralColumnAliasSuite.scala | 109 ++++++++++++++++++ 2 files changed, 130 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5f595e81d56e2..165bb2ecf4ec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1558,20 +1558,20 @@ class Analyzer(override val catalogManager: CatalogManager) * resolved by other rules * - in Aggregate TODO. * - * For Project, it rewrites the Project plan by inserting a newly created Project plan between - * the original Project and its child, and updating the project list of the original Project plan. - * The project list of the new Project plan is the lateral column aliases that are referenced - * in the original project list. These aliases in the original project list are updated to - * attribute references. + * For Project, it rewrites by inserting a newly created Project plan between the original Project + * and its child, pushing the referenced lateral column aliases to this new Project, and updating + * the project list of the original Project. * * Before rewrite: - * Project [age AS a, a + 1] + * Project [age AS a, 'a + 1] * +- Child * * After rewrite: - * Project [a, a + 1] - * +- Project [age AS a] + * Project [a, 'a + 1] + * +- Project [child output, age AS a] * +- Child + * + * For Aggregate TODO. */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) @@ -1581,14 +1581,14 @@ class Analyzer(override val catalogManager: CatalogManager) case p @ Project(projectList, child) if p.childrenResolved && !ResolveReferences.containsStar(projectList) && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - // TODO: delta + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - var referencedAliases = Seq[AliasEntry]() - def updateAliasMap(a: Alias, idx: Int): Unit = { + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def searchMatchedLCA(e: Expression): Unit = { + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => @@ -1600,13 +1600,15 @@ class Analyzer(override val catalogManager: CatalogManager) val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias if (referencedAlias.alias.resolved) { - referencedAliases :+= referencedAlias + matchedLCA = Some(referencedAlias) } } u } + matchedLCA } - projectList.zipWithIndex.foreach { + + val referencedAliases = projectList.zipWithIndex.flatMap { case (a: Alias, idx) => // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed // down. Unresolved alias is added to the map to perform the ambiguous name check. @@ -1615,13 +1617,13 @@ class Analyzer(override val catalogManager: CatalogManager) // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, // it is pushed down. - searchMatchedLCA(a) - updateAliasMap(a, idx) + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + matchedLCA case (e, _) => - searchMatchedLCA(e) - } + lookUpLCA(e) + }.toSet - referencedAliases = referencedAliases.sortBy(_.index) if (referencedAliases.isEmpty) { p } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala new file mode 100644 index 0000000000000..daf750c39bb1e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { + protected val testTable: String = "employee" + + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " + + s"using orc") + sql( + s""" + |INSERT INTO $testTable VALUES + | (1, 'amy', 10000, 1000), + | (2, 'alex', 12000, 1200), + | (1, 'cathy', 9000, 1200), + | (2, 'david', 10000, 1300), + | (6, 'jen', 12000, 1200) + |""".stripMargin) + } + + override def afterAll(): Unit = { + try { + sql(s"DROP TABLE IF EXISTS $testTable") + } finally { + super.afterAll() + } + } + + val lcaEnabled: Boolean = true + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + + test("Lateral alias in project") { + checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), + Row(1, 2)) + + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"), + Row(20000, 21000)) + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 23000)) + + // When the lateral alias conflicts with the table column, it should resolved as the table + // column + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + + // Corner cases for resolution order + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + } +} From 660e1d231b641c65c979199ed37d57f52db2a3ea Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 10:54:58 -0800 Subject: [PATCH 04/17] move lca rule to a new file --- .../sql/catalyst/analysis/Analyzer.scala | 101 +------------- .../analysis/ResolveLateralColumnAlias.scala | 127 ++++++++++++++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 2 +- 3 files changed, 129 insertions(+), 101 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 165bb2ecf4ec2..95101c0d8130b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -1552,105 +1552,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list, - * - in Project inserting a new Project node with the referenced alias so that it can be - * resolved by other rules - * - in Aggregate TODO. - * - * For Project, it rewrites by inserting a newly created Project plan between the original Project - * and its child, pushing the referenced lateral column aliases to this new Project, and updating - * the project list of the original Project. - * - * Before rewrite: - * Project [age AS a, 'a + 1] - * +- Child - * - * After rewrite: - * Project [a, 'a + 1] - * +- Project [child output, age AS a] - * +- Child - * - * For Aggregate TODO. - */ - object ResolveLateralColumnAlias extends Rule[LogicalPlan] { - private case class AliasEntry(alias: Alias, index: Int) - - private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved - && !ResolveReferences.containsStar(projectList) - && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - def lookUpLCA(e: Expression): Option[AliasEntry] = { - var matchedLCA: Option[AliasEntry] = None - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case _ => - val referencedAlias = aliases.head - // Only resolved alias can be the lateral column alias - if (referencedAlias.alias.resolved) { - matchedLCA = Some(referencedAlias) - } - } - u - } - matchedLCA - } - - val referencedAliases = projectList.zipWithIndex.flatMap { - case (a: Alias, idx) => - // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed - // down. Unresolved alias is added to the map to perform the ambiguous name check. - // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, - // because only resolved alias can be LCA, in the first round the rule application, - // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are - // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, - // it is pushed down. - val matchedLCA = lookUpLCA(a) - insertIntoAliasMap(a, idx) - matchedLCA - case (e, _) => - lookUpLCA(e) - }.toSet - - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = projectList.to[collection.mutable.Seq] - val innerProjectList = - child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { - plan - } else { - rewriteLateralColumnAlias(plan) - } - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala new file mode 100644 index 0000000000000..ea2648ccde553 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +/** + * Resolve lateral column alias, which references the alias defined previously in the SELECT list, + * - in Project inserting a new Project node with the referenced alias so that it can be + * resolved by other rules + * - in Aggregate TODO. + * + * For Project, it rewrites by inserting a newly created Project plan between the original Project + * and its child, pushing the referenced lateral column aliases to this new Project, and updating + * the project list of the original Project. + * + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After rewrite: + * Project [a, 'a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * For Aggregate TODO. + */ +object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + private case class AliasEntry(alias: Alias, index: Int) + def resolver: Resolver = conf.resolver + + private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !Analyzer.containsStar(projectList) + && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None + e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + Analyzer.resolveExpressionByPlanChildren(u, p, resolver) + .isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case _ => + val referencedAlias = aliases.head + // Only resolved alias can be the lateral column alias + if (referencedAlias.alias.resolved) { + matchedLCA = Some(referencedAlias) + } + } + u + } + matchedLCA + } + + val referencedAliases = projectList.zipWithIndex.flatMap { + case (a: Alias, idx) => + // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed + // down. Unresolved alias is added to the map to perform the ambiguous name check. + // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, + // because only resolved alias can be LCA, in the first round the rule application, + // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are + // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, + // it is pushed down. + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + matchedLCA + case (e, _) => + lookUpLCA(e) + }.toSet + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = projectList.to[collection.mutable.Seq] + val innerProjectList = + child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + plan + } else { + rewriteLateralColumnAlias(plan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 8493895218f79..032b0e7a08fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -66,7 +66,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences" :: - "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: @@ -89,6 +88,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAlias" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: From fd0609438d99643d147afad854e8206624437278 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 11:44:35 -0800 Subject: [PATCH 05/17] rename conf --- .../catalyst/analysis/ResolveLateralColumnAlias.scala | 2 +- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 9 ++++++--- .../org/apache/spark/sql/LateralColumnAliasSuite.scala | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index ea2648ccde553..2b435f1c460a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -118,7 +118,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED)) { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan } else { rewriteLateralColumnAlias(plan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a5b84660e0581..575775a0f5519 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,10 +4027,13 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) - val LATERAL_COLUMN_ALIAS_ENABLED = - buildConf("spark.sql.lateralColumnAlias.enabled") + val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution") .internal() - .doc("Enable lateral column alias in analyzer") + .doc("Enable resolving implicit lateral column alias defined in the same SELECT list. For " + + "example, with this conf turned on, for query `SELECT 1 AS a, a + 1` the `a` in `a + 1` " + + "can be resolved as the previously defined `1 AS a`. But note that table column has " + + "higher resolution priority than the lateral column alias.") .version("3.4.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index daf750c39bb1e..f6b2c919b794c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -53,7 +53,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { override protected def test(testName: String, testTags: Tag*)(testFun: => Any) (implicit pos: Position): Unit = { super.test(testName, testTags: _*) { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { testFun } } From 7d4f80f4c74a77dfceb77b4d86d36cd83d63d9c5 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 16:07:21 -0800 Subject: [PATCH 06/17] test failure --- .../sql/catalyst/analysis/ResolveLateralColumnAlias.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 2b435f1c460a0..e1372664b791e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -102,9 +102,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { if (referencedAliases.isEmpty) { p } else { - val outerProjectList = projectList.to[collection.mutable.Seq] + val outerProjectList = collection.mutable.Seq(projectList: _*) val innerProjectList = - child.output.map(_.asInstanceOf[NamedExpression]).to[collection.mutable.ArrayBuffer] + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => outerProjectList.update(idx, alias.toAttribute) innerProjectList += alias From b9704d5428fa2f25de9b6da076c972168ee0477d Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 29 Nov 2022 10:27:43 -0800 Subject: [PATCH 07/17] small fix --- .../sql/catalyst/analysis/ResolveLateralColumnAlias.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index e1372664b791e..5462cee65fd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -61,8 +61,8 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def lookUpLCA(e: Expression): Option[AliasEntry] = { - var matchedLCA: Option[AliasEntry] = None + def lookUpLCA(e: Expression): Seq[AliasEntry] = { + var matchedLCA: Seq[AliasEntry] = Seq.empty[AliasEntry] e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) @@ -75,7 +75,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias if (referencedAlias.alias.resolved) { - matchedLCA = Some(referencedAlias) + matchedLCA :+= referencedAlias } } u From 5785943fbb53b525b4434b7566d4f466461ceb61 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 1 Dec 2022 18:36:33 -0800 Subject: [PATCH 08/17] make changes to accomodate the recent refactor --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../analysis/ResolveLateralColumnAlias.scala | 64 ++++++--- .../expressions/namedExpressions.scala | 12 +- .../sql/catalyst/expressions/subquery.scala | 4 +- .../spark/sql/LateralColumnAliasSuite.scala | 131 ++++++++++++++++++ 5 files changed, 188 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d5d87b768770..3bc98c68d8486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1844,7 +1844,7 @@ class Analyzer(override val catalogManager: CatalogManager) // Only Project and Aggregate can host star expressions. case u @ (_: Project | _: Aggregate) => Try(s.expand(u.children.head, resolver)) match { - case Success(expanded) => expanded.map(wrapOuterReference) + case Success(expanded) => expanded.map(wrapOuterReference(_)) case Failure(_) => throw e } // Do not use the outer plan to resolve the star expression @@ -2165,7 +2165,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { try { AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match { - case Some(resolved) => wrapOuterReference(resolved) + case Some(resolved) => wrapOuterReference(resolved, Some(nameParts)) case None => u } } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 5462cee65fd09..37e93f7c9105b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -31,6 +31,14 @@ import org.apache.spark.sql.internal.SQLConf * resolved by other rules * - in Aggregate TODO. * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and restores it back to + * [[UnresolvedAttribute]] + * * For Project, it rewrites by inserting a newly created Project plan between the original Project * and its child, pushing the referenced lateral column aliases to this new Project, and updating * the project list of the original Project. @@ -51,19 +59,35 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE), ruleId) { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.childrenResolved && !Analyzer.containsStar(projectList) - && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) def insertIntoAliasMap(a: Alias, idx: Int): Unit = { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def lookUpLCA(e: Expression): Seq[AliasEntry] = { - var matchedLCA: Seq[AliasEntry] = Seq.empty[AliasEntry] - e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def resolveLCA(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + val name = o.nameParts.map(_.head).getOrElse(o.name) + val aliases = aliasMap.get(name).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(o.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + referencedAliases += aliases.head + o.nameParts.map(UnresolvedAttribute(_)).getOrElse(UnresolvedAttribute(o.name)) + case _ => + o + } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) .isInstanceOf[UnresolvedAttribute] => @@ -71,19 +95,15 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { aliases.size match { case n if n > 1 => throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case _ => - val referencedAlias = aliases.head + case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - if (referencedAlias.alias.resolved) { - matchedLCA :+= referencedAlias - } + referencedAliases += aliases.head + case _ => } u - } - matchedLCA + }.asInstanceOf[NamedExpression] } - - val referencedAliases = projectList.zipWithIndex.flatMap { + val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed // down. Unresolved alias is added to the map to perform the ambiguous name check. @@ -92,17 +112,17 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, // it is pushed down. - val matchedLCA = lookUpLCA(a) - insertIntoAliasMap(a, idx) - matchedLCA + val lcaResolved = resolveLCA(a).asInstanceOf[Alias] + insertIntoAliasMap(lcaResolved, idx) + lcaResolved case (e, _) => - lookUpLCA(e) - }.toSet + resolveLCA(e) + } if (referencedAliases.isEmpty) { p } else { - val outerProjectList = collection.mutable.Seq(projectList: _*) + val outerProjectList = collection.mutable.Seq(newProjectList: _*) val innerProjectList = collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8dd28e9aaae3d..f83c3aa462614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -424,8 +424,18 @@ case class OuterReference(e: NamedExpression) override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute - override def newInstance(): NamedExpression = OuterReference(e.newInstance()) + override def newInstance(): NamedExpression = + OuterReference(e.newInstance()).setNameParts(nameParts) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) + + // optional field of the original name parts of UnresolvedAttribute before it is resolved to + // OuterReference. Used in rule ResolveLateralColumnAlias to restore OuterReference back to + // UnresolvedAttribute. + var nameParts: Option[Seq[String]] = None + def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { + nameParts = newNameParts + this + } } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d53e..d249a2b5a6bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -158,8 +158,8 @@ object SubExprUtils extends PredicateHelper { /** * Wrap attributes in the expression with [[OuterReference]]s. */ - def wrapOuterReference[E <: Expression](e: E): E = { - e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E] + def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { + e.transform { case a: Attribute => OuterReference(a).setNameParts(nameParts) }.asInstanceOf[E] } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index f6b2c919b794c..adf18958a1e92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -59,6 +60,17 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } + private def withLCAOff(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { + f + } + } + private def withLCAOn(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { + f + } + } + test("Lateral alias in project") { checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), Row(1, 2)) @@ -106,4 +118,123 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(18000, 18000, 10000) ) } + + test("Duplicated lateral alias names - Project") { + def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) + } + + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 1200) + ) + checkAnswer( + sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 12000, 10000) + ) + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + checkAnswer( + sql(s"SELECT salary + 1000 AS new_salary, new_salary * 1.0 AS new_salary " + + s"FROM $testTable WHERE name = 'jen'"), + Row(13000, 13000.0)) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql( + s""" + |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary + |FROM $testTable + |WHERE name = 'jen' + |""".stripMargin), + Row(18000, 12000, 10000, 12000) + ) + } + + test("Lateral alias conflicts with OuterReference - Project") { + // an attribute can both be resolved as LCA and OuterReference + val query1 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, id + 1 AS id2)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { checkAnswer(sql(query1), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query1), Seq.empty) } + + // an attribute can only be resolved as LCA + val query2 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id1, id1 + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { + assert(intercept[AnalysisException] { sql(query2) } + .getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + } + withLCAOn { checkAnswer(sql(query2), Seq.empty) } + + // an attribute should only be resolved as OuterReference + val query3 = + s""" + |SELECT * + |FROM range(1, 7) outer_table + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, outer_table.id + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + + // a bit complex subquery that the id + 1 is first wrapped with OuterReference + // test if lca rule strips the OuterReference and resolves to lateral alias + val query4 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOn { + val analyzedPlan = sql(query4).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + // but running it triggers exception + // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) + } + } + // TODO: LCA in subquery } From 757cffb4f0adbb512eb4738fe3e38eea943b474a Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 5 Dec 2022 11:48:19 -0800 Subject: [PATCH 09/17] introduce leaf exp in Project as well --- .../analysis/ResolveLateralColumnAlias.scala | 162 ++++++++++++------ .../expressions/namedExpressions.scala | 36 +++- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 9 + .../spark/sql/LateralColumnAliasSuite.scala | 89 ++++++---- 5 files changed, 211 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 37e93f7c9105b..a674c8cdbb423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,76 +17,90 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression, OuterReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} +import org.apache.spark.sql.catalyst.expressions.{Alias, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list, - * - in Project inserting a new Project node with the referenced alias so that it can be - * resolved by other rules + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases * - in Aggregate TODO. * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and restores it back to - * [[UnresolvedAttribute]] - * - * For Project, it rewrites by inserting a newly created Project plan between the original Project - * and its child, pushing the referenced lateral column aliases to this new Project, and updating - * the project list of the original Project. + * The whole process is generally divided into two phases: + * 1) recognize lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO * + * Example for Project: * Before rewrite: * Project [age AS a, 'a + 1] * +- Child * - * After rewrite: - * Project [a, 'a + 1] + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] * +- Project [child output, age AS a] * +- Child * - * For Aggregate TODO. + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + def resolver: Resolver = conf.resolver private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsUpWithPruning( + // phase 1: wrap + val rewrittenPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.childrenResolved && !Analyzer.containsStar(projectList) && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def insertIntoAliasMap(a: Alias, idx: Int): Unit = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def resolveLCA(e: NamedExpression): NamedExpression = { + def wrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { case o: OuterReference if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - val name = o.nameParts.map(_.head).getOrElse(o.name) - val aliases = aliasMap.get(name).get + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(o.name, n) + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - referencedAliases += aliases.head - o.nameParts.map(UnresolvedAttribute(_)).getOrElse(UnresolvedAttribute(o.name)) - case _ => - o + // TODO We need to resolve to the nested field type, e.g. for query + // SELECT named_struct() AS foo, foo.a, we can't say this foo.a is the + // LateralColumnAliasReference(foo, foo.a). Otherwise, the type can be mismatched + LateralColumnAliasReference(aliases.head.alias, nameParts) + case _ => o } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) @@ -97,28 +111,74 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - referencedAliases += aliases.head - case _ => + // TODO similar problem + LateralColumnAliasReference(aliases.head.alias, u.nameParts) + case _ => u } - u }.asInstanceOf[NamedExpression] } val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed - // down. Unresolved alias is added to the map to perform the ambiguous name check. - // If there is a chain of LCA, for example, SELECT 1 AS a, 'a + 1 AS b, 'b + 1 AS c, - // because only resolved alias can be LCA, in the first round the rule application, - // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are - // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, - // it is pushed down. - val lcaResolved = resolveLCA(a).asInstanceOf[Alias] - insertIntoAliasMap(lcaResolved, idx) - lcaResolved + val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but only + // resolved alias can be LCA + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped case (e, _) => - resolveLCA(e) + wrapLCAReference(e) + } + p.copy(projectList = newProjectList) + } + + // phase 2: unwrap + rewrittenPlan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + // build the map again in case the project list changes and index goes off + // TODO one risk: is there any rule that strips off the Alias? that the LCA is resolved + // in the beginning, but when it comes to push down, it really can't find the matching one? + // Restore back to UnresolvedAttribute + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.nameParts.head) => + val aliasEntry = aliasMap.get(lcaRef.nameParts.head).get.head + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + // If there is no chaining, push down the alias and resolve the attribute by + // constructing a dummy plan + referencedAliases += aliasEntry + // Implementation notes (to-delete): + // this is a design decision whether to restore the UnresolvedAttribute, or + // directly resolve by constructing a plan and using resolveExpressionByPlanChildren + Analyzer.resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(lcaRef.nameParts), + plan = Project(Seq(aliasEntry.alias), OneRowRelation()), + resolver = resolver, + throws = false + ) + } else { + // If there is chaining, don't resolve and save to future rounds + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.nameParts.head) => + // It shouldn't happen. Restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.name) + }.asInstanceOf[NamedExpression] } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap = insertIntoAliasMap(a, idx, aliasMap) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } if (referencedAliases.isEmpty) { p } else { @@ -134,7 +194,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { child = Project(innerProjectList.toSeq, child) ) } - } + } } override def apply(plan: LogicalPlan): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f83c3aa462614..4a3e5a6487f13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -428,9 +428,9 @@ case class OuterReference(e: NamedExpression) OuterReference(e.newInstance()).setNameParts(nameParts) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) - // optional field of the original name parts of UnresolvedAttribute before it is resolved to - // OuterReference. Used in rule ResolveLateralColumnAlias to restore OuterReference back to - // UnresolvedAttribute. + // optional field, the original name parts of UnresolvedAttribute before it is resolved to + // OuterReference. Used in rule ResolveLateralColumnAlias to convert OuterReference back to + // LateralColumnAliasReference. var nameParts: Option[Seq[String]] = None def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { nameParts = newNameParts @@ -438,6 +438,36 @@ case class OuterReference(e: NamedExpression) } } +/** + * A placeholder used to hold a referenced that has been temporarily resolved as the reference + * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. + * + * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all + * analysis check, then all [[LateralColumnAliasReference]] should already be removed. + * + * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute + * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to resolve + * the attribute, or restore back. + */ +case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) + extends LeafExpression with NamedExpression with Unevaluable { + assert(a.resolved) + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = a.exprId + override def qualifier: Seq[String] = a.qualifier + override def toAttribute: Attribute = a.toAttribute + override def newInstance(): NamedExpression = + LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) + + override def nullable: Boolean = a.nullable + override def dataType: DataType = a.dataType + override def prettyName: String = "lateralAliasReference" + override def sql: String = s"$prettyName($name)" + + final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8fca9ec60cdff..1a8ad7c7d6213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 25c509732f9b6..209a80fee2ff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3413,4 +3413,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) ) } + def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(nameParts), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index adf18958a1e92..d78a661c5a7e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -71,7 +71,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Lateral alias in project") { + test("Lateral alias basics - Project") { checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), Row(1, 2)) @@ -91,28 +91,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { s"new_income from $testTable where name = 'amy'"), Row(20000, 23000)) - // When the lateral alias conflicts with the table column, it should resolved as the table - // column - checkAnswer( - sql( - "select salary * 2 as salary, salary * 2 + bonus as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 21000)) - - checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 22000)) - - checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + - s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + - " where name = 'amy'"), - Row(20000, 22000, 11000, 22000)) - - // Corner cases for resolution order + // should referring to the previously defined LCA checkAnswer( sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), Row(18000, 18000, 10000) @@ -176,6 +155,27 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + } + test("Lateral alias conflicts with OuterReference - Project") { // an attribute can both be resolved as LCA and OuterReference val query1 = @@ -220,14 +220,14 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // a bit complex subquery that the id + 1 is first wrapped with OuterReference // test if lca rule strips the OuterReference and resolves to lateral alias val query4 = - s""" - |SELECT * - |FROM range(1, 7) - |WHERE ( - | SELECT id2 - | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 - |ORDER BY id - |""".stripMargin + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed @@ -236,5 +236,30 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) } } - // TODO: LCA in subquery + // TODO: more tests on LCA in subquery + + test("Lateral alias of a struct - Project") { + // This test fails now +// checkAnswer( +// sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), +// Row(Row(1), 2)) + } + + test("Lateral alias chaining - Project") { + checkAnswer( + sql( + s""" + |SELECT bonus * 1.1 AS new_bonus, salary + new_bonus AS new_base, + | new_base * 1.1 AS new_total, new_total - new_base AS r, + | new_total - r + |FROM $testTable WHERE name = 'cathy' + |""".stripMargin), + Row(1320, 10320, 11352, 1032, 10320) + ) + + checkAnswer( + sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), + Row(1, 2, 1, 3, 4, 3, 4) + ) + } } From 29de892ba1c76f11684167fae78a6abb91165750 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 5 Dec 2022 14:10:56 -0800 Subject: [PATCH 10/17] handle a corner case --- .../analysis/ResolveLateralColumnAlias.scala | 47 ++++++++++++------- .../expressions/namedExpressions.scala | 24 +++++----- .../spark/sql/LateralColumnAliasSuite.scala | 8 ++-- 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index a674c8cdbb423..7a9b6d43c8c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf * - in Aggregate TODO. * * The whole process is generally divided into two phases: - * 1) recognize lateral alias, wrap the attributes referencing them with + * 1) recognize resolved lateral alias, wrap the attributes referencing them with * [[LateralColumnAliasReference]] * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. * For Project, it further resolves the attributes and push down the referenced lateral aliases. @@ -64,7 +64,10 @@ import org.apache.spark.sql.internal.SQLConf * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with * [[LateralColumnAliasReference]]. */ +// TODO revisit resolving order: top down, or bottom up object ResolveLateralColumnAlias extends Rule[LogicalPlan] { + def resolver: Resolver = conf.resolver + private case class AliasEntry(alias: Alias, index: Int) private def insertIntoAliasMap( a: Alias, @@ -74,7 +77,20 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def resolver: Resolver = conf.resolver + /** + * Use the given the lateral alias candidate to resolve the name parts. + * @return The resolved attribute if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[NamedExpression] = { + val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = Project(Seq(lateralAlias), OneRowRelation()), + resolver = resolver, + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) Some(resolvedAttr) else None + } private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { // phase 1: wrap @@ -96,10 +112,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - // TODO We need to resolve to the nested field type, e.g. for query - // SELECT named_struct() AS foo, foo.a, we can't say this foo.a is the - // LateralColumnAliasReference(foo, foo.a). Otherwise, the type can be mismatched - LateralColumnAliasReference(aliases.head.alias, nameParts) + resolveByLateralAlias(nameParts, aliases.head.alias) + .map(LateralColumnAliasReference(_, nameParts)) + .getOrElse(o) case _ => o } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && @@ -111,8 +126,9 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - // TODO similar problem - LateralColumnAliasReference(aliases.head.alias, u.nameParts) + resolveByLateralAlias(u.nameParts, aliases.head.alias) + .map(LateralColumnAliasReference(_, u.nameParts)) + .getOrElse(u) case _ => u } }.asInstanceOf[NamedExpression] @@ -138,9 +154,13 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => // build the map again in case the project list changes and index goes off - // TODO one risk: is there any rule that strips off the Alias? that the LCA is resolved + // TODO one risk: is there any rule that strips off /add the Alias? that the LCA is resolved // in the beginning, but when it comes to push down, it really can't find the matching one? - // Restore back to UnresolvedAttribute + // Restore back to UnresolvedAttribute. + // Also, when resolving from bottom up should I worry about cases like: + // Project [b AS c, c + 1 AS d] + // +- Project [1 AS a, a AS b] + // b AS c is resolved, even b refers to an alias contains the lateral alias? var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { @@ -154,12 +174,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // Implementation notes (to-delete): // this is a design decision whether to restore the UnresolvedAttribute, or // directly resolve by constructing a plan and using resolveExpressionByPlanChildren - Analyzer.resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(lcaRef.nameParts), - plan = Project(Seq(aliasEntry.alias), OneRowRelation()), - resolver = resolver, - throws = false - ) + resolveByLateralAlias(lcaRef.nameParts, aliasEntry.alias).getOrElse(lcaRef) } else { // If there is chaining, don't resolve and save to future rounds lcaRef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 4a3e5a6487f13..7972cd1399f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -439,29 +439,29 @@ case class OuterReference(e: NamedExpression) } /** - * A placeholder used to hold a referenced that has been temporarily resolved as the reference + * A placeholder used to hold a attribute that has been temporarily resolved as the reference * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. * * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all * analysis check, then all [[LateralColumnAliasReference]] should already be removed. * - * @param a A resolved [[Alias]] that is a lateral column alias referenced by the current attribute - * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to resolve - * the attribute, or restore back. + * @param ne the current attribute is resolved to by lateral column alias + * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to later resolve + * the attribute or restore back. */ -case class LateralColumnAliasReference(a: Alias, nameParts: Seq[String]) +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String]) extends LeafExpression with NamedExpression with Unevaluable { - assert(a.resolved) + assert(ne.resolved) override def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") - override def exprId: ExprId = a.exprId - override def qualifier: Seq[String] = a.qualifier - override def toAttribute: Attribute = a.toAttribute + override def exprId: ExprId = ne.exprId + override def qualifier: Seq[String] = ne.qualifier + override def toAttribute: Attribute = ne.toAttribute override def newInstance(): NamedExpression = - LateralColumnAliasReference(a.newInstance().asInstanceOf[Alias], nameParts) + LateralColumnAliasReference(ne.newInstance(), nameParts) - override def nullable: Boolean = a.nullable - override def dataType: DataType = a.dataType + override def nullable: Boolean = ne.nullable + override def dataType: DataType = ne.dataType override def prettyName: String = "lateralAliasReference" override def sql: String = s"$prettyName($name)" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index d78a661c5a7e4..17b7f5750697b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -239,10 +239,10 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { // TODO: more tests on LCA in subquery test("Lateral alias of a struct - Project") { - // This test fails now -// checkAnswer( -// sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), -// Row(Row(1), 2)) + checkAnswer( + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), + Row(Row(1), 2)) + // TODO: more tests } test("Lateral alias chaining - Project") { From 72991c6210b34ef3a0af3e7b2c075f73812f89cb Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 6 Dec 2022 11:46:44 -0800 Subject: [PATCH 11/17] add more tests; add check rule --- .../sql/catalyst/analysis/CheckAnalysis.scala | 21 +++++++++- .../spark/sql/LateralColumnAliasSuite.scala | 42 ++++++++++++++----- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 12dac5c632a3b..9937a06de9a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -638,6 +638,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case UnresolvedWindowExpression(_, windowSpec) => throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name) }) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + projectList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if p.resolved => + failUnresolvedAttribute( + p, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + }) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -714,6 +722,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "operator" -> other.nodeName, "invalidExprSqls" -> invalidExprSqls.mkString(", "))) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + case agg @ Aggregate(_, aggList, _) + if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved => + aggList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference => + failUnresolvedAttribute( + agg, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + }) + case _ => // Analysis successful! } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 17b7f5750697b..5c0c120a87d2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -29,16 +29,24 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { override def beforeAll(): Unit = { super.beforeAll() - sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " + - s"using orc") + sql( + s""" + |CREATE TABLE $testTable ( + | dept INTEGER, + | name String, + | salary INTEGER, + | bonus INTEGER, + | properties STRUCT) + |USING orc + |""".stripMargin) sql( s""" |INSERT INTO $testTable VALUES - | (1, 'amy', 10000, 1000), - | (2, 'alex', 12000, 1200), - | (1, 'cathy', 9000, 1200), - | (2, 'david', 10000, 1300), - | (6, 'jen', 12000, 1200) + | (1, 'amy', 10000, 1000, named_struct('joinYear', 2019, 'mostRecentEmployer', 'A')), + | (2, 'alex', 12000, 1200, named_struct('joinYear', 2017, 'mostRecentEmployer', 'A')), + | (1, 'cathy', 9000, 1200, named_struct('joinYear', 2020, 'mostRecentEmployer', 'B')), + | (2, 'david', 10000, 1300, named_struct('joinYear', 2019, 'mostRecentEmployer', 'C')), + | (6, 'jen', 12000, 1200, named_struct('joinYear', 2018, 'mostRecentEmployer', 'D')) |""".stripMargin) } @@ -174,6 +182,16 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + " where name = 'amy'"), Row(20000, 22000, 11000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) } test("Lateral alias conflicts with OuterReference - Project") { @@ -240,9 +258,13 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { test("Lateral alias of a struct - Project") { checkAnswer( - sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar"), - Row(Row(1), 2)) - // TODO: more tests + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), + Row(Row(1), 2, 3)) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), + Row(Row(Row(1)), 2) + ) } test("Lateral alias chaining - Project") { From d45fe31f0aec6ddb670a012e53495554f03c05cb Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 14:48:11 -0800 Subject: [PATCH 12/17] uplift the necessity to resolve expression in second phase; add more tests --- .../analysis/ResolveLateralColumnAlias.scala | 91 +++++++++---------- .../expressions/namedExpressions.scala | 17 ++-- .../spark/sql/LateralColumnAliasSuite.scala | 30 +++++- 3 files changed, 82 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 7a9b6d43c8c17..ad6867042ffa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression, OuterReference} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} @@ -64,7 +64,6 @@ import org.apache.spark.sql.internal.SQLConf * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with * [[LateralColumnAliasReference]]. */ -// TODO revisit resolving order: top down, or bottom up object ResolveLateralColumnAlias extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver @@ -78,18 +77,27 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { } /** - * Use the given the lateral alias candidate to resolve the name parts. - * @return The resolved attribute if succeeds. None if fails to resolve. + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. */ private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[NamedExpression] = { + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + // TODO question: everytime it resolves the extract field it generates a new exprId. + // Does it matter? val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( expr = UnresolvedAttribute(nameParts), plan = Project(Seq(lateralAlias), OneRowRelation()), resolver = resolver, throws = false ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) Some(resolvedAttr) else None + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } } private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { @@ -103,20 +111,6 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) def wrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - resolveByLateralAlias(nameParts, aliases.head.alias) - .map(LateralColumnAliasReference(_, nameParts)) - .getOrElse(o) - case _ => o - } case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && Analyzer.resolveExpressionByPlanChildren(u, p, resolver) .isInstanceOf[UnresolvedAttribute] => @@ -126,11 +120,23 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias - resolveByLateralAlias(u.nameParts, aliases.head.alias) - .map(LateralColumnAliasReference(_, u.nameParts)) - .getOrElse(u) + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) case _ => u } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } }.asInstanceOf[NamedExpression] } val newProjectList = projectList.zipWithIndex.map { @@ -139,7 +145,7 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but only - // resolved alias can be LCA + // resolved alias can be LCA. aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => @@ -153,47 +159,36 @@ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - // build the map again in case the project list changes and index goes off - // TODO one risk: is there any rule that strips off /add the Alias? that the LCA is resolved - // in the beginning, but when it comes to push down, it really can't find the matching one? - // Restore back to UnresolvedAttribute. - // Also, when resolving from bottom up should I worry about cases like: - // Project [b AS c, c + 1 AS d] - // +- Project [1 AS a, a AS b] - // b AS c is resolved, even b refers to an alias contains the lateral alias? - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + var aliasMap = Map[Attribute, AliasEntry]() val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.nameParts.head) => - val aliasEntry = aliasMap.get(lcaRef.nameParts.head).get.head + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap(lcaRef.a) + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - // If there is no chaining, push down the alias and resolve the attribute by - // constructing a dummy plan referencedAliases += aliasEntry - // Implementation notes (to-delete): - // this is a design decision whether to restore the UnresolvedAttribute, or - // directly resolve by constructing a plan and using resolveExpressionByPlanChildren - resolveByLateralAlias(lcaRef.nameParts, aliasEntry.alias).getOrElse(lcaRef) + lcaRef.ne } else { - // If there is chaining, don't resolve and save to future rounds lcaRef } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.nameParts.head) => - // It shouldn't happen. Restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.name) + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) }.asInstanceOf[NamedExpression] } - val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => val lcaResolved = unwrapLCAReference(a) // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap = insertIntoAliasMap(a, idx, aliasMap) + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) lcaResolved case (e, _) => unwrapLCAReference(e) } + if (referencedAliases.isEmpty) { p } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7972cd1399f2d..ff65eecafc48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -439,17 +439,20 @@ case class OuterReference(e: NamedExpression) } /** - * A placeholder used to hold a attribute that has been temporarily resolved as the reference - * to a lateral column alias. This is created and removed by rule [[ResolveLateralColumnAlias]]. + * A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the + * reference to a lateral column alias. * + * This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]]. * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all * analysis check, then all [[LateralColumnAliasReference]] should already be removed. * - * @param ne the current attribute is resolved to by lateral column alias - * @param nameParts The named parts of the original [[UnresolvedAttribute]]. Used to later resolve - * the attribute or restore back. + * @param ne the resolved [[NamedExpression]] by lateral column alias + * @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back + * to [[UnresolvedAttribute]] when needed + * @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping + * and resolving LateralColumnAliasReference */ -case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String]) +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute) extends LeafExpression with NamedExpression with Unevaluable { assert(ne.resolved) override def name: String = @@ -458,7 +461,7 @@ case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[Strin override def qualifier: Seq[String] = ne.qualifier override def toAttribute: Attribute = ne.toAttribute override def newInstance(): NamedExpression = - LateralColumnAliasReference(ne.newInstance(), nameParts) + LateralColumnAliasReference(ne.newInstance(), nameParts, a) override def nullable: Boolean = ne.nullable override def dataType: DataType = ne.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 5c0c120a87d2c..3c528e5997e8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -256,7 +256,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } // TODO: more tests on LCA in subquery - test("Lateral alias of a struct - Project") { + test("Lateral alias of a complex type - Project") { checkAnswer( sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), Row(Row(1), 2, 3)) @@ -265,6 +265,34 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), Row(Row(Row(1)), 2) ) + + checkAnswer( + sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), + Row(Seq(1, 2, 3), 2, 3) + ) + checkAnswer( + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) + ) + checkAnswer( + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), + Row(Seq(Row(1), Row(2)), 2) + ) + + checkAnswer( + sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), + Row(Map("a" -> 1, "b" -> 2), 2, 3) + ) + } + + test("Lateral alias reference attribute further be used by upper plan - Project") { + // this is out of the scope of lateral alias project functionality requirements, but naturally + // supported by the current design + checkAnswer( + sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), + Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + ) } test("Lateral alias chaining - Project") { From 1f55f7381e728b0feff5fd89e71b8a4fc1c60ccd Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 15:25:07 -0800 Subject: [PATCH 13/17] address comments to add tests for LCA off --- .../spark/sql/LateralColumnAliasSuite.scala | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 3c528e5997e8e..abeb3bb784124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -59,6 +59,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } val lcaEnabled: Boolean = true + // by default the tests in this suites run with LCA on override protected def test(testName: String, testTags: Tag*)(testFun: => Any) (implicit pos: Position): Unit = { super.test(testName, testTags: _*) { @@ -67,6 +68,11 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } } + // mark special testcases test both LCA on and off + protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*)(testFun) + } private def withLCAOff(f: => Unit): Unit = { withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { @@ -79,29 +85,35 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - test("Lateral alias basics - Project") { - checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), + testOnAndOff("Lateral alias basics - Project") { + def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + } + + checkAnswerWhenOnAndExceptionWhenOff( + s"select dept as d, d + 1 as e from $testTable where name = 'amy'", Row(1, 2)) - checkAnswer( - sql( - s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", Row(20000, 21000)) - checkAnswer( - sql( - s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + - s" where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'", Row(20000, 22000)) - checkAnswer( - sql( - "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + - s"new_income from $testTable where name = 'amy'"), + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'", Row(20000, 23000)) // should referring to the previously defined LCA - checkAnswer( - sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", Row(18000, 18000, 10000) ) } @@ -194,7 +206,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { Row(Row("someone"), "amy")) } - test("Lateral alias conflicts with OuterReference - Project") { + testOnAndOff("Lateral alias conflicts with OuterReference - Project") { // an attribute can both be resolved as LCA and OuterReference val query1 = s""" From f753529afe5ca21543a2d1915fcdbe6f63f218d4 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 17:55:40 -0800 Subject: [PATCH 14/17] revert the refactor, split LCA into two rules --- .../sql/catalyst/analysis/Analyzer.scala | 424 +++++++++++------- .../analysis/ResolveLateralColumnAlias.scala | 217 --------- .../ResolveLateralColumnAliasReference.scala | 127 ++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 3 +- 4 files changed, 380 insertions(+), 391 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3bc98c68d8486..6bbf2de445418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} -import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -183,164 +182,6 @@ object AnalysisContext { } } -object Analyzer extends Logging { - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - resolver: Resolver, - throws: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { - if (e.resolved) return e - e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - withOrigin(u.origin) { - ExtractValue(newChild, fieldName, resolver) - } - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - } - - try { - innerResolve(expr, isTopLevel = true) - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - resolver: Resolver, - throws: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - resolver = resolver, - throws = throws) - } - - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan, - resolver: Resolver): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - resolver = resolver, - throws = true) - } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) -} - /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -417,17 +258,6 @@ class Analyzer(override val catalogManager: CatalogManager) TypeCoercion.typeCoercionRules } - private def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false): Expression = { - Analyzer.resolveExpressionByPlanOutput(expr, plan, resolver, throws) - } - - private def resolveExpressionByPlanChildren(e: Expression, q: LogicalPlan): Expression = { - Analyzer.resolveExpressionByPlanChildren(e, q, resolver) - } - override def batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -458,7 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: - ResolveLateralColumnAlias :: + WrapLateralColumnAliasReference :: + ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1558,7 +1389,6 @@ class Analyzer(override val catalogManager: CatalogManager) * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.analysis.Analyzer.containsStar /** Return true if there're conflicting attributes among children's outputs of a plan */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { @@ -1871,6 +1701,12 @@ class Analyzer(override val catalogManager: CatalogManager) }.map(_.asInstanceOf[NamedExpression]) } + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) @@ -1927,10 +1763,252 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * The first phase to resolve lateral column alias. See comments in + * [[ResolveLateralColumnAliasReference]] for more detailed explanation. + */ + object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { + import ResolveLateralColumnAliasReference.AliasEntry + + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + + /** + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + // TODO question: everytime it resolves the extract field it generates a new exprId. + // Does it matter? + val resolvedAttr = resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = Project(Seq(lateralAlias), OneRowRelation()), + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 1: wrap + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + def wrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but + // only resolved alias can be LCA. + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCAReference(e) + } + p.copy(projectList = newProjectList) + } + } + } + } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + throws: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + withOrigin(u.origin) { + ExtractValue(newChild, fieldName, resolver) + } + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + try { + innerResolve(expr, isTopLevel = true) + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, resolver) + }, + getAttrCandidates = () => plan.output, + throws = throws) + } + + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + throws = true) + } + /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala deleted file mode 100644 index ad6867042ffa2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression, OuterReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project} -import org.apache.spark.sql.catalyst.rules.{Rule, UnknownRuleId} -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf - -/** - * Resolve lateral column alias, which references the alias defined previously in the SELECT list. - * Plan-wise it handles two types of operators: Project and Aggregate. - * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve - * the attributes referencing these aliases - * - in Aggregate TODO. - * - * The whole process is generally divided into two phases: - * 1) recognize resolved lateral alias, wrap the attributes referencing them with - * [[LateralColumnAliasReference]] - * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. - * For Project, it further resolves the attributes and push down the referenced lateral aliases. - * For Aggregate, TODO - * - * Example for Project: - * Before rewrite: - * Project [age AS a, 'a + 1] - * +- Child - * - * After phase 1: - * Project [age AS a, lateralalias(a) + 1] - * +- Child - * - * After phase 2: - * Project [a, a + 1] - * +- Project [child output, age AS a] - * +- Child - * - * Example for Aggregate TODO - * - * - * The name resolution priority: - * local table column > local lateral column alias > outer reference - * - * Because lateral column alias has higher resolution priority than outer reference, it will try - * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an - * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with - * [[LateralColumnAliasReference]]. - */ -object ResolveLateralColumnAlias extends Rule[LogicalPlan] { - def resolver: Resolver = conf.resolver - - private case class AliasEntry(alias: Alias, index: Int) - private def insertIntoAliasMap( - a: Alias, - idx: Int, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - /** - * Use the given lateral alias to resolve the unresolved attribute with the name parts. - * - * Construct a dummy plan with the given lateral alias as project list, use the output of the - * plan to resolve. - * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. - */ - private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - // TODO question: everytime it resolves the extract field it generates a new exprId. - // Does it matter? - val resolvedAttr = Analyzer.resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(nameParts), - plan = Project(Seq(lateralAlias), OneRowRelation()), - resolver = resolver, - throws = false - ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) { - Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) - } else { - None - } - } - - private def rewriteLateralColumnAlias(plan: LogicalPlan): LogicalPlan = { - // phase 1: wrap - val rewrittenPlan = plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved - && !Analyzer.containsStar(projectList) - && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def wrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - Analyzer.resolveExpressionByPlanChildren(u, p, resolver) - .isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] - // Insert the LCA-resolved alias instead of the unresolved one into map. If it is - // resolved, it can be referenced as LCA by later expressions (chaining). - // Unresolved Alias is also added to the map to perform ambiguous name check, but only - // resolved alias can be LCA. - aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) - lcaWrapped - case (e, _) => - wrapLCAReference(e) - } - p.copy(projectList = newProjectList) - } - - // phase 2: unwrap - rewrittenPlan.resolveOperatorsUpWithPruning( - _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), UnknownRuleId) { - case p @ Project(projectList, child) if p.resolved - && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = Map[Attribute, AliasEntry]() - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap(lcaRef.a) - // If there is no chaining of lateral column alias reference, push down the alias - // and unwrap the LateralColumnAliasReference to the NamedExpression inside - // If there is chaining, don't resolve and save to future rounds - if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - referencedAliases += aliasEntry - lcaRef.ne - } else { - lcaRef - } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => - // It shouldn't happen, but restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.nameParts) - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaResolved = unwrapLCAReference(a) - // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap += (a.toAttribute -> AliasEntry(a, idx)) - lcaResolved - case (e, _) => - unwrapLCAReference(e) - } - - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = collection.mutable.Seq(newProjectList: _*) - val innerProjectList = - collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else { - rewriteLateralColumnAlias(plan) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala new file mode 100644 index 0000000000000..cd0e0b86a8e48 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE} +import org.apache.spark.sql.internal.SQLConf + +/** + * This rule is the second phase to resolve lateral column alias. + * + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise, it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases + * - in Aggregate TODO. + * + * The whole process is generally divided into two phases: + * 1) recognize resolved lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO + * + * Example for Project: + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. + */ +object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { + case class AliasEntry(alias: Alias, index: Int) + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 2: unwrap + plan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + var aliasMap = Map[Attribute, AliasEntry]() + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap(lcaRef.a) + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 032b0e7a08fcd..efafd3cfbcde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -77,6 +77,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: @@ -88,7 +89,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: - "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAlias" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: From b9f706f9ea23bf80e9248e61136612b1c6ee363b Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 8 Dec 2022 18:13:15 -0800 Subject: [PATCH 15/17] better refactor --- .../sql/catalyst/analysis/Analyzer.scala | 82 +++++++++++-------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6bbf2de445418..5d94defc68d2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1801,50 +1801,61 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * Recognize all the attributes in the given expression that reference lateral column aliases + * by looking up the alias map. Resolve these attributes and replace by wrapping with + * [[LateralColumnAliasReference]]. + * + * @param currentPlan Because lateral alias has lower resolution priority than table columns, + * the current plan is needed to first try resolving the attribute by its + * children + */ + private def wrapLCARefHelper( + e: NamedExpression, + currentPlan: LogicalPlan, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan } else { - // phase 1: wrap plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, child) if p.childrenResolved + case p @ Project(projectList, _) if p.childrenResolved && !ResolveReferences.containsStar(projectList) && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - def wrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - val lcaWrapped = wrapLCAReference(a).asInstanceOf[Alias] + val lcaWrapped = wrapLCARefHelper(a, p, aliasMap).asInstanceOf[Alias] // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but @@ -1852,7 +1863,7 @@ class Analyzer(override val catalogManager: CatalogManager) aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => - wrapLCAReference(e) + wrapLCARefHelper(e, p, aliasMap) } p.copy(projectList = newProjectList) } @@ -1914,7 +1925,7 @@ class Analyzer(override val catalogManager: CatalogManager) attrCandidates(ordinal) case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => val attrCandidates = getAttrCandidates() val matched = attrCandidates.filter(a => resolver(a.name, colName)) if (matched.length != expectedNumCandidates) { @@ -1985,7 +1996,6 @@ class Analyzer(override val catalogManager: CatalogManager) throws = throws) } - /** * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the * input plan's children output attributes. From 94d5c9ee7c095b40ea5fe676fa50bc7acc5fe885 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Fri, 9 Dec 2022 13:28:07 -0800 Subject: [PATCH 16/17] address comments --- .../spark/sql/catalyst/expressions/AttributeMap.scala | 3 ++- .../spark/sql/catalyst/expressions/AttributeMap.scala | 3 +++ .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 +--- .../analysis/ResolveLateralColumnAliasReference.scala | 8 ++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index c55c542d957de..504b65e3db693 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 3d5d6471d26d4..ac6149f3acc4d 100644 --- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) + override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] = baseMap.values.toMap + (key -> value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d94defc68d2a..a56a3d9cb6bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1787,11 +1787,9 @@ class Analyzer(override val catalogManager: CatalogManager) */ private def resolveByLateralAlias( nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - // TODO question: everytime it resolves the extract field it generates a new exprId. - // Does it matter? val resolvedAttr = resolveExpressionByPlanOutput( expr = UnresolvedAttribute(nameParts), - plan = Project(Seq(lateralAlias), OneRowRelation()), + plan = LocalRelation(Seq(lateralAlias.toAttribute)), throws = false ).asInstanceOf[NamedExpression] if (resolvedAttr.resolved) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index cd0e0b86a8e48..c86d0a6dff0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE} +import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE import org.apache.spark.sql.internal.SQLConf /** @@ -76,12 +76,12 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { case p @ Project(projectList, child) if p.resolved && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = Map[Attribute, AliasEntry]() + var aliasMap = AttributeMap.empty[AliasEntry] val referencedAliases = collection.mutable.Set.empty[AliasEntry] def unwrapLCAReference(e: NamedExpression): NamedExpression = { e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap(lcaRef.a) + val aliasEntry = aliasMap.get(lcaRef.a).get // If there is no chaining of lateral column alias reference, push down the alias // and unwrap the LateralColumnAliasReference to the NamedExpression inside // If there is chaining, don't resolve and save to future rounds From 8d20986ee90145b1f3abeafa48c22d463d5a0c99 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 12 Dec 2022 20:03:51 -0800 Subject: [PATCH 17/17] address comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 15 ++++++++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++++++---- .../ResolveLateralColumnAliasReference.scala | 8 ++++++++ .../catalyst/expressions/namedExpressions.scala | 12 +----------- .../spark/sql/catalyst/expressions/subquery.scala | 7 ++++++- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a56a3d9cb6bad..e28a2f5dfda9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1808,7 +1808,7 @@ class Analyzer(override val catalogManager: CatalogManager) * the current plan is needed to first try resolving the attribute by its * children */ - private def wrapLCARefHelper( + private def wrapLCARef( e: NamedExpression, currentPlan: LogicalPlan, aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { @@ -1827,9 +1827,14 @@ class Analyzer(override val catalogManager: CatalogManager) case _ => u } case o: OuterReference - if aliasMap.contains(o.nameParts.map(_.head).getOrElse(o.name)) => + if aliasMap.contains( + o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .map(_.head) + .getOrElse(o.name)) => // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o.nameParts.getOrElse(Seq(o.name)) + val nameParts = o + .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .getOrElse(Seq(o.name)) val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => @@ -1853,7 +1858,7 @@ class Analyzer(override val catalogManager: CatalogManager) var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => - val lcaWrapped = wrapLCARefHelper(a, p, aliasMap).asInstanceOf[Alias] + val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] // Insert the LCA-resolved alias instead of the unresolved one into map. If it is // resolved, it can be referenced as LCA by later expressions (chaining). // Unresolved Alias is also added to the map to perform ambiguous name check, but @@ -1861,7 +1866,7 @@ class Analyzer(override val catalogManager: CatalogManager) aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) lcaWrapped case (e, _) => - wrapLCARefHelper(e, p, aliasMap) + wrapLCARef(e, p, aliasMap) } p.copy(projectList = newProjectList) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9937a06de9a98..ff8450d524c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -643,8 +643,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB projectList.foreach(_.transformDownWithPruning( _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference if p.resolved => - failUnresolvedAttribute( - p, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + throw SparkException.internalError("Resolved Project should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $p", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) }) case j: Join if !j.duplicateResolved => @@ -729,8 +731,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB aggList.foreach(_.transformDownWithPruning( _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { case lcaRef: LateralColumnAliasReference => - failUnresolvedAttribute( - agg, UnresolvedAttribute(lcaRef.nameParts), "UNRESOLVED_COLUMN") + throw SparkException.internalError("Resolved Aggregate should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $agg", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) }) case _ => // Analysis successful! diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index c86d0a6dff0bb..2ca187b95ffda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE import org.apache.spark.sql.internal.SQLConf @@ -67,6 +68,13 @@ import org.apache.spark.sql.internal.SQLConf object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) + /** + * A tag to store the nameParts from the original unresolved attribute. + * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back + * to [[LateralColumnAliasReference]]. + */ + val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index ff65eecafc48d..0f5239be6cae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -424,18 +424,8 @@ case class OuterReference(e: NamedExpression) override def qualifier: Seq[String] = e.qualifier override def exprId: ExprId = e.exprId override def toAttribute: Attribute = e.toAttribute - override def newInstance(): NamedExpression = - OuterReference(e.newInstance()).setNameParts(nameParts) + override def newInstance(): NamedExpression = OuterReference(e.newInstance()) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) - - // optional field, the original name parts of UnresolvedAttribute before it is resolved to - // OuterReference. Used in rule ResolveLateralColumnAlias to convert OuterReference back to - // LateralColumnAliasReference. - var nameParts: Option[Seq[String]] = None - def setNameParts(newNameParts: Option[Seq[String]]): OuterReference = { - nameParts = newNameParts - this - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d249a2b5a6bb7..b510893f370e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} @@ -159,7 +160,11 @@ object SubExprUtils extends PredicateHelper { * Wrap attributes in the expression with [[OuterReference]]s. */ def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { - e.transform { case a: Attribute => OuterReference(a).setNameParts(nameParts) }.asInstanceOf[E] + e.transform { case a: Attribute => + val o = OuterReference(a) + nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _)) + o + }.asInstanceOf[E] } /**