Skip to content

Commit

Permalink
Add GlobalLimitAndOffset
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Feb 4, 2020
1 parent 1149727 commit 8bcee2f
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ trait CheckAnalysis extends PredicateHelper {
}
}

case GlobalLimit(limitExpr, _, _) => checkLimitLikeClause("limit", limitExpr)
case GlobalLimit(limitExpr, _) => checkLimitLikeClause("limit", limitExpr)

case LocalLimit(limitExpr, _, child) =>
case LocalLimit(limitExpr, child) =>
checkLimitLikeClause("limit", limitExpr)
child match {
case Offset(offsetExpr, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ object UnsupportedOperationChecker extends Logging {
case GroupingSets(_, _, child, _) if child.isStreaming =>
throwError("GroupingSets is not supported on streaming DataFrames/Datasets")

case GlobalLimit(_, _, _) | LocalLimit(_, _, _)
case GlobalLimit(_, _) | LocalLimit(_, _)
if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update =>
throwError("Limits are not supported on streaming DataFrames/Datasets in Update " +
"output mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ package object dsl {

def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)

def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, child = logicalPlan)
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)

def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,24 +452,21 @@ object LimitPushDown extends Rule[LogicalPlan] {

private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = {
plan match {
case GlobalLimit(_, _, child) => child
case GlobalLimit(_, child) => child
case _ => plan
}
}

private def maybePushLocalLimit(
limitExp: Expression,
offsetExp: Expression,
plan: LogicalPlan): LogicalPlan = {
private def maybePushLocalLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = {
(limitExp, plan.maxRowsPerPartition) match {
case (IntegerLiteral(newLimit), Some(childMaxRows)) if newLimit < childMaxRows =>
// If the child has a cap on max rows per partition and the cap is larger than
// the new limit, put a new LocalLimit there.
LocalLimit(limitExp, offsetExp, stripGlobalLimitIfPresent(plan))
LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))

case (_, None) =>
// If the child has no cap, put the new LocalLimit.
LocalLimit(limitExp, offsetExp, stripGlobalLimitIfPresent(plan))
LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))

case _ =>
// Otherwise, don't put a new LocalLimit.
Expand All @@ -484,22 +481,22 @@ object LimitPushDown extends Rule[LogicalPlan] {
// Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to
// pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to
// pushdown Limit.
case LocalLimit(le, oe, Union(children)) =>
LocalLimit(le, oe, Union(children.map(maybePushLocalLimit(le, oe, _))))
case LocalLimit(exp, Union(children)) =>
LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _))))
// Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to
// the left and right sides, respectively. It's not safe to push limits below FULL OUTER
// JOIN in the general case without a more invasive rewrite.
// We also need to ensure that this limit pushdown rule will not eventually introduce limits
// on both sides if it is applied multiple times. Therefore:
// - If one side is already limited, stack another limit on top if the new limit is smaller.
// The redundant limit will be collapsed by the CombineLimits rule.
case LocalLimit(le, oe, join @ Join(left, right, joinType, _, _)) =>
case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) =>
val newJoin = joinType match {
case RightOuter => join.copy(right = maybePushLocalLimit(le, oe, right))
case LeftOuter => join.copy(left = maybePushLocalLimit(le, oe, left))
case RightOuter => join.copy(right = maybePushLocalLimit(exp, right))
case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left))
case _ => join
}
LocalLimit(le, oe, newJoin)
LocalLimit(exp, newJoin)
}
}

Expand Down Expand Up @@ -714,11 +711,11 @@ object CollapseProject extends Rule[LogicalPlan] {
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
}
case Project(l1, g @ GlobalLimit(_, _, limit @ LocalLimit(_, _, p2 @ Project(l2, _))))
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
if isRenaming(l1, l2) =>
val newProjectList = buildCleanedProjectList(l1, l2)
g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList)))
case Project(l1, limit @ LocalLimit(_, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
val newProjectList = buildCleanedProjectList(l1, l2)
limit.copy(child = p2.copy(projectList = newProjectList))
case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) =>
Expand Down Expand Up @@ -1386,12 +1383,12 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
*/
object CombineLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case GlobalLimit(le, oe, GlobalLimit(nle, noe, grandChild)) =>
GlobalLimit(Least(Seq(nle, le)), Greatest(Seq(noe, oe)), grandChild)
case LocalLimit(le, oe, LocalLimit(nle, noe, grandChild)) =>
LocalLimit(Least(Seq(nle, le)), Greatest(Seq(noe, oe)), grandChild)
case Limit(le, oe, Limit(nle, noe, grandChild)) =>
Limit(Least(Seq(nle, le)), Greatest(Seq(noe, oe)), grandChild)
case GlobalLimit(le, GlobalLimit(ne, grandChild)) =>
GlobalLimit(Least(Seq(ne, le)), grandChild)
case LocalLimit(le, LocalLimit(ne, grandChild)) =>
LocalLimit(Least(Seq(ne, le)), grandChild)
case Limit(le, Limit(ne, grandChild)) =>
Limit(Least(Seq(ne, le)), grandChild)
}
}

