Skip to content

Commit

Permalink
add save load for power iteration clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Apr 10, 2015
1 parent b5c51c8 commit 63c3923
Showing 1 changed file with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.{Logging, SparkException}
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
Expand All @@ -38,7 +44,63 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class PowerIterationClusteringModel(
val k: Int,
val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {

override def save(sc: SparkContext, path: String): Unit = {
PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
}

override protected def formatVersion: String = "1.0"
}

object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
}

private[clustering]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"

def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val dataRDD = model.assignments.map(Tuple1.apply).toDF("assignment")
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val k = (metadata \ "k").extract[Int]
val assignments = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[Tuple1[PowerIterationClustering.Assignment]](assignments.schema)

val assignmentsRDD = assignments.map {
case Row(Tuple1(x: PowerIterationClustering.Assignment)) => x
}

val realK = assignmentsRDD.map(_.cluster).distinct().collect().size
assert(k == realK)

new PowerIterationClusteringModel(k, assignmentsRDD)
}
}
}

/**
* :: Experimental ::
Expand Down

0 comments on commit 63c3923

Please sign in to comment.