Skip to content

Commit

Permalink
Use the proper serializer in limit.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Mar 26, 2014
1 parent 9b79246 commit 87b7d37
Showing 1 changed file with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ package execution

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext

import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.util.MutablePair


case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
override def output = projectList.map(_.toAttribute)
Expand Down Expand Up @@ -70,17 +71,24 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
* data to a single partition to compute the global limit.
*/
case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again

override def otherCopyArgs = sc :: Nil

override def output = child.output

override def executeCollect() = child.execute().map(_.copy()).take(limit)

override def execute() = {
child.execute()
.mapPartitions(_.take(limit).map(_.copy()))
.coalesce(1, shuffle = true)
.mapPartitions(_.take(limit))
val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Boolean, Row]()
iter.take(limit).map(row => mutablePair.update(false, row))
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.mapPartitions(_.take(limit).map(_._2))
}
}

Expand Down

0 comments on commit 87b7d37

Please sign in to comment.