Skip to content

Commit

Permalink
Refines tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Oct 21, 2014
1 parent 1d01074 commit a8c818d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,29 @@ private[sql] case class InMemoryRelation(
batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
}

// Statistics propagation contracts:
// 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
// 2. Only propagate statistics when `_statistics` is non-null
private def statisticsToBePropagated = if (_statistics == null) {
val updatedStats = statistics
if (_statistics == null) null else updatedStats
} else {
_statistics
}

override def statistics = if (_statistics == null) {
if (batchStats.value.isEmpty) {
// Underlying columnar RDD hasn't been materialized, no useful statistics information
// available, return the default statistics.
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
} else {
// Underlying columnar RDD has been materialized, required information has also been collected
// via the `batchStats` accumulator, compute the final statistics, and update `_statistics`.
_statistics = Statistics(sizeInBytes = computeSizeInBytes)
_statistics
}
} else {
// Pre-computed statistics
_statistics
}

Expand Down Expand Up @@ -129,7 +144,7 @@ private[sql] case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child)(
_cachedColumnBuffers, if (_statistics == null) statistics else _statistics)
_cachedColumnBuffers, statisticsToBePropagated)
}

override def children = Seq.empty
Expand All @@ -142,7 +157,7 @@ private[sql] case class InMemoryRelation(
storageLevel,
child)(
_cachedColumnBuffers,
if (_statistics == null) statistics else _statistics).asInstanceOf[this.type]
statisticsToBePropagated).asInstanceOf[this.type]
}

def cachedColumnBuffers = _cachedColumnBuffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite {
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._

case class IntegerData(i: Int)

class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = columnBatchSize
val originalInMemoryPartitionPruning = inMemoryPartitionPruning

override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
rawData.registerTempTable("intData")

val rawData = sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5)
rawData.registerTempTable("testData")

// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
Expand All @@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
}

before {
cacheTable("intData")
cacheTable("testData")
}

after {
uncacheTable("intData")
uncacheTable("testData")
}

// Comparisons
checkBatchPruning("i = 1", Seq(1), 1, 1)
checkBatchPruning("1 = i", Seq(1), 1, 1)
checkBatchPruning("i < 12", 1 to 11, 1, 2)
checkBatchPruning("i <= 11", 1 to 11, 1, 2)
checkBatchPruning("i > 88", 89 to 100, 1, 2)
checkBatchPruning("i >= 89", 89 to 100, 1, 2)
checkBatchPruning("12 > i", 1 to 11, 1, 2)
checkBatchPruning("11 >= i", 1 to 11, 1, 2)
checkBatchPruning("88 < i", 89 to 100, 1, 2)
checkBatchPruning("89 <= i", 89 to 100, 1, 2)
checkBatchPruning("SELECT key FROM testData WHERE key = 1", 1, 1)(Seq(1))
checkBatchPruning("SELECT key FROM testData WHERE 1 = key", 1, 1)(Seq(1))
checkBatchPruning("SELECT key FROM testData WHERE key < 12", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM testData WHERE key <= 11", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM testData WHERE key > 88", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM testData WHERE key >= 89", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM testData WHERE 12 > key", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM testData WHERE 11 >= key", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM testData WHERE 88 < key", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM testData WHERE 89 <= key", 1, 2)(89 to 100)

// IS NULL
checkBatchPruning("SELECT key FROM testData WHERE value IS NULL", 5, 5) {
(1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90)
}

// IS NOT NULL
checkBatchPruning("SELECT key FROM testData WHERE value IS NOT NULL", 5, 5) {
(11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100)
}

// Conjunction and disjunction
checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2)
checkBatchPruning("SELECT key FROM testData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21)
checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100))
checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) {
Seq(1) ++ (79 to 91)
}

// With unsupported predicate
checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10)
checkBatchPruning("SELECT key FROM testData WHERE NOT (key < 88)", 1, 2)(88 to 100)
checkBatchPruning("SELECT key FROM testData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11)

{
val seq = (1 to 30).mkString(", ")
checkBatchPruning(s"SELECT key FROM testData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
}

def checkBatchPruning(
filter: String,
expectedQueryResult: Seq[Int],
query: String,
expectedReadPartitions: Int,
expectedReadBatches: Int): Unit = {
expectedReadBatches: Int)(
expectedQueryResult: => Seq[Int]): Unit = {

test(filter) {
val query = sql(s"SELECT * FROM intData WHERE $filter")
test(query) {
val schemaRdd = sql(query)
assertResult(expectedQueryResult.toArray, "Wrong query result") {
query.collect().map(_.head).toArray
schemaRdd.collect().map(_.head).toArray
}

val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
}.head

Expand Down

0 comments on commit a8c818d

Please sign in to comment.