Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track mean & standard deviation of text length as a metric for text feature #354

Merged
merged 78 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from 77 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
d504ace
starter code
TuanNguyen27 Jul 2, 2019
9cd2790
spaghetti code
TuanNguyen27 Jul 2, 2019
61c52c7
better place to put avgTextLen
TuanNguyen27 Jul 3, 2019
bf1c283
first fix of unit test
TuanNguyen27 Jul 3, 2019
d64d087
fix most tests
TuanNguyen27 Jul 3, 2019
3f5da2c
fix some styles
TuanNguyen27 Jul 3, 2019
6bb0256
fix more style
TuanNguyen27 Jul 3, 2019
0f91a40
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 3, 2019
cebf02d
handling division by zero
TuanNguyen27 Jul 3, 2019
c1e50ca
address comments
TuanNguyen27 Jul 5, 2019
75da25e
adding some doc on how to use text len cardinality
TuanNguyen27 Jul 8, 2019
2d7c233
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 8, 2019
9e0e2f9
add default value for avg text len
TuanNguyen27 Jul 8, 2019
47ad700
add docs
TuanNguyen27 Jul 8, 2019
42a47ba
fix scala style
TuanNguyen27 Jul 8, 2019
9082b86
delete extra line
TuanNguyen27 Jul 8, 2019
1c2235e
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 9, 2019
387f0ea
remove avgtextLength from doc
TuanNguyen27 Jul 9, 2019
fd64fd8
Merge branch 'tn/cardinality' of https://github.com/salesforce/Transm…
TuanNguyen27 Jul 9, 2019
0b27fed
starter code on moments & textstat
TuanNguyen27 Jul 10, 2019
3d1eea9
fix moments aggregation?
TuanNguyen27 Jul 10, 2019
af3d467
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 11, 2019
34da327
still broken
TuanNguyen27 Jul 11, 2019
37ff005
Merge branch 'tn/cardinality' of https://github.com/salesforce/Transm…
TuanNguyen27 Jul 11, 2019
7a69049
Merge branch 'master' into tn/cardinality
leahmcguire Jul 11, 2019
0a0ba98
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 12, 2019
35b6fe9
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 12, 2019
06ae7e3
finsh semi group adding logic
TuanNguyen27 Jul 15, 2019
88c7726
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 18, 2019
273fc8c
removing the old implementation
TuanNguyen27 Jul 19, 2019
6baf8a0
removing redundant code
TuanNguyen27 Jul 22, 2019
d2ad2ec
remove redundant changes to tests
TuanNguyen27 Jul 22, 2019
5c67bd2
remove more extra stuff
TuanNguyen27 Jul 22, 2019
4127a87
bump default value for maxCard here
TuanNguyen27 Jul 22, 2019
c6dff11
make cardinality and moments work across both text and numeric features
TuanNguyen27 Jul 23, 2019
0728855
rename variables
TuanNguyen27 Jul 23, 2019
797d5e6
wip
TuanNguyen27 Jul 24, 2019
3366362
wip
TuanNguyen27 Jul 24, 2019
2b09292
FeatureDistribution update
TuanNguyen27 Jul 24, 2019
f634049
wip
TuanNguyen27 Jul 25, 2019
f5c68a3
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 25, 2019
6e22db9
moving cardinality and moments calculation into feature distribution
TuanNguyen27 Jul 25, 2019
34af0eb
Merge branch 'tn/cardinality' of https://github.com/salesforce/Transm…
TuanNguyen27 Jul 25, 2019
0eb5f8d
wip
TuanNguyen27 Jul 25, 2019
10af178
update some compiler error
TuanNguyen27 Jul 25, 2019
0b902a0
update test
TuanNguyen27 Jul 26, 2019
520ac05
fix scala style
TuanNguyen27 Jul 26, 2019
1ea54e5
more fix
TuanNguyen27 Jul 26, 2019
c62d2db
wip
TuanNguyen27 Jul 26, 2019
52c0fc3
fix style error
TuanNguyen27 Jul 26, 2019
315a21a
update test to reflect new members of FeatureDistribution case class
TuanNguyen27 Jul 26, 2019
3c8f196
update printing
TuanNguyen27 Jul 26, 2019
be94fe3
added some tests
TuanNguyen27 Jul 27, 2019
bd4ef29
Update FeatureDistributionTest.scala
TuanNguyen27 Jul 28, 2019
d24d863
fix scala style
TuanNguyen27 Jul 28, 2019
9135987
add docs
TuanNguyen27 Jul 29, 2019
48a742f
move maxCard param to companion object
TuanNguyen27 Jul 29, 2019
3ab4f91
scala style fix
TuanNguyen27 Jul 29, 2019
2825443
fix string conversion
TuanNguyen27 Jul 29, 2019
3973b72
move MaxCardinality to RFF companion object
TuanNguyen27 Jul 29, 2019
5a99fc0
Try fixing test
TuanNguyen27 Jul 29, 2019
9f12d0b
Merge branch 'master' into tn/cardinality
TuanNguyen27 Jul 30, 2019
8f454d1
wip, need to make a bigger dummy dataset
TuanNguyen27 Jul 30, 2019
2fd2c68
Merge branch 'tn/cardinality' of https://github.com/salesforce/Transm…
TuanNguyen27 Jul 30, 2019
de826d5
more idiomatic scala
TuanNguyen27 Jul 30, 2019
8811be5
test Tuple2Semigroup
TuanNguyen27 Jul 30, 2019
6678746
wip
TuanNguyen27 Jul 30, 2019
4a569dc
wip
TuanNguyen27 Jul 30, 2019
d1d7dc0
wip
TuanNguyen27 Jul 30, 2019
29d1925
update test
TuanNguyen27 Jul 31, 2019
d72004f
fix scala style
TuanNguyen27 Jul 31, 2019
6802558
removing verbose lines
TuanNguyen27 Jul 31, 2019
88c082f
clean up test for cardinality and moments
TuanNguyen27 Aug 1, 2019
5b71fca
fix scala style
TuanNguyen27 Aug 1, 2019
85d207f
clean up summation of Option[Moments]
TuanNguyen27 Aug 1, 2019
c9bdc27
Changed the TextStats SemiGroup to a Monoid so that we can make an Op…
Jauntbox Aug 2, 2019
d3c36f2
Fix merge conflict
Jauntbox Aug 2, 2019
4736a9c
Merge branch 'master' into tn/cardinality
TuanNguyen27 Aug 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ package com.salesforce.op.filters
import java.util.Objects

