Skip to content

Commit

Permalink
Add random test feature generator to generate datasets with features …
Browse files Browse the repository at this point in the history
…of *all* types (#298)
  • Loading branch information
tovbinm authored Apr 19, 2019
1 parent 213962e commit 3ac0a50
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 175 deletions.
1 change: 1 addition & 0 deletions features/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
dependencies {
compile project(':utils')
testCompile project(':testkit')

// Scala graph
compile "org.scala-graph:graph-core_$scalaVersion:$scalaGraphVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,18 @@ import java.util
import com.salesforce.op.aggregators._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.FeatureGeneratorStage
import com.salesforce.op.test.{Passenger, TestSparkContext}
import com.twitter.algebird.MonoidAggregator
import com.salesforce.op.test.{FeatureAsserts, Passenger, TestSparkContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.joda.time.Duration
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

import scala.reflect.runtime.universe._

case class FeatureBuilderContainerTest(s: String, l: Long, d: Double)

@RunWith(classOf[JUnitRunner])
class FeatureBuilderTest extends FlatSpec with TestSparkContext {
class FeatureBuilderTest extends FlatSpec with TestSparkContext with FeatureAsserts {
private val name = "feature"
private val passenger =
Passenger.newBuilder()
Expand Down Expand Up @@ -180,51 +179,3 @@ class FeatureBuilderTest extends FlatSpec with TestSparkContext {
}

}

/**
* Assert feature instance on a given input/output
*/
object assertFeature extends Matchers {

/**
* Assert feature instance on a given input/output
*
* @param f feature to assert
* @param in input value
* @param out expected output value
* @param name expected name
* @param isResponse is expected to be a response
* @param aggregator expected aggregator
* @param aggregateWindow expected aggregate window
* @param tti expected input typetag
* @param wtt expected output typetag
* @tparam I input type
* @tparam O output feature type
*/
def apply[I, O <: FeatureType](f: FeatureLike[O])(
in: I, out: O, name: String, isResponse: Boolean = false,
aggregator: WeakTypeTag[O] => MonoidAggregator[Event[O], _, O] =
(wtt: WeakTypeTag[O]) => MonoidAggregatorDefaults.aggregatorOf[O](wtt),
aggregateWindow: Option[Duration] = None
)(implicit tti: WeakTypeTag[I], wtt: WeakTypeTag[O]): Unit = {
f.name shouldBe name
f.isResponse shouldBe isResponse
f.parents shouldBe Nil
f.uid.startsWith(wtt.tpe.dealias.toString.split("\\.").last) shouldBe true
f.wtt.tpe =:= wtt.tpe shouldBe true
f.isRaw shouldBe true
f.typeName shouldBe wtt.tpe.typeSymbol.fullName

f.originStage shouldBe a[FeatureGeneratorStage[_, _ <: FeatureType]]
val fg = f.originStage.asInstanceOf[FeatureGeneratorStage[I, O]]
fg.tti shouldBe tti
fg.aggregator shouldBe aggregator(wtt)
fg.extractFn(in) shouldBe out
fg.extractSource.nonEmpty shouldBe true // TODO we should eval the code here: eval(fg.extractSource)(in)
fg.getOutputFeatureName shouldBe name
fg.outputIsResponse shouldBe isResponse
fg.aggregateWindow shouldBe aggregateWindow
fg.uid.startsWith(classOf[FeatureGeneratorStage[I, O]].getSimpleName) shouldBe true
}

}
89 changes: 89 additions & 0 deletions testkit/src/main/scala/com/salesforce/op/test/FeatureAsserts.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2017, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.salesforce.op.test

import com.salesforce.op.aggregators.{Event, MonoidAggregatorDefaults}
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.stages.FeatureGeneratorStage
import com.twitter.algebird.MonoidAggregator
import org.joda.time.Duration
import org.scalatest.Matchers

import scala.reflect.runtime.universe.WeakTypeTag

/**
* Asserts for Feature instances on a given input/output
*/
trait FeatureAsserts extends Matchers {

/**
* Assert Feature instance on a given input/output
*
* @param f feature to assert
* @param in input value
* @param out expected output value
* @param name expected name
* @param isResponse is expected to be a response
* @param aggregator expected aggregator
* @param aggregateWindow expected aggregate window
* @param tti expected input typetag
* @param wtt expected output typetag
* @tparam I input type
* @tparam O output feature type
*/
def assertFeature[I, O <: FeatureType](f: FeatureLike[O])(
in: I, out: O, name: String, isResponse: Boolean = false,
aggregator: WeakTypeTag[O] => MonoidAggregator[Event[O], _, O] =
(wtt: WeakTypeTag[O]) => MonoidAggregatorDefaults.aggregatorOf[O](wtt),
aggregateWindow: Option[Duration] = None
)(implicit tti: WeakTypeTag[I], wtt: WeakTypeTag[O]): Unit = {
f.name shouldBe name
f.isResponse shouldBe isResponse
f.parents shouldBe Nil
f.uid.startsWith(wtt.tpe.dealias.toString.split("\\.").last) shouldBe true
f.wtt.tpe =:= wtt.tpe shouldBe true
f.isRaw shouldBe true
f.typeName shouldBe wtt.tpe.typeSymbol.fullName

f.originStage shouldBe a[FeatureGeneratorStage[_, _ <: FeatureType]]
val fg = f.originStage.asInstanceOf[FeatureGeneratorStage[I, O]]
fg.tti shouldBe tti
fg.aggregator shouldBe aggregator(wtt)
fg.extractFn(in) shouldBe out
fg.extractSource.nonEmpty shouldBe true // TODO we should eval the code here: eval(fg.extractSource)(in)
fg.getOutputFeatureName shouldBe name
fg.outputIsResponse shouldBe isResponse
fg.aggregateWindow shouldBe aggregateWindow
fg.uid.startsWith(classOf[FeatureGeneratorStage[I, O]].getSimpleName) shouldBe true
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@

package com.salesforce.op.test

import com.salesforce.op.features.types.{FeatureType, FeatureTypeSparkConverter}
import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.features.{Feature, FeatureBuilder, FeatureSparkTypes}
import com.salesforce.op.testkit.RandomList.UniformGeolocation
import com.salesforce.op.testkit._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType

import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe._


/**
* Test Feature Builder is a factory for creating datasets and features for tests
*/
Expand Down Expand Up @@ -249,15 +255,161 @@ case object TestFeatureBuilder {
f5name = DefaultFeatureNames.f5, data)
}

/**
* Build a dataset with arbitrary amount of features of specified types
*
* @param data data
* @param spark spark session
* @return dataset with arbitrary amount of features of specified types
*/
def apply(data: Seq[FeatureType]*)(implicit spark: SparkSession): (DataFrame, Array[Feature[_ <: FeatureType]]) = {
val iterators = data.map(_.iterator).toArray
val rows = ArrayBuffer.empty[Row]
val featureValues = ArrayBuffer.empty[Array[FeatureType]]

while (iterators.forall(_.hasNext)) {
val vals: Array[FeatureType] = iterators.map(_.next())
val sparkVals = vals.map(FeatureTypeSparkConverter.toSpark)
rows += Row.fromSeq(sparkVals)
featureValues += vals
}

require(rows.nonEmpty && featureValues.nonEmpty, "Number of rows must be positive")

val features: Array[Feature[_ <: FeatureType]] = featureValues.head.zipWithIndex.map { case (f, i) =>
val wtt = FeatureType.featureTypeTag(f.getClass.getName).asInstanceOf[WeakTypeTag[FeatureType]]
feature[FeatureType](name = s"f${i + 1}")(wtt)
}.toArray

val schema = StructType(features.map(FeatureSparkTypes.toStructField(_)))
dataframeOfRows(schema, rows) -> features
}

private val InitDate = new SimpleDateFormat("dd/MM/yy").parse("18/04/19")

/**
* Build a dataset with random features of specified size
*
* @param numOfRows number of rows to generate (must be positive)
* @param spark spark session
* @return dataset with random features of specified size
*/
// scalastyle:off parameter.number
def random
(
numOfRows: Int = 10
)(
vectors: => Seq[OPVector] = RandomVector.sparse(RandomReal.normal[Real](), 10).limit(numOfRows),
textLists: => Seq[TextList] = RandomList.ofTexts(RandomText.strings(0, 10), maxLen = 10).limit(numOfRows),
dateLists: => Seq[DateList] = RandomList.ofDates(
RandomIntegral.dates(InitDate, 1000, 1000000), maxLen = 10
).limit(numOfRows),
dateTimeLists: => Seq[DateList] = RandomList.ofDateTimes(
RandomIntegral.datetimes(InitDate, 1000, 1000000), maxLen = 10
).limit(numOfRows),
geoLocations: => Seq[Geolocation] = RandomList.ofGeolocations.limit(numOfRows),
base64Maps: => Seq[Base64Map] = RandomMap.of[Base64, Base64Map](RandomText.base64(5, 10), 0, 5).limit(numOfRows),
binaryMaps: => Seq[BinaryMap] = RandomMap.ofBinaries(0.5, 0, 5).limit(numOfRows),
comboBoxMaps: => Seq[ComboBoxMap] = RandomMap.of[ComboBox, ComboBoxMap](
RandomText.comboBoxes(List("choice1", "choice2", "choice3")), 0, 5
).limit(numOfRows),
currencyMaps: => Seq[CurrencyMap] = RandomMap.ofReals[Currency, CurrencyMap](
RandomReal.poisson[Currency](5.0), 0, 5
).limit(numOfRows),
dateMaps: => Seq[DateMap] = RandomMap.of(
RandomIntegral.dates(InitDate, 1000, 1000000), 0, 5
).limit(numOfRows),
dateTimeMaps: => Seq[DateTimeMap] = RandomMap.of(
RandomIntegral.datetimes(InitDate, 1000, 1000000), 0, 5
).limit(numOfRows),
emailMaps: => Seq[EmailMap] = RandomMap.of(
RandomText.emailsOn(RandomStream.of(List("example.com", "test.com"))), 0, 5
).limit(numOfRows),
idMaps: => Seq[IDMap] = RandomMap.of[ID, IDMap](RandomText.ids, 0, 5).limit(numOfRows),
integralMaps: => Seq[IntegralMap] = RandomMap.of(RandomIntegral.integrals, 0, 5).limit(numOfRows),
multiPickListMaps: => Seq[MultiPickListMap] = RandomMap.ofMultiPickLists(
RandomMultiPickList.of(RandomText.countries, maxLen = 5), 0, 5
).limit(numOfRows),
percentMaps: => Seq[PercentMap] = RandomMap.ofReals[Percent, PercentMap](
RandomReal.normal[Percent](50, 5), 0, 5
).limit(numOfRows),
phoneMaps: => Seq[PhoneMap] = RandomMap.of[Phone, PhoneMap](RandomText.phones, 0, 5).limit(numOfRows),
pickListMaps: => Seq[PickListMap] = RandomMap.of[PickList, PickListMap](
RandomText.pickLists(List("pick1", "pick2", "pick3")), 0, 5
).limit(numOfRows),
realMaps: => Seq[RealMap] = RandomMap.ofReals[Real, RealMap](RandomReal.normal[Real](), 0, 5).limit(numOfRows),
textAreaMaps: => Seq[TextAreaMap] = RandomMap.of[TextArea, TextAreaMap](
RandomText.textAreas(0, 50), 0, 5
).limit(numOfRows),
textMaps: => Seq[TextMap] = RandomMap.of[Text, TextMap](RandomText.strings(0, 10), 0, 5).limit(numOfRows),
urlMaps: => Seq[URLMap] = RandomMap.of[URL, URLMap](RandomText.urls, 0, 5).limit(numOfRows),
countryMaps: => Seq[CountryMap] = RandomMap.of[Country, CountryMap](RandomText.countries, 0, 5).limit(numOfRows),
stateMaps: => Seq[StateMap] = RandomMap.of[State, StateMap](RandomText.states, 0, 5).limit(numOfRows),
cityMaps: => Seq[CityMap] = RandomMap.of[City, CityMap](RandomText.cities, 0, 5).limit(numOfRows),
postalCodeMaps: => Seq[PostalCodeMap] = RandomMap.of[PostalCode, PostalCodeMap](
RandomText.postalCodes, 0, 5
).limit(numOfRows),
streetMaps: => Seq[StreetMap] = RandomMap.of[Street, StreetMap](RandomText.streets, 0, 5).limit(numOfRows),
geoLocationMaps: => Seq[GeolocationMap] = RandomMap.ofGeolocations[UniformGeolocation](
RandomList.ofGeolocations, 0, 5
).limit(numOfRows),
binaries: => Seq[Binary] = RandomBinary(0.5).limit(numOfRows),
currencies: => Seq[Currency] = RandomReal.poisson[Currency](5.0).limit(numOfRows),
dates: => Seq[Date] = RandomIntegral.dates(InitDate, 1000, 1000000).limit(numOfRows),
dateTimes: => Seq[DateTime] = RandomIntegral.datetimes(InitDate, 1000, 1000000).limit(numOfRows),
integrals: => Seq[Integral] = RandomIntegral.integrals.limit(numOfRows),
percents: => Seq[Percent] = RandomReal.normal[Percent](50, 5).limit(numOfRows),
reals: => Seq[Real] = RandomReal.normal[Real]().limit(numOfRows),
realNNs: => Seq[RealNN] = RandomReal.normal[RealNN]().limit(numOfRows),
multiPickLists: => Seq[MultiPickList] = RandomMultiPickList.of(RandomText.countries, maxLen = 5).limit(numOfRows),
base64s: => Seq[Base64] = RandomText.base64(5, 10).limit(numOfRows),
comboBoxes: => Seq[ComboBox] = RandomText.comboBoxes(List("choice1", "choice2", "choice3")).limit(numOfRows),
emails: => Seq[Email] = RandomText.emailsOn(RandomStream.of(List("example.com", "test.com"))).limit(numOfRows),
ids: => Seq[ID] = RandomText.ids.limit(numOfRows),
phones: => Seq[Phone] = RandomText.phones.limit(numOfRows),
pickLists: => Seq[PickList] = RandomText.pickLists(List("pick1", "pick2", "pick3")).limit(numOfRows),
texts: => Seq[Text] = RandomText.base64(5, 10).limit(numOfRows),
textAreas: => Seq[TextArea] = RandomText.textAreas(0, 50).limit(numOfRows),
urls: => Seq[URL] = RandomText.urls.limit(numOfRows),
countries: => Seq[Country] = RandomText.countries.limit(numOfRows),
states: => Seq[State] = RandomText.states.limit(numOfRows),
cities: => Seq[City] = RandomText.cities.limit(numOfRows),
postalCodes: => Seq[PostalCode] = RandomText.postalCodes.limit(numOfRows),
streets: => Seq[Street] = RandomText.streets.limit(numOfRows)
)(implicit spark: SparkSession): (DataFrame, Array[Feature[_ <: FeatureType]]) = {

require(numOfRows > 0, "Number of rows must be positive")

val data: Array[Seq[FeatureType]] = Array(
vectors, textLists, dateLists, dateTimeLists, geoLocations,
base64Maps, binaryMaps, comboBoxMaps, currencyMaps, dateMaps,
dateTimeMaps, emailMaps, idMaps, integralMaps, multiPickListMaps,
percentMaps, phoneMaps, pickListMaps, realMaps, textAreaMaps,
textMaps, urlMaps, countryMaps, stateMaps, cityMaps,
postalCodeMaps, streetMaps, geoLocationMaps, binaries, currencies,
dates, dateTimes, integrals, percents, reals, realNNs,
multiPickLists, base64s, comboBoxes, emails, ids, phones,
pickLists, texts, textAreas, urls, countries, states,
cities, postalCodes, streets)

this.apply(data: _*)(spark)
}
// scalastyle:on

private def dataframe[T <: Product](schema: StructType, data: Seq[T])(implicit spark: SparkSession): DataFrame = {
val rows = data.map(p => Row.fromSeq(
p.productIterator.toSeq.map { case f: FeatureType => FeatureTypeSparkConverter.toSpark(f) }
))
dataframeOfRows(schema, rows)
}

private def dataframeOfRows(schema: StructType, data: Seq[Row])(implicit spark: SparkSession): DataFrame = {
import spark.implicits._
implicit val rowEncoder = RowEncoder(schema)

data.map(p => Row.fromSeq(
p.productIterator.toSeq.map { case f: FeatureType => FeatureTypeSparkConverter.toSpark(f) }
)).toDF()
data.toDF()
}

private def feature[T <: FeatureType](name: String)(implicit tt: WeakTypeTag[T]) =
FeatureBuilder.fromRow[T](name)(tt).asPredictor

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

package com.salesforce.op.testkit

import com.salesforce.op.features.types.{FeatureType}
import com.salesforce.op.features.types.FeatureType

import scala.language.postfixOps
import scala.util.Random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ object RandomReal {
): RandomReal[DataType] =
RandomReal[DataType](new GammaGenerator(shape, scale))


/**
* Generator of real-number feature types with log-normal distribution
*
Expand All @@ -155,8 +154,10 @@ object RandomReal {
* @tparam DataType the type of data
* @return a generator of reals
*/
def weibull[DataType <: Real : WeakTypeTag](alpha: Double = 1.0, beta: Double = 5.0):
RandomReal[DataType] = RandomReal[DataType](new WeibullGenerator(alpha, beta))
def weibull[DataType <: Real : WeakTypeTag](
alpha: Double = 1.0, beta: Double = 5.0
): RandomReal[DataType] =
RandomReal[DataType](new WeibullGenerator(alpha, beta))

class UniformDistribution(min: Double, max: Double) extends RandomDataGenerator[Double] {
private val source = new UniformGenerator
Expand Down
Loading

0 comments on commit 3ac0a50

Please sign in to comment.