Skip to content

Commit

Permalink
Fix ModelInsights for xgboost (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinfaikan authored and tovbinm committed Nov 6, 2018
1 parent c83caef commit 904835a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 18 deletions.
37 changes: 22 additions & 15 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.table.Alignment._
import com.salesforce.op.utils.table.Table
import ml.dmlc.xgboost4j.scala.spark.OpXGBoost.RichBooster
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostRegressionModel}
import org.apache.spark.ml.classification._
import org.apache.spark.ml.regression._
Expand Down Expand Up @@ -99,7 +100,7 @@ case class ModelInsights
def prettyPrint(topK: Int = 15): String = {
val res = new ArrayBuffer[String]()
res ++= prettyValidationResults
res += prettySelectedModelInfo
res ++= prettySelectedModelInfo
res += modelEvaluationMetrics
res ++= topKCorrelations(topK)
res ++= topKContributions(topK)
Expand Down Expand Up @@ -151,7 +152,7 @@ case class ModelInsights
Seq(evalSummary, modelEvalRes.mkString("\n"))
}

private def prettySelectedModelInfo: String = {
private def prettySelectedModelInfo: Seq[String] = {
val excludedParams = Set(
SparkWrapperParams.SparkStageParamName,
ModelSelectorNames.outputParamName, ModelSelectorNames.inputParam1Name,
Expand All @@ -170,8 +171,10 @@ case class ModelInsights
val params = e.modelParameters.filterKeys(!excludedParams.contains(_))
Seq("name" -> e.modelName, "uid" -> e.modelUID, "modelType" -> e.modelType) ++ params
}).flatten.sortBy(_._1)
val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults)
table.prettyString()
if (validationResults.nonEmpty) {
val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults)
Seq(table.prettyString())
} else Seq.empty
}

