From 313b2c98e9513e50d2764b28c447c3a7cd281ebb Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Mon, 28 Nov 2022 10:38:29 -0800 Subject: [PATCH] add tests, refine logic --- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++--- .../spark/sql/LateralColumnAliasSuite.scala | 109 ++++++++++++++++++ 2 files changed, 130 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a3c79228f1cf1..bb92639c6b3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1388,20 +1388,20 @@ class Analyzer(override val catalogManager: CatalogManager) * resolved by other rules * - in Aggregate TODO. * - * For Project, it rewrites the Project plan by inserting a newly created Project plan between - * the original Project and its child, and updating the project list of the original Project plan. - * The project list of the new Project plan is the lateral column aliases that are referenced - * in the original project list. These aliases in the original project list are updated to - * attribute references. + * For Project, it rewrites by inserting a newly created Project plan between the original Project + * and its child, pushing the referenced lateral column aliases to this new Project, and updating + * the project list of the original Project. * * Before rewrite: - * Project [age AS a, a + 1] + * Project [age AS a, 'a + 1] * +- Child * * After rewrite: - * Project [a, a + 1] - * +- Project [age AS a] + * Project [a, 'a + 1] + * +- Project [child output, age AS a] * +- Child + * + * For Aggregate TODO. */ object ResolveLateralColumnAlias extends Rule[LogicalPlan] { private case class AliasEntry(alias: Alias, index: Int) @@ -1411,14 +1411,14 @@ class Analyzer(override val catalogManager: CatalogManager) case p @ Project(projectList, child) if p.childrenResolved && !ResolveReferences.containsStar(projectList) && projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) => - // TODO: delta + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - var referencedAliases = Seq[AliasEntry]() - def updateAliasMap(a: Alias, idx: Int): Unit = { + def insertIntoAliasMap(a: Alias, idx: Int): Unit = { val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx))) } - def searchMatchedLCA(e: Expression): Unit = { + def lookUpLCA(e: Expression): Option[AliasEntry] = { + var matchedLCA: Option[AliasEntry] = None e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] => @@ -1430,13 +1430,15 @@ class Analyzer(override val catalogManager: CatalogManager) val referencedAlias = aliases.head // Only resolved alias can be the lateral column alias if (referencedAlias.alias.resolved) { - referencedAliases :+= referencedAlias + matchedLCA = Some(referencedAlias) } } u } + matchedLCA } - projectList.zipWithIndex.foreach { + + val referencedAliases = projectList.zipWithIndex.flatMap { case (a: Alias, idx) => // Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed // down. Unresolved alias is added to the map to perform the ambiguous name check. @@ -1445,13 +1447,13 @@ class Analyzer(override val catalogManager: CatalogManager) // only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are // all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved, // it is pushed down. - searchMatchedLCA(a) - updateAliasMap(a, idx) + val matchedLCA = lookUpLCA(a) + insertIntoAliasMap(a, idx) + matchedLCA case (e, _) => - searchMatchedLCA(e) - } + lookUpLCA(e) + }.toSet - referencedAliases = referencedAliases.sortBy(_.index) if (referencedAliases.isEmpty) { p } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala new file mode 100644 index 0000000000000..daf750c39bb1e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { + protected val testTable: String = "employee" + + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " + + s"using orc") + sql( + s""" + |INSERT INTO $testTable VALUES + | (1, 'amy', 10000, 1000), + | (2, 'alex', 12000, 1200), + | (1, 'cathy', 9000, 1200), + | (2, 'david', 10000, 1300), + | (6, 'jen', 12000, 1200) + |""".stripMargin) + } + + override def afterAll(): Unit = { + try { + sql(s"DROP TABLE IF EXISTS $testTable") + } finally { + super.afterAll() + } + } + + val lcaEnabled: Boolean = true + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + + test("Lateral alias in project") { + checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"), + Row(1, 2)) + + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"), + Row(20000, 21000)) + checkAnswer( + sql( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 23000)) + + // When the lateral alias conflicts with the table column, it should resolved as the table + // column + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + + // Corner cases for resolution order + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + } +}