Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel #5450

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.{Logging, SparkException}
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.{Logging, SparkContext, SparkException}

/**
* :: Experimental ::
Expand All @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class PowerIterationClusteringModel(
val k: Int,
val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {

override def save(sc: SparkContext, path: String): Unit = {
PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
}

override protected def formatVersion: String = "1.0"
}

object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
}

private[clustering]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"

def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

val dataRDD = model.assignments.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)

val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)

val k = (metadata \ "k").extract[Int]
val assignments = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)

val assignmentsRDD = assignments.map {
case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
}

new PowerIterationClusteringModel(k, assignmentsRDD)
}
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
val v = powerIter(w, maxIterations)
val assignments = kMeans(v, k).mapPartitions({ iter =>
iter.map { case (id, cluster) =>
new Assignment(id, cluster)
Assignment(id, cluster)
}
}, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
Expand All @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
* @param cluster assigned cluster id
*/
@Experimental
class Assignment(val id: Long, val cluster: Int) extends Serializable
case class Assignment(id: Long, cluster: Int)

/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.mllib.clustering

import scala.collection.mutable
import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.SparkContext
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
assert(x ~== u1(i.toInt) absTol 1e-14)
}
}

test("model save/load") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
try {
model.save(sc, path)
val sameModel = PowerIterationClusteringModel.load(sc, path)
PowerIterationClusteringSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

object PowerIterationClusteringSuite extends FunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
new PowerIterationClusteringModel(k, assignments)
}

def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
assert(a.k === b.k)

val aAssignments = a.assignments.map(x => (x.id, x.cluster))
val bAssignments = b.assignments.map(x => (x.id, x.cluster))
val unequalElements = aAssignments.join(bAssignments).filter {
case (id, (c1, c2)) => c1 != c2 }.count()
assert(unequalElements === 0L)
}
}