Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added aggregate & conditonal readers for parquet #172

Merged
merged 5 commits into from
Nov 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ object DataReaders {
aggregateParams: AggregateParams[T]
): AggregateCSVProductReader[T] = csvProduct(path, key, aggregateParams)

/**
* Creates a [[AggregateParquetProductReader]]
*/
def parquetProduct[T <: Product : Encoder : WeakTypeTag](
path: Option[String] = None,
key: T => String = randomKey _,
aggregateParams: AggregateParams[T]
): AggregateParquetProductReader[T] = new AggregateParquetProductReader[T](
readPath = path, key = key, aggregateParams = aggregateParams
)

/**
* Creates a [[AggregateParquetProductReader]], but is called parquetCase so it's easier to understand
*/
def parquetCase[T <: Product : Encoder : WeakTypeTag](
path: Option[String] = None,
key: T => String = randomKey _,
aggregateParams: AggregateParams[T]
): AggregateParquetProductReader[T] = parquetProduct(path, key, aggregateParams)

}

/**
Expand Down Expand Up @@ -232,6 +252,27 @@ object DataReaders {
conditionalParams: ConditionalParams[T]
): ConditionalCSVProductReader[T] = csvProduct(path, key, conditionalParams)

/**
* Creates a [[ConditionalParquetProductReader]]
*/
def parquetProduct[T <: Product : Encoder : WeakTypeTag]
(
path: Option[String] = None,
key: T => String = randomKey _,
conditionalParams: ConditionalParams[T]
): ConditionalParquetProductReader[T] = new ConditionalParquetProductReader[T](
readPath = path, key = key, conditionalParams = conditionalParams
)

/**
* Creates a [[ConditionalParquetProductReader]], but is called parquetCase so is easier to understand
*/
def parquetCase[T <: Product : Encoder : WeakTypeTag](
path: Option[String] = None,
key: T => String = randomKey _,
conditionalParams: ConditionalParams[T]
): ConditionalParquetProductReader[T] = parquetProduct(path, key, conditionalParams)

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,33 @@ class ParquetProductReader[T <: Product : Encoder]
maybeRepartition(data, params)
}
}

/**
* Data Reader for Parquet events, where there may be multiple records for a given key. Each parquet record
* will be automatically converted to type T that defines an [[Encoder]].
* @param readPath default path to data
* @param key function for extracting key from record
* @param aggregateParams params for time-based aggregation
* @tparam T
*/
class AggregateParquetProductReader[T <: Product : Encoder : WeakTypeTag]
(
readPath: Option[String],
key: T => String,
val aggregateParams: AggregateParams[T]
)extends ParquetProductReader[T](readPath, key) with AggregateDataReader[T]

/**
* Data Reader for Parquet events, when computing conditional probabilities. There may be multiple records for
* a given key. Each parquet record will be automatically converted to type T that defines an [[Encoder]].
* @param readPath default path to data
* @param key function for extracting key from record
* @param conditionalParams params for conditional aggregation
* @tparam T
*/
class ConditionalParquetProductReader[T <: Product : Encoder : WeakTypeTag]
(
readPath: Option[String],
key: T => String,
val conditionalParams: ConditionalParams[T]
)extends ParquetProductReader[T](readPath, key) with ConditionalDataReader[T]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix scalastyle warnings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tovbinm Seems like we have are going to have similar logic for testing these readers. Will look at the possibility of extracting these as Table Tests and update the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any tests included? Why not?

Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,13 @@

package com.salesforce.op.readers

import org.joda.time.Duration

import com.salesforce.op.aggregators.CutOffTime
import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.test.{TestCommon, TestSparkContext}
import com.salesforce.op.utils.io.csv.CSVOptions
import org.apache.spark.sql.Row
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import scala.reflect.runtime.universe._


// need this to be external to (not nested in) CSVProductReaderTest for spark sql to work correctly
Expand All @@ -61,23 +56,10 @@ case class PassengerCaseClass
)