import com.salesforce.op.features.{FeatureDistributionLike, FeatureDistributionType}
import com.salesforce.op.stages.impl.feature.{HashAlgorithm, Inclusion, NumericBucketizer}
import com.salesforce.op.stages.impl.feature.{HashAlgorithm, Inclusion, NumericBucketizer, TextStats}
import com.salesforce.op.utils.json.EnumEntrySerializer
import com.twitter.algebird.Monoid._
import com.twitter.algebird._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Semigroup
import org.apache.spark.mllib.feature.HashingTF
import org.json4s.jackson.Serialization
import org.json4s.{DefaultFormats, Formats}
Expand All @@ -63,6 +63,8 @@ case class FeatureDistribution
nulls: Long,
distribution: Array[Double],
summaryInfo: Array[Double],
moments: Option[Moments] = None,
cardEstimate: Option[TextStats] = None,
`type`: FeatureDistributionType = FeatureDistributionType.Training
) extends FeatureDistributionLike {

Expand Down Expand Up @@ -99,10 +101,19 @@ case class FeatureDistribution
*/
def reduce(fd: FeatureDistribution): FeatureDistribution = {
checkMatch(fd)
// should move this somewhere else
implicit val testStatsMonoid: Monoid[TextStats] = TextStats.monoid(RawFeatureFilter.MaxCardinality)
implicit val opMonoid = optionMonoid[TextStats]

val combinedDist = distribution + fd.distribution
// summary info can be empty or min max if hist is empty but should otherwise match so take the longest info
val combinedSummary = if (summaryInfo.length > fd.summaryInfo.length) summaryInfo else fd.summaryInfo
FeatureDistribution(name, key, count + fd.count, nulls + fd.nulls, combinedDist, combinedSummary, `type`)
val combinedSummaryInfo = if (summaryInfo.length > fd.summaryInfo.length) summaryInfo else fd.summaryInfo

val combinedMoments = moments + fd.moments
val combinedCard = cardEstimate + fd.cardEstimate

FeatureDistribution(name, key, count + fd.count, nulls + fd.nulls, combinedDist,
combinedSummaryInfo, combinedMoments, combinedCard, `type`)
}

/**
Expand Down Expand Up @@ -155,19 +166,23 @@ case class FeatureDistribution
"count" -> count.toString,
"nulls" -> nulls.toString,
"distribution" -> distribution.mkString("[", ",", "]"),
"summaryInfo" -> summaryInfo.mkString("[", ",", "]")
"summaryInfo" -> summaryInfo.mkString("[", ",", "]"),
"cardinality" -> cardEstimate.map(_.toString).getOrElse(""),
"moments" -> moments.map(_.toString).getOrElse("")
).map { case (n, v) => s"$n = $v" }.mkString(", ")

s"${getClass.getSimpleName}($valStr)"
}

override def equals(that: Any): Boolean = that match {
case FeatureDistribution(`name`, `key`, `count`, `nulls`, d, s, `type`) =>
distribution.deep == d.deep && summaryInfo.deep == s.deep
case FeatureDistribution(`name`, `key`, `count`, `nulls`, d, s, m, c, `type`) =>
distribution.deep == d.deep && summaryInfo.deep == s.deep &&
moments == m && cardEstimate == c
case _ => false
}

override def hashCode(): Int = Objects.hashCode(name, key, count, nulls, distribution, summaryInfo, `type`)
override def hashCode(): Int = Objects.hashCode(name, key, count, nulls, distribution,
summaryInfo, moments, cardEstimate, `type`)
}

object FeatureDistribution {
Expand Down Expand Up @@ -225,17 +240,51 @@ object FeatureDistribution {
value.map(seq => 0L -> histValues(seq, summary, bins, textBinsFormula))
.getOrElse(1L -> (Array(summary.min, summary.max, summary.sum, summary.count) -> new Array[Double](bins)))

val moments = value.map(momentsValues)
val cardEstimate = value.map(cardinalityValues)

FeatureDistribution(
name = name,
key = key,
count = 1L,
nulls = nullCount,
summaryInfo = summaryInfo,
distribution = distribution,
moments = moments,
cardEstimate = cardEstimate,
`type` = `type`
)
}

/**
* Function to calculate the first five central moments of numeric values, or length of tokens for text features
*
* @param values values to calculate moments
* @return Moments object containing information about moments
*/
private def momentsValues(values: ProcessedSeq): Moments = {
val population = values match {
case Left(seq) => seq.map(x => x.length.toDouble)
case Right(seq) => seq
}
MomentsGroup.sum(population.map(x => Moments(x)))
}

/**
* Function to track frequency of the first $(MaxCardinality) unique values
* (number for numeric features, token for text features)
*
* @param values values to track distribution / frequency
* @return TextStats object containing a Map from a value to its frequency (histogram)
*/
private def cardinalityValues(values: ProcessedSeq): TextStats = {
val population = values match {
case Left(seq) => seq
case Right(seq) => seq.map(_.toString)
}
TextStats(population.groupBy(identity).map{case (key, value) => (key, value.size)})
}

/**
* Function to put data into histogram of counts
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ object RawFeatureFilter {
// If there are not enough rows in the scoring set, we should not perform comparisons between the training and
// scoring sets since they will not be reliable. Currently, this is set to the same as the minimum training size.
val minScoringRowsDefault = 500
val MaxCardinality = 500


val stageName = classOf[RawFeatureFilter[_]].getSimpleName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Semigroup
import com.twitter.algebird.Monoid
import com.twitter.algebird.macros.caseclass
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Dataset, Encoder}
Expand Down Expand Up @@ -156,7 +156,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]]
val shouldCleanKeys = $(cleanKeys)
val shouldCleanValues = $(cleanText)

implicit val testStatsSG: Semigroup[TextMapStats] = TextMapStats.semiGroup(maxCard)
implicit val testStatsMonoid: Monoid[TextMapStats] = TextMapStats.monoid(maxCard)
val valueStats: Dataset[Array[TextMapStats]] = dataset.map(
_.map(computeTextMapStats(_, shouldCleanKeys, shouldCleanValues)).toArray
)
Expand Down Expand Up @@ -186,9 +186,9 @@ private[op] case class TextMapStats(keyValueCounts: Map[String, TextStats]) exte

private[op] object TextMapStats {

def semiGroup(maxCardinality: Int): Semigroup[TextMapStats] = {
implicit val testStatsSG: Semigroup[TextStats] = TextStats.semiGroup(maxCardinality)
caseclass.semigroup[TextMapStats]
def monoid(maxCardinality: Int): Monoid[TextMapStats] = {
implicit val testStatsMonoid: Monoid[TextStats] = TextStats.monoid(maxCardinality)
caseclass.monoid[TextMapStats]
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.salesforce.op.stages.impl.feature.VectorizerUtils._
import com.salesforce.op.utils.json.JsonLike
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.twitter.algebird.Monoid
import com.twitter.algebird.Monoid._
import com.twitter.algebird.Operators._
import com.twitter.algebird.Semigroup
Expand Down Expand Up @@ -82,7 +83,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
val maxCard = $(maxCardinality)
val shouldCleanText = $(cleanText)

implicit val testStatsSG: Semigroup[TextStats] = TextStats.semiGroup(maxCard)
implicit val testStatsMonoid: Semigroup[TextStats] = TextStats.monoid(maxCard)
val valueStats: Dataset[Array[TextStats]] = dataset.map(_.map(computeTextStats(_, shouldCleanText)).toArray)
val aggregatedStats: Array[TextStats] = valueStats.reduce(_ + _)

Expand Down Expand Up @@ -170,12 +171,14 @@ object SmartTextVectorizer {
private[op] case class TextStats(valueCounts: Map[String, Int]) extends JsonLike

private[op] object TextStats {
def semiGroup(maxCardinality: Int): Semigroup[TextStats] = new Semigroup[TextStats] {
def monoid(maxCardinality: Int): Monoid[TextStats] = new Monoid[TextStats] {
override def plus(l: TextStats, r: TextStats): TextStats = {
if (l.valueCounts.size > maxCardinality) l
else if (r.valueCounts.size > maxCardinality) r
else TextStats(l.valueCounts + r.valueCounts)
}

override def zero: TextStats = TextStats.empty
}

def empty: TextStats = TextStats(Map.empty)
Expand Down
55 changes: 46 additions & 9 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.junit.runner.RunWith
import com.salesforce.op.features.types.Real
import com.salesforce.op.stages.impl.feature.TextStats
import com.twitter.algebird.Moments
import org.apache.spark.sql.DataFrame
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import org.apache.spark.sql.functions._

import scala.util.{Failure, Success}

Expand Down Expand Up @@ -117,7 +120,6 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
val linearRegLabel = (smallNorm, bigNorm)
.zipped.map(_.toDouble.get * 5000 + _.toDouble.get).map(RealNN(_))
val labelStd = math.sqrt(5000 * 5000 * smallFeatureVariance + bigFeatureVariance)

def twoFeatureDF(feature1: List[Real], feature2: List[Real], label: List[RealNN]):
(Feature[RealNN], FeatureLike[OPVector], DataFrame) = {
val generatedData = feature1.zip(feature2).zip(label).map {
Expand Down Expand Up @@ -147,8 +149,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
.setInput(logRegDF._1, logRegDF._2).getOutput()

def getFeatureImp(standardizedModel: FeatureLike[Prediction],
unstandardizedModel: FeatureLike[Prediction],
DF: DataFrame): Array[Double] = {
unstandardizedModel: FeatureLike[Prediction], DF: DataFrame): Array[Double] = {
lazy val workFlow = new OpWorkflow()
.setResultFeatures(standardizedModel, unstandardizedModel).setInputDataset(DF)
lazy val model = workFlow.train()
Expand All @@ -163,6 +164,17 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
return Array(descaledsmallCoeff, originalsmallCoeff, descaledbigCoeff, orginalbigCoeff)
}

def getFeatureMomentsAndCard(inputModel: FeatureLike[Prediction],
DF: DataFrame): (Map[String, Moments], Map[String, TextStats]) = {
lazy val workFlow = new OpWorkflow().setResultFeatures(inputModel).setInputDataset(DF)
lazy val dummyReader = workFlow.getReader()
lazy val workFlowRFF = workFlow.withRawFeatureFilter(Some(dummyReader), None)
lazy val model = workFlowRFF.train()
val insights = model.modelInsights(inputModel)
val featureMoments = insights.features.map(f => f.featureName -> f.distributions.head.moments.get).toMap
val featureCardinality = insights.features.map(f => f.featureName -> f.distributions.head.cardEstimate.get).toMap
return (featureMoments, featureCardinality)
}

val params = new OpParams()

Expand Down Expand Up @@ -400,16 +412,16 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
case Failure(e) => fail(e)
case Success(deser) =>
insights.label shouldEqual deser.label
insights.features.zip(deser.features).foreach{
insights.features.zip(deser.features).foreach {
case (i, o) =>
i.featureName shouldEqual o.featureName
i.featureType shouldEqual o.featureType
i.derivedFeatures.zip(o.derivedFeatures).foreach{ case (ii, io) => ii.corr shouldEqual io.corr }
i.derivedFeatures.zip(o.derivedFeatures).foreach { case (ii, io) => ii.corr shouldEqual io.corr }
RawFeatureFilterResultsComparison.compareSeqMetrics(i.metrics, o.metrics)
RawFeatureFilterResultsComparison.compareSeqDistributions(i.distributions, o.distributions)
RawFeatureFilterResultsComparison.compareSeqExclusionReasons(i.exclusionReasons, o.exclusionReasons)
}
insights.selectedModelInfo.toSeq.zip(deser.selectedModelInfo.toSeq).foreach{
insights.selectedModelInfo.toSeq.zip(deser.selectedModelInfo.toSeq).foreach {
case (o, i) =>
o.validationType shouldEqual i.validationType
o.validationParameters.keySet shouldEqual i.validationParameters.keySet
Expand All @@ -420,7 +432,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
o.bestModelUID shouldEqual i.bestModelUID
o.bestModelName shouldEqual i.bestModelName
o.bestModelType shouldEqual i.bestModelType
o.validationResults.zip(i.validationResults).foreach{
o.validationResults.zip(i.validationResults).foreach {
case (ov, iv) => ov.metricValues shouldEqual iv.metricValues
ov.modelParameters.keySet shouldEqual iv.modelParameters.keySet
}
Expand Down Expand Up @@ -489,7 +501,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
paramsMapI("correlationType") shouldEqual paramsMapD("correlationType")
paramsMapI("jsDivergenceProtectedFeatures") shouldEqual paramsMapD("jsDivergenceProtectedFeatures")
paramsMapI("protectedFeatures") shouldEqual paramsMapD("protectedFeatures")
}
}
}
}

Expand Down Expand Up @@ -531,7 +543,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
supports = Array(1.0)
), CategoricalGroupStats(
group = "f0_f0_f2",
categoricalFeatures = Array( "f0_f0_f3_2"),
categoricalFeatures = Array("f0_f0_f3_2"),
contingencyMatrix = Map("0" -> Array(11.0, 12.0), "1" -> Array(12.0, 12.0), "2" -> Array(13.0, 12.0)),
cramersV = 6.3,
pointwiseMutualInfo = Map("0" -> Array(7.3), "1" -> Array(8.3), "2" -> Array(9.3)),
Expand Down Expand Up @@ -761,4 +773,29 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
absError / bigCoeffSum < tol shouldBe true
absError2 / smallCoeffSum < tol shouldBe true
}

it should "correctly return moments calculation and cardinality calculation for numeric features" in {

import spark.implicits._
val df = linRegDF._3
val meanTol = 0.01
val varTol = 0.01
val (moments, cardinality) = getFeatureMomentsAndCard(standardizedLinpred, linRegDF._3)

// Go through each feature and check that the mean, variance, and unique counts match the data
moments.foreach { case (featureName, value) => {
value.count shouldBe 1000
val (expectedMean, expectedVariance) =
df.select(avg(featureName), variance(featureName)).as[(Double, Double)].collect().head
math.abs((value.mean - expectedMean) / expectedMean) < meanTol shouldBe true
math.abs((value.variance - expectedVariance) / expectedVariance) < varTol shouldBe true
}
}

cardinality.foreach { case (featureName, value) => {
val actualUniques = df.select(featureName).as[Double].collect().toSet
value.valueCounts.keySet.map(_.toDouble).subsetOf(actualUniques) shouldBe true
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
package com.salesforce.op.filters

import com.salesforce.op.features.{FeatureDistributionType, TransientFeature}
import com.salesforce.op.stages.impl.feature.TextStats
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.testkit.RandomText
import com.twitter.algebird.Moments
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
Expand Down Expand Up @@ -73,6 +75,10 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
distribs(3).distribution.sum shouldBe 0
distribs(4).distribution.sum shouldBe 3
distribs(4).summaryInfo.length shouldBe bins
distribs(2).cardEstimate.get shouldBe TextStats(Map("male" -> 1, "female" -> 1))
distribs(2).moments.get shouldBe Moments(2, 5.0, 2.0, 0.0, 2.0)
distribs(4).cardEstimate.get shouldBe TextStats(Map("5.0" -> 1, "1.0" -> 1, "3.0" -> 1))
distribs(4).moments.get shouldBe Moments(3, 3.0, 8.0, 0.0, 32.0)
}

it should "be correctly created for text features" in {
Expand All @@ -93,6 +99,7 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
distribs(0).distribution.length shouldBe 100
distribs(0).distribution.sum shouldBe 10000
distribs.foreach(d => d.featureKey shouldBe d.name -> d.key)
distribs(0).moments.get.count shouldBe 10000
}

it should "be correctly created for map features" in {
Expand Down Expand Up @@ -189,12 +196,14 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
it should "have toString" in {
FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6), Array.empty).toString() shouldBe
"FeatureDistribution(type = Training, name = A, key = None, count = 10, nulls = 1, " +
"distribution = [1.0,4.0,0.0,0.0,6.0], summaryInfo = [])"
"distribution = [1.0,4.0,0.0,0.0,6.0], summaryInfo = [], cardinality = , moments = )"
}

it should "marshall to/from json" in {
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6), Array.empty)
val fd2 = FeatureDistribution("A", None, 20, 20, Array(2, 8, 0, 0, 12), Array.empty)
val fd2 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6),
Array.empty, Some(Moments(1.0)), Some(TextStats(Map("foo" -> 1, "bar" ->2))),
FeatureDistributionType.Scoring)
val json = FeatureDistribution.toJson(Array(fd1, fd2))
FeatureDistribution.fromJson(json) match {
case Success(r) => r shouldBe Seq(fd1, fd2)
Expand All @@ -203,7 +212,8 @@ class FeatureDistributionTest extends FlatSpec with PassengerSparkFixtureTest wi
}

it should "marshall to/from json with default vector args" in {
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6), Array.empty, FeatureDistributionType.Scoring)
val fd1 = FeatureDistribution("A", None, 10, 1, Array(1, 4, 0, 0, 6),
Array.empty, None, None, FeatureDistributionType.Scoring)
val fd2 = FeatureDistribution("A", Some("X"), 20, 20, Array(2, 8, 0, 0, 12), Array.empty)
val json =
"""[{"name":"A","count":10,"nulls":1,"distribution":[1.0,4.0,0.0,0.0,6.0],"type":"Scoring"},
Expand Down
Loading