Skip to content

Commit

Permalink
add tests, refine logic
Browse files Browse the repository at this point in the history
  • Loading branch information
anchovYu committed Nov 28, 2022
1 parent 94adb3f commit 313b2c9
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1388,20 +1388,20 @@ class Analyzer(override val catalogManager: CatalogManager)
* resolved by other rules
* - in Aggregate TODO.
*
* For Project, it rewrites the Project plan by inserting a newly created Project plan between
* the original Project and its child, and updating the project list of the original Project plan.
* The project list of the new Project plan is the lateral column aliases that are referenced
* in the original project list. These aliases in the original project list are updated to
* attribute references.
* For Project, it rewrites by inserting a newly created Project plan between the original Project
* and its child, pushing the referenced lateral column aliases to this new Project, and updating
* the project list of the original Project.
*
* Before rewrite:
* Project [age AS a, a + 1]
* Project [age AS a, 'a + 1]
* +- Child
*
* After rewrite:
* Project [a, a + 1]
* +- Project [age AS a]
* Project [a, 'a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* For Aggregate TODO.
*/
object ResolveLateralColumnAlias extends Rule[LogicalPlan] {
private case class AliasEntry(alias: Alias, index: Int)
Expand All @@ -1411,14 +1411,14 @@ class Analyzer(override val catalogManager: CatalogManager)
case p @ Project(projectList, child) if p.childrenResolved
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsPattern(UNRESOLVED_ATTRIBUTE)) =>
// TODO: delta

var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
var referencedAliases = Seq[AliasEntry]()
def updateAliasMap(a: Alias, idx: Int): Unit = {
def insertIntoAliasMap(a: Alias, idx: Int): Unit = {
val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry])
aliasMap += (a.name -> (prevAliases :+ AliasEntry(a, idx)))
}
def searchMatchedLCA(e: Expression): Unit = {
def lookUpLCA(e: Expression): Option[AliasEntry] = {
var matchedLCA: Option[AliasEntry] = None
e.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) {
case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) &&
resolveExpressionByPlanChildren(u, p).isInstanceOf[UnresolvedAttribute] =>
Expand All @@ -1430,13 +1430,15 @@ class Analyzer(override val catalogManager: CatalogManager)
val referencedAlias = aliases.head
// Only resolved alias can be the lateral column alias
if (referencedAlias.alias.resolved) {
referencedAliases :+= referencedAlias
matchedLCA = Some(referencedAlias)
}
}
u
}
matchedLCA
}
projectList.zipWithIndex.foreach {

val referencedAliases = projectList.zipWithIndex.flatMap {
case (a: Alias, idx) =>
// Add all alias to the aliasMap. But note only resolved alias can be LCA and pushed
// down. Unresolved alias is added to the map to perform the ambiguous name check.
Expand All @@ -1445,13 +1447,13 @@ class Analyzer(override val catalogManager: CatalogManager)
// only 1 AS a is pushed down, even though 1 AS a, 'a + 1 AS b and 'b + 1 AS c are
// all added to the aliasMap. On the second round, when 'a + 1 AS b is resolved,
// it is pushed down.
searchMatchedLCA(a)
updateAliasMap(a, idx)
val matchedLCA = lookUpLCA(a)
insertIntoAliasMap(a, idx)
matchedLCA
case (e, _) =>
searchMatchedLCA(e)
}
lookUpLCA(e)
}.toSet

referencedAliases = referencedAliases.sortBy(_.index)
if (referencedAliases.isEmpty) {
p
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

class LateralColumnAliasSuite extends QueryTest with SharedSparkSession {
protected val testTable: String = "employee"

override def beforeAll(): Unit = {
super.beforeAll()
sql(s"CREATE TABLE $testTable (dept INTEGER, name String, salary INTEGER, bonus INTEGER) " +
s"using orc")
sql(
s"""
|INSERT INTO $testTable VALUES
| (1, 'amy', 10000, 1000),
| (2, 'alex', 12000, 1200),
| (1, 'cathy', 9000, 1200),
| (2, 'david', 10000, 1300),
| (6, 'jen', 12000, 1200)
|""".stripMargin)
}

override def afterAll(): Unit = {
try {
sql(s"DROP TABLE IF EXISTS $testTable")
} finally {
super.afterAll()
}
}

val lcaEnabled: Boolean = true
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
(implicit pos: Position): Unit = {
super.test(testName, testTags: _*) {
withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_ENABLED.key -> lcaEnabled.toString) {
testFun
}
}
}

test("Lateral alias in project") {
checkAnswer(sql(s"select dept as d, d + 1 as e from $testTable where name = 'amy'"),
Row(1, 2))

checkAnswer(
sql(
s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'"),
Row(20000, 21000))
checkAnswer(
sql(
s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" +
s" where name = 'amy'"),
Row(20000, 22000))

checkAnswer(
sql(
"select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " +
s"new_income from $testTable where name = 'amy'"),
Row(20000, 23000))

// When the lateral alias conflicts with the table column, it should resolved as the table
// column
checkAnswer(
sql(
"select salary * 2 as salary, salary * 2 + bonus as " +
s"new_income from $testTable where name = 'amy'"),
Row(20000, 21000))

checkAnswer(
sql(
"select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " +
s"new_income from $testTable where name = 'amy'"),
Row(20000, 22000))

checkAnswer(
sql(
"select salary * 2 as salary, (salary + bonus) * 2 as bonus, " +
s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" +
" where name = 'amy'"),
Row(20000, 22000, 11000, 22000))

// Corner cases for resolution order
checkAnswer(
sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"),
Row(18000, 18000, 10000)
)
}
}

0 comments on commit 313b2c9

Please sign in to comment.