@RunWith(classOf[JUnitRunner])
class CSVProductReadersTest extends FlatSpec with TestSparkContext {

def testDataPath: String = "../test-data"

def csvWithoutHeaderPath: String = s"$testDataPath/BigPassenger.csv"

def csvWithHeaderPath: String = s"$testDataPath/BigPassengerWithHeader.csv"

val age = FeatureBuilder.Integral[PassengerCaseClass]
.extract(_.age.toIntegral)
.asPredictor

val survived = FeatureBuilder.Binary[PassengerCaseClass]
.extract(_.survived.toBinary)
.aggregate(zero = Some(true), (l, r) => Some(l.getOrElse(false) && r.getOrElse(false)))
.asResponse
class CSVProductReadersTest extends FlatSpec with TestSparkContext with TestCommon {
def csvWithoutHeaderPath: String = s"$testDataDir/BigPassenger.csv"

def csvWithHeaderPath: String = s"$testDataDir/BigPassengerWithHeader.csv"

import spark.implicits._

Expand Down Expand Up @@ -115,50 +97,4 @@ class CSVProductReadersTest extends FlatSpec with TestSparkContext {
data.collect { case r if r.get(0) == "3" => r.get(1) } shouldBe Array(Array("this", "is", "a", "description"))
data.length shouldBe 8
}

Spec[AggregateCSVProductReader[_]] should "read and aggregate data correctly" in {
val dataReader = DataReaders.Aggregate.csvCase[PassengerCaseClass](
path = Some(csvWithoutHeaderPath),
key = _.passengerId.toString,
aggregateParams = AggregateParams(
timeStampFn = Some[PassengerCaseClass => Long](_.recordDate.getOrElse(0L)),
cutOffTime = CutOffTime.UnixEpoch(1471046600)
)
)

val data = dataReader.readDataset().collect()
data.foreach(_ shouldBe a[PassengerCaseClass])
data.length shouldBe 8

val aggregatedData = dataReader.generateDataFrame(rawFeatures = Array(age, survived)).collect()
aggregatedData.length shouldBe 6
aggregatedData.collect { case r if r.get(0) == "4" => r} shouldEqual Array(Row("4", 60, false))

dataReader.fullTypeName shouldBe typeOf[PassengerCaseClass].toString
}

Spec[ConditionalCSVProductReader[_]] should "read and conditionally aggregate data correctly" in {
val dataReader = DataReaders.Conditional.csvCase[PassengerCaseClass](
path = Some(csvWithoutHeaderPath),
key = _.passengerId.toString,
conditionalParams = ConditionalParams(
timeStampFn = _.recordDate.getOrElse(0L),
targetCondition = _.height.contains(186), // Function to figure out if target event has occurred
responseWindow = Some(Duration.millis(800)), // How many days after target event to aggregate for response
predictorWindow = None, // How many days before target event to include in predictor aggregation
timeStampToKeep = TimeStampToKeep.Min,
dropIfTargetConditionNotMet = true
)
)

val data = dataReader.readDataset().collect()
data.foreach(_ shouldBe a[PassengerCaseClass])
data.length shouldBe 8

val aggregatedData = dataReader.generateDataFrame(rawFeatures = Array(age, survived)).collect()
aggregatedData.length shouldBe 2
aggregatedData shouldEqual Array(Row("3", null, true), Row("4", 10, false))

dataReader.fullTypeName shouldBe typeOf[PassengerCaseClass].toString
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,80 @@ package com.salesforce.op.readers

import com.salesforce.op.OpParams
import com.salesforce.op.aggregators.CutOffTime
import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.test._
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.joda.time.Duration
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.reflect.runtime.universe._


@RunWith(classOf[JUnitRunner])
class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest {
class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest with TestCommon {
def csvWithoutHeaderPath: String = s"$testDataDir/BigPassenger.csv"

def csvWithHeaderPath: String = s"$testDataDir/BigPassengerWithHeader.csv"

def bigPassengerFilePath: String = s"$testDataDir/BigPassengerWithHeader.parquet"

import spark.implicits._

val agePredictor = FeatureBuilder.Integral[PassengerCaseClass]
.extract(_.age.toIntegral)
.asPredictor

val survivedResponse = FeatureBuilder.Binary[PassengerCaseClass]
.extract(_.survived.toBinary)
.aggregate(zero = Some(true), (l, r) => Some(l.getOrElse(false) && r.getOrElse(false)))
.asResponse

val aggregateParameters = AggregateParams(
timeStampFn = Some[PassengerCaseClass => Long](_.recordDate.getOrElse(0L)),
cutOffTime = CutOffTime.UnixEpoch(1471046600)
)

val conditionalParameters = ConditionalParams[PassengerCaseClass](
timeStampFn = _.recordDate.getOrElse(0L),
targetCondition = _.height.contains(186), // Function to figure out if target event has occurred
responseWindow = Some(Duration.millis(800)), // How many days after target event to aggregate for response
predictorWindow = None, // How many days before target event to include in predictor aggregation
timeStampToKeep = TimeStampToKeep.Min,
dropIfTargetConditionNotMet = true
)

val parquetAggReader = DataReaders.Aggregate.parquetCase[PassengerCaseClass](
path = Some(bigPassengerFilePath),
key = _.passengerId.toString,
aggregateParams = aggregateParameters
)

val csvAggReader = DataReaders.Aggregate.csvCase[PassengerCaseClass](
path = Some(csvWithoutHeaderPath),
key = _.passengerId.toString,
aggregateParams = aggregateParameters
)

val csvConditionalReader = DataReaders.Conditional.csvCase[PassengerCaseClass](
path = Some(csvWithoutHeaderPath),
key = _.passengerId.toString,
conditionalParams = conditionalParameters
)

val parquetConditionalReader = DataReaders.Conditional.parquetCase[PassengerCaseClass](
path = Some(bigPassengerFilePath),
key = _.passengerId.toString,
conditionalParams = conditionalParameters
)

val aggReaders = Seq(csvAggReader, parquetAggReader)

val conditionalReaders = Seq(csvConditionalReader, parquetConditionalReader)

// scalastyle:off
Spec(DataReaders.getClass) should "define readers" in {
Expand Down Expand Up @@ -112,9 +173,34 @@ class DataReadersTest extends FlatSpec with PassengerSparkFixtureTest {
error.getMessage shouldBe "Function is not serializable"
error.getCause shouldBe a[SparkException]
}

}

// TODO: test the readers
aggReaders.foreach( reader =>
Spec(reader.getClass) should "read and aggregate data correctly" in {
val data = reader.readDataset().collect()
data.foreach(_ shouldBe a[PassengerCaseClass])
data.length shouldBe 8

val aggregatedData = reader.generateDataFrame(rawFeatures = Array(agePredictor, survivedResponse)).collect()
aggregatedData.length shouldBe 6
aggregatedData.collect { case r if r.get(0) == "4" => r} shouldEqual Array(Row("4", 60, false))

reader.fullTypeName shouldBe typeOf[PassengerCaseClass].toString
}
)

conditionalReaders.foreach( reader =>
Spec(reader.getClass) should "read and conditionally aggregate data correctly" in {
val data = reader.readDataset().collect()
data.foreach(_ shouldBe a[PassengerCaseClass])
data.length shouldBe 8

val aggregatedData = reader.generateDataFrame(rawFeatures = Array(agePredictor, survivedResponse)).collect()
aggregatedData.length shouldBe 2
aggregatedData shouldEqual Array(Row("3", null, true), Row("4", 10, false))

reader.fullTypeName shouldBe typeOf[PassengerCaseClass].toString
}
)
}

Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ case class PassengerType

@RunWith(classOf[JUnitRunner])
class ParquetProductReaderTest extends FlatSpec with TestSparkContext with TestCommon {
def parquetFilePath: String = s"$testDataDir/PassengerDataAll.parquet"
def passengerFilePath: String = s"$testDataDir/PassengerDataAll.parquet"

val parquetRecordCount = 891

import spark.implicits._
val dataReader = new ParquetProductReader[PassengerType](
readPath = Some(parquetFilePath),
readPath = Some(passengerFilePath),
key = _.PassengerId.toString
)

Expand All @@ -75,7 +75,7 @@ class ParquetProductReaderTest extends FlatSpec with TestSparkContext with TestC

it should "read in byte arrays as valid strings" in {
val caseReader = DataReaders.Simple.parquetCase[PassengerType](
path = Some(parquetFilePath),
path = Some(passengerFilePath),
key = _.PassengerId.toString
)

Expand All @@ -85,7 +85,7 @@ class ParquetProductReaderTest extends FlatSpec with TestSparkContext with TestC

it should "map the columns of data to types defined in schema" in {
val caseReader = DataReaders.Simple.parquetCase[PassengerType](
path = Some(parquetFilePath),
path = Some(passengerFilePath),
key = _.PassengerId.toString
)

Expand Down
Binary file added test-data/BigPassengerWithHeader.parquet
Binary file not shown.