private def modelEvaluationMetrics: String = {
Expand Down Expand Up @@ -527,10 +530,9 @@ case object ModelInsights {
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureDistributions: Array[FeatureDistribution]
): Seq[FeatureInsights] = {
val contributions = getModelContributions(model)

val featureInsights = (vectorInfo, summary) match {
case (Some(v), Some(s)) =>
val contributions = getModelContributions(model, Option(v.columns.length))
val droppedSet = s.dropped.toSet
val indexInToIndexKept = v.columns
.collect { case c if !droppedSet.contains(c.makeColName()) => c.index }
Expand Down Expand Up @@ -567,21 +569,24 @@ case object ModelInsights {
getIfExists(idx, s.categoricalStats(groupIdx).contingencyMatrix)
case _ => Map.empty[String, Double]
},
contribution = keptIndex.map(i => contributions.map(_.applyOrElse(i, Seq.empty))).getOrElse(Seq.empty),
contribution =
keptIndex.map(i => contributions.map(_.applyOrElse(i, (_: Int) => 0.0))).getOrElse(Seq.empty),
min = getIfExists(h.index, s.featuresStatistics.min),
max = getIfExists(h.index, s.featuresStatistics.max),
mean = getIfExists(h.index, s.featuresStatistics.mean),
variance = getIfExists(h.index, s.featuresStatistics.variance)
)
}
case (Some(v), None) => v.getColumnHistory().map { h =>
h.parentFeatureOrigins ->
Insights(
case (Some(v), None) =>
val contributions = getModelContributions(model, Option(v.columns.length))
v.getColumnHistory().map { h =>
h.parentFeatureOrigins -> Insights(
derivedFeatureName = h.columnName,
stagesApplied = h.parentFeatureStages,
derivedFeatureGroup = h.grouping,
derivedFeatureValue = h.indicatorValue,
contribution = contributions.map(_.applyOrElse(h.index, Seq.empty)) // nothing dropped without sanity check
contribution =
contributions.map(_.applyOrElse(h.index, (_: Int) => 0.0)) // nothing dropped without sanity check
)
}
case (None, _) => Seq.empty
Expand Down Expand Up @@ -631,7 +636,8 @@ case object ModelInsights {
}
}

private[op] def getModelContributions(model: Option[Model[_]]): Seq[Seq[Double]] = {
private[op] def getModelContributions
(model: Option[Model[_]], featureVectorSize: Option[Int] = None): Seq[Seq[Double]] = {
val stage = model.flatMap {
case m: SparkWrapperParams[_] => m.getSparkMlStage()
case _ => None
Expand All @@ -648,8 +654,8 @@ case object ModelInsights {
case m: RandomForestRegressionModel => Seq(m.featureImportances.toArray.toSeq)
case m: GBTRegressionModel => Seq(m.featureImportances.toArray.toSeq)
case m: GeneralizedLinearRegressionModel => Seq(m.coefficients.toArray.toSeq)
case m: XGBoostRegressionModel => Seq(m.nativeBooster.getFeatureScore().values.map(_.toDouble).toSeq)
case m: XGBoostClassificationModel => Seq(m.nativeBooster.getFeatureScore().values.map(_.toDouble).toSeq)
case m: XGBoostRegressionModel => Seq(m.nativeBooster.getFeatureScoreVector(featureVectorSize).toArray.toSeq)
case m: XGBoostClassificationModel => Seq(m.nativeBooster.getFeatureScoreVector(featureVectorSize).toArray.toSeq)
}
contributions.getOrElse(Seq.empty)
}
Expand All @@ -668,7 +674,8 @@ case object ModelInsights {
case p if p.param.name == OpPipelineStageParamsNames.InputFeatures =>
p.param.name -> p.value.asInstanceOf[Array[TransientFeature]].map(_.toJsonString()).mkString(", ")
case p if p.param.name != OpPipelineStageParamsNames.OutputMetadata &&
p.param.name != OpPipelineStageParamsNames.InputSchema => p.param.name -> p.value.toString
p.param.name != OpPipelineStageParamsNames.InputSchema && Option(p.value).nonEmpty =>
p.param.name -> p.value.toString
}.toMap
}
stages.map { s =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -81,6 +82,27 @@ case object OpXGBoost {
}
}

implicit class RichBooster(val booster: Booster) extends AnyVal {
/**
* Converts feature score map into a vector
*
* @param featureVectorSize size of feature vectors the xgboost model is trained on
* @return vector containing feature scores
*/
def getFeatureScoreVector(featureVectorSize: Option[Int] = None): Vector = {
val featureScore = booster.getFeatureScore()
require(featureScore.nonEmpty, "Feature score map is empty")
val indexScore = featureScore.map { case (fid, score) =>
val index = fid.tail.toInt
index -> score.toDouble
}.toSeq
val maxIndex = indexScore.map(_._1).max
require(featureVectorSize.forall(_ > maxIndex), "Feature vector size must be larger than max feature index")
val size = featureVectorSize.getOrElse(maxIndex + 1)
Vectors.sparse(size, indexScore)
}
}

/**
* Hack to access [[ml.dmlc.xgboost4j.scala.spark.XGBoost.removeMissingValues]] private method
*/
Expand Down
56 changes: 54 additions & 2 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ package com.salesforce.op
import com.salesforce.op.features.types._
import com.salesforce.op.features.{Feature, FeatureDistributionType}
import com.salesforce.op.filters.FeatureDistribution
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, BinaryClassificationModelsToTry, MultiClassificationModelSelector, OpLogisticRegression}
import com.salesforce.op.stages.impl.classification._
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, RegressionModelSelector}
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, OpXGBoostRegressor, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.SelectedModel
import com.salesforce.op.stages.impl.selector.ValidationType._
Expand Down Expand Up @@ -84,6 +84,13 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
val lrParams = new ParamGridBuilder().addGrid(lr.regParam, Array(0.01, 0.1)).build()
val models = Seq(lr -> lrParams).asInstanceOf[Seq[(EstimatorType, Array[ParamMap])]]

val xgbClassifier = new OpXGBoostClassifier().setSilent(1).setSeed(42L)
val xgbRegressor = new OpXGBoostRegressor().setSilent(1).setSeed(42L)
val xgbClassifierPred = xgbClassifier.setInput(label, features).getOutput()
val xgbRegressorPred = xgbRegressor.setInput(label, features).getOutput()
lazy val xgbWorkflow =
new OpWorkflow().setResultFeatures(xgbClassifierPred, xgbRegressorPred).setReader(dataReader)
lazy val xgbWorkflowModel = xgbWorkflow.train()

val pred = BinaryClassificationModelSelector
.withCrossValidation(seed = 42, splitter = Option(DataSplitter(seed = 42, reserveTestFraction = 0.1)),
Expand Down Expand Up @@ -543,4 +550,49 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
insights.features.foreach(f => f.distributions shouldBe empty)
}

it should "return model insights for xgboost classification" in {
noException should be thrownBy xgbWorkflowModel.modelInsights(xgbClassifierPred)
val insights = xgbWorkflowModel.modelInsights(xgbClassifierPred)
val ageInsights = insights.features.filter(_.featureName == age.name).head
val genderInsights = insights.features.filter(_.featureName == genderPL.name).head
insights.features.size shouldBe 5
insights.features.map(_.featureName).toSet shouldEqual rawNames
ageInsights.derivedFeatures.size shouldBe 2
ageInsights.derivedFeatures.foreach { f =>
f.contribution.size shouldBe 1
f.corr.isEmpty shouldBe true
f.variance.isEmpty shouldBe true
f.cramersV.isEmpty shouldBe true
}
genderInsights.derivedFeatures.size shouldBe 4
genderInsights.derivedFeatures.foreach { f =>
f.contribution.size shouldBe 1
f.corr.isEmpty shouldBe true
f.variance.isEmpty shouldBe true
f.cramersV.isEmpty shouldBe true
}
}

it should "return model insights for xgboost regression" in {
noException should be thrownBy xgbWorkflowModel.modelInsights(xgbRegressorPred)
val insights = xgbWorkflowModel.modelInsights(xgbRegressorPred)
val ageInsights = insights.features.filter(_.featureName == age.name).head
val genderInsights = insights.features.filter(_.featureName == genderPL.name).head
insights.features.size shouldBe 5
insights.features.map(_.featureName).toSet shouldEqual rawNames
ageInsights.derivedFeatures.size shouldBe 2
ageInsights.derivedFeatures.foreach { f =>
f.contribution.size shouldBe 1
f.corr.isEmpty shouldBe true
f.variance.isEmpty shouldBe true
f.cramersV.isEmpty shouldBe true
}
genderInsights.derivedFeatures.size shouldBe 4
genderInsights.derivedFeatures.foreach { f =>
f.contribution.size shouldBe 1
f.corr.isEmpty shouldBe true
f.variance.isEmpty shouldBe true
f.cramersV.isEmpty shouldBe true
}
}
}

0 comments on commit 904835a

Please sign in to comment.