From a8c818df4314a2a6a82b4272e9b49b28afa60507 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 21 Oct 2014 13:19:29 +0800 Subject: [PATCH] Refines tests --- .../columnar/InMemoryColumnarTableScan.scala | 19 ++++- .../spark/sql/columnar/ColumnStatsSuite.scala | 6 ++ .../columnar/PartitionBatchPruningSuite.scala | 76 ++++++++++++------- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index fad947be86311..ee63134f56d8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -66,14 +66,29 @@ private[sql] case class InMemoryRelation( batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum } + // Statistics propagation contracts: + // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data + // 2. Only propagate statistics when `_statistics` is non-null + private def statisticsToBePropagated = if (_statistics == null) { + val updatedStats = statistics + if (_statistics == null) null else updatedStats + } else { + _statistics + } + override def statistics = if (_statistics == null) { if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) } else { + // Underlying columnar RDD has been materialized, required information has also been collected + // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. _statistics = Statistics(sizeInBytes = computeSizeInBytes) _statistics } } else { + // Pre-computed statistics _statistics } @@ -129,7 +144,7 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child)( - _cachedColumnBuffers, if (_statistics == null) statistics else _statistics) + _cachedColumnBuffers, statisticsToBePropagated) } override def children = Seq.empty @@ -142,7 +157,7 @@ private[sql] case class InMemoryRelation( storageLevel, child)( _cachedColumnBuffers, - if (_statistics == null) statistics else _statistics).asInstanceOf[this.type] + statisticsToBePropagated).asInstanceOf[this.type] } def cachedColumnBuffers = _cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 6bdf741134e2f..a9f0851f8826c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite { assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index f53acc8c9f718..9fc077b181743 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,8 +22,6 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ -case class IntegerData(i: Int) - class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -31,8 +29,12 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) - rawData.registerTempTable("intData") + + val rawData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5) + rawData.registerTempTable("testData") // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") @@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be } before { - cacheTable("intData") + cacheTable("testData") } after { - uncacheTable("intData") + uncacheTable("testData") } // Comparisons - checkBatchPruning("i = 1", Seq(1), 1, 1) - checkBatchPruning("1 = i", Seq(1), 1, 1) - checkBatchPruning("i < 12", 1 to 11, 1, 2) - checkBatchPruning("i <= 11", 1 to 11, 1, 2) - checkBatchPruning("i > 88", 89 to 100, 1, 2) - checkBatchPruning("i >= 89", 89 to 100, 1, 2) - checkBatchPruning("12 > i", 1 to 11, 1, 2) - checkBatchPruning("11 >= i", 1 to 11, 1, 2) - checkBatchPruning("88 < i", 89 to 100, 1, 2) - checkBatchPruning("89 <= i", 89 to 100, 1, 2) + checkBatchPruning("SELECT key FROM testData WHERE key = 1", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM testData WHERE 1 = key", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM testData WHERE key < 12", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM testData WHERE key <= 11", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM testData WHERE key > 88", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM testData WHERE key >= 89", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM testData WHERE 12 > key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM testData WHERE 11 >= key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM testData WHERE 88 < key", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM testData WHERE 89 <= key", 1, 2)(89 to 100) + + // IS NULL + checkBatchPruning("SELECT key FROM testData WHERE value IS NULL", 5, 5) { + (1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90) + } + + // IS NOT NULL + checkBatchPruning("SELECT key FROM testData WHERE value IS NOT NULL", 5, 5) { + (11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100) + } // Conjunction and disjunction - checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) - checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) - checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) - checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) + checkBatchPruning("SELECT key FROM testData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21) + checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100)) + checkBatchPruning("SELECT key FROM testData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) { + Seq(1) ++ (79 to 91) + } // With unsupported predicate - checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) + checkBatchPruning("SELECT key FROM testData WHERE NOT (key < 88)", 1, 2)(88 to 100) + checkBatchPruning("SELECT key FROM testData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11) + + { + val seq = (1 to 30).mkString(", ") + checkBatchPruning(s"SELECT key FROM testData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100) + } def checkBatchPruning( - filter: String, - expectedQueryResult: Seq[Int], + query: String, expectedReadPartitions: Int, - expectedReadBatches: Int): Unit = { + expectedReadBatches: Int)( + expectedQueryResult: => Seq[Int]): Unit = { - test(filter) { - val query = sql(s"SELECT * FROM intData WHERE $filter") + test(query) { + val schemaRdd = sql(query) assertResult(expectedQueryResult.toArray, "Wrong query result") { - query.collect().map(_.head).toArray + schemaRdd.collect().map(_.head).toArray } - val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head