Expand All @@ -1401,10 +1398,10 @@ object CombineLimits extends Rule[LogicalPlan] {
*/
object RewriteOffsets extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case GlobalLimit(le, oe, Offset(noe, grandChild)) =>
GlobalLimit(le, Greatest(Seq(noe, oe)), grandChild)
case LocalLimit(le, oe, Offset(noe, grandChild)) =>
Offset(noe, LocalLimit(le, Greatest(Seq(noe, oe)), grandChild))
case GlobalLimit(le, Offset(oe, grandChild)) =>
GlobalLimitAndOffset(le, oe, grandChild)
case LocalLimit(le, Offset(oe, grandChild)) =>
Offset(oe, LocalLimit(Add(le, oe), grandChild))
}
}

Expand Down Expand Up @@ -1512,11 +1509,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
projection.initialize(0)
LocalRelation(projectList.map(_.toAttribute), data.map(projection(_).copy()), isStreaming)

case Limit(
IntegerLiteral(limit),
IntegerLiteral(offset),
LocalRelation(output, data, isStreaming)) =>
LocalRelation(output, data.drop(offset).take(limit), isStreaming)
case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) =>
LocalRelation(output, data.take(limit), isStreaming)

case Filter(condition, LocalRelation(output, data, isStreaming))
if !hasUnevaluableExpr(condition) =>
Expand Down Expand Up @@ -1799,7 +1793,7 @@ object OptimizeLimitZero extends Rule[LogicalPlan] {
// changes up the Logical Plan.
//
// Replace Global Limit 0 nodes with empty Local Relation
case gl @ GlobalLimit(IntegerLiteral(0), _, _) =>
case gl @ GlobalLimit(IntegerLiteral(0), _) =>
empty(gl)

// Note: For all SQL queries, if a LocalLimit 0 node exists in the Logical Plan, then a
Expand All @@ -1808,7 +1802,7 @@ object OptimizeLimitZero extends Rule[LogicalPlan] {
// then the following rule will handle that case as well.
//
// Replace Local Limit 0 nodes with empty Local Relation
case ll @ LocalLimit(IntegerLiteral(0), _, _) =>
case ll @ LocalLimit(IntegerLiteral(0), _) =>
empty(ll)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
case exists: Exists if exists.children.isEmpty =>
IsNotNull(
ScalarSubquery(
plan = Limit(Literal(1), child = Project(Seq(Alias(Literal(1), "col")()), exists.plan)),
plan = Limit(Literal(1), Project(Seq(Alias(Literal(1), "col")()), exists.plan)),
exprId = exists.exprId))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
// LIMIT
// - LIMIT ALL is the same as omitting the LIMIT clause
withOffset.optional(limit) {
Limit(typedVisit(limit), Literal(0), withOffset)
Limit(typedVisit(limit), withOffset)
}
}

Expand Down Expand Up @@ -1007,7 +1007,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging

ctx.sampleMethod() match {
case ctx: SampleByRowsContext =>
Limit(expression(ctx.expression), child = query)
Limit(expression(ctx.expression), query)

case ctx: SampleByPercentileContext =>
val fraction = ctx.percentage.getText.toDouble
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,32 +797,24 @@ case class Pivot(
* So we introduced LocalLimit and GlobalLimit in the logical plan node for limit pushdown.
*/
object Limit {
def apply(
limitExpr: Expression,
offsetExpr: Expression = Literal(0),
child: LogicalPlan): UnaryNode = {
GlobalLimit(limitExpr, offsetExpr, LocalLimit(limitExpr, offsetExpr, child))
def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = {
GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
}

def unapply(p: GlobalLimit): Option[(Expression, Expression, LogicalPlan)] = {
def unapply(p: GlobalLimit): Option[(Expression, LogicalPlan)] = {
p match {
case GlobalLimit(le1, oe1, LocalLimit(le2, oe2, child)) if le1 == le2 && oe1 == oe2 =>
Some((le1, oe1, child))
case GlobalLimit(le1, LocalLimit(le2, child)) if le1 == le2 => Some((le1, child))
case _ => None
}
}
}

/**
* A global (coordinated) limit. This operator can remove at most `offsetExpr` number
* and emit at most `limitExpr` number in total.
* A global (coordinated) limit. This operator can emit at most `limitExpr` number in total.
*
* See [[Limit]] for more information.
*/
case class GlobalLimit(
limitExpr: Expression,
offsetExpr: Expression,
child: LogicalPlan) extends OrderPreservingUnaryNode {
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = {
limitExpr match {
Expand All @@ -833,18 +825,28 @@ case class GlobalLimit(
}

/**
* A partition-local (non-coordinated) limit. This operator remove at most `offsetExpr`
* number and emit at most `limitExpr` number of tuples on each physical partition.
* A partition-local (non-coordinated) limit. This operator can emit at most `limitExpr` number
* of tuples on each physical partition.
*
* See [[Limit]] for more information.
*/
case class LocalLimit(
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output

override def maxRowsPerPartition: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
case _ => None
}
}
}

case class GlobalLimitAndOffset(
limitExpr: Expression,
offsetExpr: Expression,
child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output

override def maxRowsPerPartition: Option[Long] = {
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class AnalysisErrorSuite extends AnalysisTest {
val plan4 = Filter(
Exists(
Limit(1,
child = Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
),
LocalRelation(a))
assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ class CollapseProjectSuite extends PlanTest {

test("collapse redundant alias through local limit") {
val relation = LocalRelation('a.int, 'b.int)
val query = LocalLimit(1, 0, relation.select('a as 'b)).select('b as 'c).analyze
val query = LocalLimit(1, relation.select('a as 'b)).select('b as 'c).analyze
val optimized = Optimize.execute(query)
val expected = LocalLimit(1, 0, relation.select('a as 'c)).analyze
val expected = LocalLimit(1, relation.select('a as 'c)).analyze
comparePlans(optimized, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ class LimitPushdownSuite extends PlanTest {
val unionQuery = Union(testRelation, testRelation2).limit(1)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Limit(1, 0, Union(LocalLimit(1, 0, testRelation), LocalLimit(1, 0, testRelation2))).analyze
Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

test("Union: limit to each side with constant-foldable limit expressions") {
val unionQuery = Union(testRelation, testRelation2).limit(Add(1, 1))
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Limit(2, 0, Union(LocalLimit(2, 0, testRelation), LocalLimit(2, 0, testRelation2))).analyze
Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

test("Union: limit to each side with the new limit number") {
val unionQuery = Union(testRelation, testRelation2.limit(3)).limit(1)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Limit(1, 0, Union(LocalLimit(1, 0, testRelation), LocalLimit(1, 0, testRelation2))).analyze
Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

Expand All @@ -74,7 +74,7 @@ class LimitPushdownSuite extends PlanTest {
Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).limit(2)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Limit(2, 0, Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1))).analyze
Limit(2, Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

Expand All @@ -83,8 +83,8 @@ class LimitPushdownSuite extends PlanTest {
Union(testRelation.limit(3), testRelation2.select('d, 'e, 'f).limit(4)).limit(2)
val unionOptimized = Optimize.execute(unionQuery.analyze)
val unionCorrectAnswer =
Limit(2, 0, Union(
LocalLimit(2, 0, testRelation), LocalLimit(2, 0, testRelation2.select('d, 'e, 'f)))).analyze
Limit(2, Union(
LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d, 'e, 'f)))).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

Expand All @@ -93,49 +93,49 @@ class LimitPushdownSuite extends PlanTest {
test("left outer join") {
val originalQuery = x.join(y, LeftOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, LocalLimit(1, 0, x).join(y, LeftOuter)).analyze
val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("left outer join and left sides are limited") {
val originalQuery = x.limit(2).join(y, LeftOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, LocalLimit(1, 0, x).join(y, LeftOuter)).analyze
val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("left outer join and right sides are limited") {
val originalQuery = x.join(y.limit(2), LeftOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, LocalLimit(1, 0, x).join(Limit(2, 0, y), LeftOuter)).analyze
val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("right outer join") {
val originalQuery = x.join(y, RightOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, x.join(LocalLimit(1, 0, y), RightOuter)).analyze
val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("right outer join and right sides are limited") {
val originalQuery = x.join(y.limit(2), RightOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, x.join(LocalLimit(1, 0, y), RightOuter)).analyze
val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("right outer join and left sides are limited") {
val originalQuery = x.limit(2).join(y, RightOuter).limit(1)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(1, 0, Limit(2, 0, x).join(LocalLimit(1, 0, y), RightOuter)).analyze
val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze
comparePlans(optimized, correctAnswer)
}

test("larger limits are not pushed on top of smaller ones in right outer join") {
val originalQuery = x.join(y.limit(5), RightOuter).limit(10)
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Limit(10, 0, x.join(Limit(5, 0, y), RightOuter)).analyze
val correctAnswer = Limit(10, x.join(Limit(5, y), RightOuter)).analyze
comparePlans(optimized, correctAnswer)
}

Expand Down
Loading

0 comments on commit 8bcee2f

Please sign in to comment.