Skip to content

Commit

Permalink
[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
Browse files Browse the repository at this point in the history
This PR adds model save/load for spark.ml's LogisticRegressionModel.  It also does minor refactoring of the default save/load classes to reuse code.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9606 from jkbradley/logreg-io2.
  • Loading branch information
jkbradley authored and mengxr committed Nov 11, 2015
1 parent 745e45d commit 6e101d2
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
Expand Down Expand Up @@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] (
val coefficients: Vector,
val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {
with LogisticRegressionParams with Writable {

@deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients
Expand Down Expand Up @@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] (
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}

/**
* Returns a [[Writer]] instance for this ML instance.
*
* For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*/
override def write: Writer = new LogisticRegressionWriter(this)
}


/** [[Writer]] instance for [[LogisticRegressionModel]] */
private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel)
extends Writer with Logging {

private case class Data(
numClasses: Int,
numFeatures: Int,
intercept: Double,
coefficients: Vector)

override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
}
}


object LogisticRegressionModel extends Readable[LogisticRegressionModel] {

override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader

override def load(path: String): LogisticRegressionModel = read.load(path)
}


private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"

override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.format("parquet").load(dataPath)
.select("numClasses", "numFeatures", "intercept", "coefficients").head()
// We will need numClasses, numFeatures in the future for multinomial logreg support.
// val numClasses = data.getInt(0)
// val numFeatures = data.getInt(1)
val intercept = data.getDouble(2)
val coefficients = data.getAs[Vector](3)
val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}


/**
* MultiClassSummarizer computes the number of distinct labels and corresponding counts,
* and validates the data to see if the labels used for k class multi-label classification
Expand Down
74 changes: 68 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,21 @@ trait Readable[T] {
private[ml] class DefaultParamsWriter(instance: Params) extends Writer {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
}
}

private[ml] object DefaultParamsWriter {

/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
Expand All @@ -201,23 +216,70 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
private[ml] class DefaultParamsReader[T] extends Reader[T] {

override def load(path: String): T = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[T]
}
}

private[ml] object DefaultParamsReader {

/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
* @param metadataStr Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadataStr: String)

/**
* Load metadata from file.
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
val cls = Utils.classForName((metadata \ "class").extract[String])

implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
(metadata \ "paramMap") match {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}

Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
}

/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
case _ =>
throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
instance.asInstanceOf[T]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}

class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
Expand Down Expand Up @@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3)
assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
}

test("read/write") {
// Set some Params to make sure set Params are serialized.
val lr = new LogisticRegression()
.setElasticNetParam(0.1)
.setMaxIter(2)
.fit(dataset)
val lr2 = testDefaultReadWrite(lr)
assert(lr.intercept === lr2.intercept)
assert(lr.coefficients.toArray === lr2.coefficients.toArray)
assert(lr.numClasses === lr2.numClasses)
assert(lr.numFeatures === lr2.numFeatures)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading
* @tparam T ML instance type
* @return Instance loaded from file
*/
def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
val uid = instance.uid
val path = new File(tempDir, uid).getPath

Expand Down Expand Up @@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val load = instance.getClass.getMethod("load", classOf[String])
val another = load.invoke(instance, path).asInstanceOf[T]
assert(another.uid === instance.uid)
another
}
}

Expand Down

0 comments on commit 6e101d2

Please sign in to comment.