Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Oct 24, 2014
1 parent eb3938d commit be37ece
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private[spark] object SerDeUtil extends Logging {
* Convert an RDD of Java objects to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
private[python] def toJavaArray(jrdd: JavaRDD[_]): JavaRDD[Array[_]] = {
def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
jrdd.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
Expand Down Expand Up @@ -139,7 +139,7 @@ private[spark] object SerDeUtil extends Logging {
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
}

Expand Down Expand Up @@ -200,54 +200,41 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())

rdd.mapPartitions { iter =>
val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
if (batchSize > 1) {
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
if (batchSize == 0) {
new AutoBatchedPickler(cleaned)
} else {
cleaned.map(pickle.dumps(_))
val pickle = new Pickler
cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}

/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
initialize()
val unpickle = new Unpickler
val unpickled =
if (batchSerialized) {
iter.flatMap { batch =>
unpickle.loads(batch) match {
case objs: java.util.List[_] => collectionAsScalaIterable(objs)
case other => throw new SparkException(
s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
}
}
} else {
iter.map(unpickle.loads(_))
}
unpickled.map {
case obj if isPair(obj) =>
// we only accept (K, V)
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}

val rdd = pythonToJava(pyRDD, batched).rdd
rdd.first match {
case obj if isPair(obj) =>
// we only accept (K, V)
case other => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}
rdd.map { obj =>
val arr = obj.asInstanceOf[Array[_]]
(arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}

}

0 comments on commit be37ece

Please sign in to comment.