diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 2653af994d7c6..1cd9341598723 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -149,7 +149,7 @@ class LogisticRegressionModel private[ml] ( if (map(probabilityCol) != "") { if (map(rawPredictionCol) != "") { val raw2prob: Vector => Vector = (rawPreds) => { - val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1)) + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) Vectors.dense(1.0 - prob1, prob1) } tmpData = tmpData.select(Star(None), @@ -171,7 +171,7 @@ class LogisticRegressionModel private[ml] ( predict.call(map(probabilityCol).attr) as map(predictionCol)) } else if (map(rawPredictionCol) != "") { val predict: Vector => Double = (rawPreds) => { - val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1)) + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) if (prob1 > t) 1.0 else 0.0 } tmpData = tmpData.select(Star(None), @@ -207,7 +207,7 @@ class LogisticRegressionModel private[ml] ( override protected def predictRaw(features: Vector): Vector = { val m = margin(features) - Vectors.dense(-m, m) + Vectors.dense(0.0, m) } override protected def copy(): LogisticRegressionModel = { diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 56a9dbdd58b64..50995ffef9ad5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java index 8662d68cd365b..1f47b711ac6d4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java @@ -30,12 +30,12 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.LabeledPoint; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegressionModel; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.api.java.JavaSQLContext; import org.apache.spark.sql.api.java.JavaSchemaRDD; import org.apache.spark.sql.api.java.Row; @@ -93,35 +93,14 @@ public void linearRegressionWithSetters() { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().get(lr.maxIter()).get() == 10); - assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0); + assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5); - assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1); + assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); assert(model2.getPredictionCol().equals("thePred")); } - - @Test - public void linearRegressionPredictorClassifierMethods() { - LinearRegression lr = new LinearRegression(); - - // fit() vs. train() - LinearRegressionModel model1 = lr.fit(dataset); - LinearRegressionModel model2 = lr.train(datasetRDD); - assert(model1.intercept() == model2.intercept()); - assert(model1.weights().equals(model2.weights())); - - // transform() vs. predict() - model1.transform(dataset).registerTempTable("transformed"); - JavaSchemaRDD trans = jsql.sql("SELECT prediction FROM transformed"); - JavaRDD preds = model1.predict(featuresRDD); - for (Tuple2 trans_pred: trans.zip(preds).collect()) { - double t = trans_pred._1().getDouble(0); - double p = trans_pred._2(); - assert(t == p); - } - } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 3ad15e516c16e..11acaa3a0d357 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.ml.classification; -import scala.Tuple2; - import java.io.Serializable; import java.lang.Math; import java.util.ArrayList; @@ -34,9 +32,8 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.ml.LabeledPoint; +import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Row; @@ -47,7 +44,6 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient DataFrame dataset; private transient JavaRDD datasetRDD; - private transient JavaRDD featuresRDD; private double eps = 1e-5; @Before @@ -60,9 +56,6 @@ public void setUp() { points.add(new LabeledPoint(lp.label(), lp.features())); } datasetRDD = jsc.parallelize(points, 2); - featuresRDD = datasetRDD.map(new Function() { - @Override public Vector call(LabeledPoint lp) { return lp.features(); } - }); dataset = jsql.applySchema(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @@ -79,13 +72,13 @@ public void logisticRegressionDefaultParams() { assert(lr.getLabelCol().equals("label")); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults assert(model.getThreshold() == 0.5); assert(model.getFeaturesCol().equals("features")); assert(model.getPredictionCol().equals("prediction")); - assert(model.getScoreCol().equals("score")); + assert(model.getProbabilityCol().equals("probability")); } @Test @@ -95,17 +88,17 @@ public void logisticRegressionWithSetters() { .setMaxIter(10) .setRegParam(1.0) .setThreshold(0.6) - .setScoreCol("probability"); + .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().get(lr.maxIter()).get() == 10); - assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0); - assert(model.fittingParamMap().get(lr.threshold()).get() == 0.6); + assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); assert(model.getThreshold() == 0.6); // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - SchemaRDD predAllZero = jsql.sql("SELECT prediction, probability FROM predAllZero"); + SchemaRDD predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { assert(r.getDouble(0) == 0.0); } @@ -117,7 +110,7 @@ public void logisticRegressionWithSetters() { predictions.collectAsList(); */ - model.transform(dataset, model.threshold().w(0.0), model.scoreCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -128,54 +121,37 @@ public void logisticRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.scoreCol().w("theProb")); - assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5); - assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1); - assert(model2.fittingParamMap().get(lr.threshold()).get() == 0.4); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); assert(model2.getThreshold() == 0.4); - assert(model2.getScoreCol().equals("theProb")); + assert(model2.getProbabilityCol().equals("theProb")); } + @SuppressWarnings("unchecked") @Test public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); - - // fit() vs. train() - LogisticRegressionModel model1 = lr.fit(dataset); - LogisticRegressionModel model2 = lr.train(datasetRDD); - assert(model1.intercept() == model2.intercept()); - assert(model1.weights().equals(model2.weights())); - assert(model1.numClasses() == model2.numClasses()); - assert(model1.numClasses() == 2); - - // transform() vs. predict() - model1.transform(dataset).registerTempTable("transformed"); - SchemaRDD trans = jsql.sql("SELECT prediction FROM transformed"); - JavaRDD preds = model1.predict(featuresRDD); - for (scala.Tuple2 trans_pred: trans.toJavaRDD().zip(preds).collect()) { - double t = trans_pred._1().getDouble(0); - double p = trans_pred._2(); - assert(t == p); + LogisticRegressionModel model = lr.fit(dataset); + assert(model.numClasses() == 2); + + model.transform(dataset).registerTempTable("transformed"); + SchemaRDD trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collect()) { + Vector raw = (Vector)row.get(0); + Vector prob = (Vector)row.get(1); + assert(raw.size() == 2); + assert(prob.size() == 2); + double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); + assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); + assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); } - // Check various types of predictions. - JavaRDD rawPredictions = model1.predictRaw(featuresRDD); - JavaRDD probabilities = model1.predictProbabilities(featuresRDD); - JavaRDD predictions = model1.predict(featuresRDD); - double threshold = model1.getThreshold(); - for (Tuple2 raw_prob: rawPredictions.zip(probabilities).collect()) { - Vector raw = raw_prob._1(); - Vector prob = raw_prob._2(); - for (int i = 0; i < raw.size(); ++i) { - double r = raw.apply(i); - double p = prob.apply(i); - double pFromR = 1.0 / (1.0 + Math.exp(-r)); - assert(Math.abs(r - pFromR) < eps); - } - } - for (Tuple2 prob_pred: probabilities.zip(predictions).collect()) { - Vector prob = prob_pred._1(); - double pred = prob_pred._2(); + SchemaRDD trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collect()) { + double pred = row.getDouble(0); + Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); for (int i = 0; i < prob.size(); ++i) { assert(probOfPred >= prob.apply(i)); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8cf7b81834918..f412622572c1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -82,7 +82,9 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .select('prediction, 'myProbability) .collect() .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0.0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") @@ -115,10 +117,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { // Compare rawPrediction with probability results.select('rawPrediction, 'probability).collect().map { case Row(raw: Vector, prob: Vector) => - val raw2prob: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m)) - raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) => - assert(r ~== p relTol eps) - } + assert(raw.size === 2) + assert(prob.size === 2) + val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) + assert(prob(1) ~== probFromRaw1 relTol eps) + assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) } // Compare prediction with probability