Skip to content

Commit

Permalink
remove external data deps
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 10, 2014
1 parent 9fd4933 commit 2b11211
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ package org.apache.spark.ml.classification

import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StandardScaler
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.sql.test.TestSQLContext._

class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
Expand All @@ -33,13 +29,10 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {

override def beforeAll(): Unit = {
super.beforeAll()
dataset = MLUtils.loadLibSVMFile(
sparkContext, "../data/mllib/sample_binary_classification_data.txt")
dataset.cache()
dataset = sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 1000, 42), 2)
}

override def afterAll(): Unit = {
dataset.unpersist()
dataset = null
super.afterAll()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,23 @@

package org.apache.spark.ml.tuning

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.FunSuite

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.test.TestSQLContext._

class CrossValidatorSuite extends FunSuite with BeforeAndAfterAll {
class CrossValidatorSuite extends FunSuite {

var dataset: SchemaRDD = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = sparkContext.makeRDD(generateLogisticInput(1.0, 1.0, 1000, 42), 2)
}
val dataset: SchemaRDD = sparkContext.makeRDD(generateLogisticInput(1.0, 1.0, 1000, 42), 2)

test("cross validation with logistic regression") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 100.0))
.addGrid(lr.maxIter, Array(2, 10))
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
Expand All @@ -48,7 +43,7 @@ class CrossValidatorSuite extends FunSuite with BeforeAndAfterAll {
.setNumFolds(3)
val cvModel = cv.fit(dataset)
val bestParamMap = cvModel.bestModel.fittingParamMap
assert(bestParamMap(lr.regParam) === 0.1)
assert(bestParamMap(lr.regParam) === 0.001)
assert(bestParamMap(lr.maxIter) === 10)
}
}

0 comments on commit 2b11211

Please sign in to comment.