Skip to content

Commit

Permalink
add files.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed May 23, 2017
1 parent ad09e4c commit 4cc2c45
Show file tree
Hide file tree
Showing 20 changed files with 122 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ class Analyzer(

// Category 1:
// BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>
case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>

// Category 2:
// These operators can be anywhere in a correlated subquery.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ trait CheckAnalysis extends PredicateHelper {
|in operator ${operator.simpleString}
""".stripMargin)

case _: Hint =>
case _: UnresolvedHint =>
throw new IllegalStateException(
"Internal error: logical hint operator should have been removed during analysis")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ object ResolveHints {
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
BroadcastHint(plan)
ResolvedHint(isBroadcastable = Option(true), plan)
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
BroadcastHint(plan)
ResolvedHint(isBroadcastable = Option(true), plan)

case _: BroadcastHint | _: View | _: With | _: SubqueryAlias =>
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
// For an existing broadcast hint, there is no point going down (if we do, we either
// won't change the structure, or will introduce another broadcast hint that is useless.
Expand All @@ -85,10 +85,10 @@ object ResolveHints {
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
BroadcastHint(h.child)
ResolvedHint(isBroadcastable = Option(true), h.child)
} else {
// Otherwise, find within the subtree query plans that should be broadcasted.
applyBroadcastHint(h.child, h.parameters.toSet)
Expand All @@ -102,7 +102,7 @@ object ResolveHints {
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case h: Hint => h.child
case h: UnresolvedHint => h.child
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// Note that some operators (e.g. project, aggregate, union) are being handled separately
// (earlier in this rule).
case _: AppendColumns => true
case _: BroadcastHint => true
case _: ResolvedHint => true
case _: Distinct => true
case _: Generate => true
case _: Pivot => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case _: Distinct => true
case _: AppendColumns => true
case _: AppendColumnsWithObject => true
case _: BroadcastHint => true
case _: ResolvedHint => true
case _: RepartitionByExpression => true
case _: Repartition => true
case _: Sort => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}

/**
* Add a [[Hint]] to a logical plan.
* Add a [[UnresolvedHint]] to a logical plan.
*/
private def withHints(
ctx: HintContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val stmt = ctx.hintStatement
Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query)
UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)

case BroadcastHint(child) =>
collectProjectsAndFilters(child)
case h: ResolvedHint =>
collectProjectsAndFilters(h.child)

case other =>
(None, Nil, other, Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ case class Statistics(
s"isBroadcastable=$isBroadcastable"
).filter(_.nonEmpty).mkString(", ")
}

/** Must be called when computing stats for a join operator to reset hints. */
def resetHintsForJoin(): Statistics = copy(
isBroadcastable = false
)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ case class Join(
case _ =>
// Make sure we don't propagate isBroadcastable in other joins, because
// they could explode the size.
super.computeStats(conf).copy(isBroadcastable = false)
super.computeStats(conf).resetHintsForJoin()
}

if (conf.cboEnabled) {
Expand All @@ -375,26 +375,6 @@ case class Join(
}
}

/**
* A hint for the optimizer that we should broadcast the `child` if used in a join operator.
*/
case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

// set isBroadcastable to true so the child will be broadcasted
override def computeStats(conf: SQLConf): Statistics =
child.stats(conf).copy(isBroadcastable = true)
}

/**
* A general hint for the child. This node will be eliminated post analysis.
* A pair of (name, parameters).
*/
case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}

/**
* Insert some data into a table. Note that this plan is unresolved and has to be replaced by the
* concrete implementations during analysis.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf

/**
* A general hint for the child that is not yet resolved. This node is generated by the parser and
* should be removed This node will be eliminated post analysis.
* A pair of (name, parameters).
*/
case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalPlan)
extends UnaryNode {

override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}

/**
* A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]].
*/
case class ResolvedHint(
isBroadcastable: Option[Boolean] = None,
child: LogicalPlan)
extends UnaryNode {

override def output: Seq[Attribute] = child.output

override def computeStats(conf: SQLConf): Statistics = {
val stats = child.stats(conf)
isBroadcastable.map(x => stats.copy(isBroadcastable = x)).getOrElse(stats)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,68 +28,70 @@ class ResolveHintsSuite extends AnalysisTest {

test("invalid hints should be ignored") {
checkAnalysis(
Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
testRelation,
caseSensitive = false)
}

test("case-sensitive or insensitive parameters") {
checkAnalysis(
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
BroadcastHint(testRelation),
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(isBroadcastable = Option(true), testRelation),
caseSensitive = false)

checkAnalysis(
Hint("MAPJOIN", Seq("table"), table("TaBlE")),
BroadcastHint(testRelation),
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
ResolvedHint(isBroadcastable = Option(true), testRelation),
caseSensitive = false)

checkAnalysis(
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
BroadcastHint(testRelation),
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(isBroadcastable = Option(true), testRelation),
caseSensitive = true)

checkAnalysis(
Hint("MAPJOIN", Seq("table"), table("TaBlE")),
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
testRelation,
caseSensitive = true)
}

test("multiple broadcast hint aliases") {
checkAnalysis(
Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None),
UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
Join(ResolvedHint(isBroadcastable = Option(true), testRelation),
ResolvedHint(isBroadcastable = Option(true), testRelation2), Inner, None),
caseSensitive = false)
}

test("do not traverse past existing broadcast hints") {
checkAnalysis(
Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))),
BroadcastHint(testRelation.where('a > 1)).analyze,
UnresolvedHint("MAPJOIN", Seq("table"),
ResolvedHint(isBroadcastable = Option(true), table("table").where('a > 1))),
ResolvedHint(isBroadcastable = Option(true), testRelation.where('a > 1)).analyze,
caseSensitive = false)
}

test("should work for subqueries") {
checkAnalysis(
Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
BroadcastHint(testRelation),
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
ResolvedHint(isBroadcastable = Option(true), testRelation),
caseSensitive = false)

checkAnalysis(
Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
BroadcastHint(testRelation),
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
ResolvedHint(isBroadcastable = Option(true), testRelation),
caseSensitive = false)

// Negative case: if the alias doesn't match, don't match the original table name.
checkAnalysis(
Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")),
UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")),
testRelation,
caseSensitive = false)
}

test("do not traverse past subquery alias") {
checkAnalysis(
Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)),
UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)),
testRelation.where('a > 1).analyze,
caseSensitive = false)
}
Expand All @@ -102,7 +104,8 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze,
ResolvedHint(isBroadcastable = Option(true),
testRelation.where('a > 1).select('a)).select('a).analyze,
caseSensitive = false)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,16 @@ class ColumnPruningSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input),
BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze
ResolvedHint(isBroadcastable = Option(true),
SubqueryAlias("y", input)), Inner, None)).analyze

val optimized = Optimize.execute(query)

val expected =
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
BroadcastHint(
ResolvedHint(
isBroadcastable = Option(true),
Project(Seq($"y.key"), SubqueryAlias("y", input))),
Inner, None).analyze

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,12 @@ class FilterPushdownSuite extends PlanTest {
}

test("broadcast hint") {
val originalQuery = BroadcastHint(testRelation)
val originalQuery = ResolvedHint(isBroadcastable = Option(true), testRelation)
.where('a === 2L && 'b + Rand(10).as("rnd") === 3)

val optimized = Optimize.execute(originalQuery.analyze)

val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
val correctAnswer = ResolvedHint(isBroadcastable = Option(true), testRelation.where('a === 2L))
.where('b + Rand(10).as("rnd") === 3)
.analyze

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,16 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
SubqueryAlias("x", input),
BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze
ResolvedHint(isBroadcastable = Option(true),
SubqueryAlias("y", input)), Cross, None)).analyze

val optimized = Optimize.execute(query)

val expected =
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
ResolvedHint(isBroadcastable = Option(true),
Project(Seq($"y.key"), SubqueryAlias("y", input))),
Cross, None).analyze

comparePlans(optimized, expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,30 +534,31 @@ class PlanParserSuite extends PlanTest {

comparePlans(
parsePlan("SELECT /*+ HINT */ * FROM t"),
Hint("HINT", Seq.empty, table("t").select(star())))
UnresolvedHint("HINT", Seq.empty, table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"),
Hint("BROADCASTJOIN", Seq("u"), table("t").select(star())))
UnresolvedHint("BROADCASTJOIN", Seq("u"), table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"),
Hint("MAPJOIN", Seq("u"), table("t").select(star())))
UnresolvedHint("MAPJOIN", Seq("u"), table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"),
Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star())))
UnresolvedHint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"),
Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star())))
UnresolvedHint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star())))

comparePlans(
parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"),
Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star())))
UnresolvedHint("MAPJOIN", Seq("default.t"), table("default.t").select(star())))

comparePlans(
parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
UnresolvedHint("MAPJOIN", Seq("t"),
table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
expectedStatsCboOn = filterStatsCboOn,
expectedStatsCboOff = filterStatsCboOff)

val broadcastHint = BroadcastHint(filter)
val broadcastHint = ResolvedHint(isBroadcastable = Option(true), filter)
checkStats(
broadcastHint,
expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
Expand Down
Loading

0 comments on commit 4cc2c45

Please sign in to comment.