Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 10, 2014
1 parent 986593e commit 3df7952
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 26 deletions.
20 changes: 10 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,34 @@ class Param[T] (

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...

/** Specialized version of [[Param[Double]] for Java. */
/** Specialized version of [[Param[Double]]] for Java. */
class DoubleParam(parent: Params, name: String, doc: String, default: Option[Double] = None)
extends Param[Double](parent, name, doc, default) {
override def w(value: Double): ParamPair[Double] = ParamPair(this, value)
override def w(value: Double): ParamPair[Double] = super.w(value)
}

/** Specialized version of [[Param[Int]] for Java. */
/** Specialized version of [[Param[Int]]] for Java. */
class IntParam(parent: Params, name: String, doc: String, default: Option[Int] = None)
extends Param[Int](parent, name, doc, default) {
override def w(value: Int): ParamPair[Int] = ParamPair(this, value)
override def w(value: Int): ParamPair[Int] = super.w(value)
}

/** Specialized version of [[Param[Float]] for Java. */
/** Specialized version of [[Param[Float]]] for Java. */
class FloatParam(parent: Params, name: String, doc: String, default: Option[Float] = None)
extends Param[Float](parent, name, doc, default) {
override def w(value: Float): ParamPair[Float] = ParamPair(this, value)
override def w(value: Float): ParamPair[Float] = super.w(value)
}

/** Specialized version of [[Param[Long]] for Java. */
/** Specialized version of [[Param[Long]]] for Java. */
class LongParam(parent: Params, name: String, doc: String, default: Option[Long] = None)
extends Param[Long](parent, name, doc, default) {
override def w(value: Long): ParamPair[Long] = ParamPair(this, value)
override def w(value: Long): ParamPair[Long] = super.w(value)
}

/** Specilized version of [[Param[Boolean]] for Java. */
/** Specialized version of [[Param[Boolean]]] for Java. */
class BooleanParam(parent: Params, name: String, doc: String, default: Option[Boolean] = None)
extends Param[Boolean](parent, name, doc, default) {
override def w(value: Boolean): ParamPair[Boolean] = ParamPair(this, value)
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.{SchemaRDD, StructType}
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
trait CrossValidatorParams extends Params {
private[ml] trait CrossValidatorParams extends Params {
/** param for the estimator to be cross-validated */
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
def getEstimator: Estimator[_] = get(estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,14 @@ import org.apache.spark.sql.test.TestSQLContext._

class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {

var dataset: SchemaRDD = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 1000, 42), 2)
}

override def afterAll(): Unit = {
dataset = null
super.afterAll()
}
var dataset: SchemaRDD = sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 1000, 42), 2)

test("logistic regression") {
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
.select('label, 'prediction)
.collect()
.foreach(println)
}

test("logistic regression with setters") {
Expand All @@ -52,15 +41,15 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
.setRegParam(1.0)
val model = lr.fit(dataset)
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
.select('label, 'score, 'prediction).collect()
.foreach(println)
.select('label, 'score, 'prediction)
.collect()
}

test("logistic regression fit and transform with varargs") {
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
.select('label, 'probability, 'prediction)
.foreach(println)
.collect()
}
}

0 comments on commit 3df7952

Please sign in to comment.