diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f5fca686df144..a88f52674102c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,13 +21,14 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams { + with LogisticRegressionParams with Writable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] ( // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + */ + override def write: Writer = new LogisticRegressionWriter(this) +} + + +/** [[Writer]] instance for [[LogisticRegressionModel]] */ +private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel) + extends Writer with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } +} + + +object LogisticRegressionModel extends Readable[LogisticRegressionModel] { + + override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader + + override def load(path: String): LogisticRegressionModel = read.load(path) } + +private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } +} + + /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index cbdf913ba8dfa..85f888c9f2f67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -175,6 +175,21 @@ trait Readable[T] { private[ml] class DefaultParamsWriter(instance: Params) extends Writer { override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + } +} + +private[ml] object DefaultParamsWriter { + + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - paramMap + */ + def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -201,14 +216,61 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer { private[ml] class DefaultParamsReader[T] extends Reader[T] { override def load(path: String): T = { - implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[T] + } +} + +private[ml] object DefaultParamsReader { + + /** + * All info from metadata file. + * @param params paramMap, as a [[JValue]] + * @param metadataStr Full metadata file String (for debugging) + */ + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadataStr: String) + + /** + * Load metadata from file. + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) - val cls = Utils.classForName((metadata \ "class").extract[String]) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] val uid = (metadata \ "uid").extract[String] - val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] - (metadata \ "paramMap") match { + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => val param = instance.getParam(paramName) @@ -216,8 +278,8 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] { instance.set(param, value) } case _ => - throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") } - instance.asInstanceOf[T] } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 325faf37e8eea..51b06b7eb6d53 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -31,7 +31,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ @@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + } + test("read/write") { + // Set some Params to make sure set Params are serialized. + val lr = new LogisticRegression() + .setElasticNetParam(0.1) + .setMaxIter(2) + .fit(dataset) + val lr2 = testDefaultReadWrite(lr) + assert(lr.intercept === lr2.intercept) + assert(lr.coefficients.toArray === lr2.coefficients.toArray) + assert(lr.numClasses === lr2.numClasses) + assert(lr.numFeatures === lr2.numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4545b0f281f5a..cac4bd9aa3ab8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading * @tparam T ML instance type + * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { + def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val load = instance.getClass.getMethod("load", classOf[String]) val another = load.invoke(instance, path).asInstanceOf[T] assert(another.uid === instance.uid) + another } }