Skip to content

Commit

Permalink
Local scoring for model with features of all types (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm authored Jun 22, 2019
1 parent 2f55e7a commit 5b82f42
Showing 1 changed file with 56 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,20 @@ package com.salesforce.op.local

import java.nio.file.Paths

import com.salesforce.op.features.Feature
import com.salesforce.op.features.types._
import com.salesforce.op.readers.DataFrameFieldNames._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, OpLogisticRegression}
import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestCommon}
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestCommon, TestFeatureBuilder}
import com.salesforce.op.testkit.{RandomList, RandomText}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichRow._
import com.salesforce.op.{OpWorkflow, OpWorkflowModel}
import com.salesforce.op.{OpWorkflow, OpWorkflowModel, UID}
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
Expand Down Expand Up @@ -77,7 +82,9 @@ class OpWorkflowModelLocalTest extends FlatSpec with PassengerSparkFixtureTest w
}
lazy val rawData = dataReader.generateDataFrame(model.getRawFeatures()).sort(KeyFieldName).collect().map(_.toMap)
lazy val expectedScores = model.score().sort(KeyFieldName).collect(prediction, survivedNum, indexed, deindexed)

lazy val modelLocation2 = {
Paths.get(tempDir.toString, "op-runner-local-test-model-2").toFile.getCanonicalFile.toString
}

Spec(classOf[OpWorkflowModelLocal]) should "produce scores without Spark" in {
val scoreFn = OpWorkflowModel.load(modelLocation).scoreFunction
Expand All @@ -104,6 +111,46 @@ class OpWorkflowModelLocalTest extends FlatSpec with PassengerSparkFixtureTest w
elapsed should be <= 10000L
}

it should "produce scores without Spark for all feature types" in {
// Generate features of all possible types
val numOfRows = 10
val (ds, features) = TestFeatureBuilder.random(numOfRows)(
// HashingTF transformer used in vectorization of text lists does not handle nulls well,
// therefore setting minLen = 1 for now
textLists = RandomList.ofTexts(RandomText.strings(0, 10), minLen = 1, maxLen = 10).limit(numOfRows)
)
// Prepare the label feature
val label = features.find(_.isSubtypeOf[RealNN]).head.asInstanceOf[Feature[RealNN]].transformWith(new Labelizer)

// Transmogrify all the features using default settings
val featureVector = features.transmogrify()

// Create a binary classification model selector with a single model type for simplicity
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit(
modelsAndParameters = Seq(new OpLogisticRegression() -> new ParamGridBuilder().build())
).setInput(label, featureVector).getOutput()

// Use id feature as row key
val id = features.find(_.isSubtypeOf[ID]).head.asInstanceOf[Feature[ID]].name
val keyFn = (r: Row) => r.getAs[String](id)
val workflow = new OpWorkflow().setInputDataset(ds, keyFn).setResultFeatures(prediction)
// Train, score and save the model
val model = workflow.train()
val expectedScoresDF = model.score()
val expectedScores = expectedScoresDF.sort(KeyFieldName).select(prediction.name).collect().map(_.toMap)
model.save(modelLocation2)

// Load and score the model
val scoreFn = OpWorkflowModel.load(modelLocation2).scoreFunction
scoreFn shouldBe a[ScoreFunction]
val rawData = ds.withColumn(KeyFieldName, col(id)).sort(KeyFieldName).collect().map(_.toMap)
val scores = rawData.map(scoreFn)
scores.length shouldBe expectedScores.length
for {((score, expected), i) <- scores.zip(expectedScores).zipWithIndex} withClue(s"Record index $i: ") {
score shouldBe expected
}
}

private def assert(
scores: Array[Map[String, Any]],
expectedScores: Array[(Prediction, RealNN, RealNN, Text)]
Expand All @@ -123,3 +170,9 @@ class OpWorkflowModelLocalTest extends FlatSpec with PassengerSparkFixtureTest w
}

}


class Labelizer(uid: String = UID[Labelizer]) extends UnaryTransformer[RealNN, RealNN]("labelizer", uid) {
override def outputIsResponse: Boolean = true
def transformFn: RealNN => RealNN = v => v.value.map(x => if (x > 0.0) 1.0 else 0.0).toRealNN(0.0)
}

0 comments on commit 5b82f42

Please sign in to comment.