Skip to content

Commit

Permalink
[SPARK-17073][SQL][FOLLOWUP] generate column-level statistics
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr adds some test cases for statistics: case sensitive column names, non ascii column names, refresh table, and also improves some documentation.

## How was this patch tested?
add test cases

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes apache#15360 from wzhfy/colStats2.
  • Loading branch information
wzhfy authored and Robert Kruszewski committed Oct 31, 2016
1 parent c3c39f0 commit a0b3a45
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ case class AnalyzeColumnCommand(

def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
val (rowCount, columnStats) = computeColStats(sparkSession, relation)
// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = Statistics(
sizeInBytes = newTotalSize,
rowCount = Some(rowCount),
colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map()))
// Newly computed column stats should override the existing ones.
colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats)
sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
// Refresh the cached data source table in the catalog.
sessionState.catalog.refreshTable(tableIdentWithDB)
Expand Down Expand Up @@ -90,8 +92,9 @@ case class AnalyzeColumnCommand(
}
}
if (duplicatedColumns.nonEmpty) {
logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " +
s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.")
logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " +
s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " +
s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.")
}

// Collect statistics per column.
Expand All @@ -116,42 +119,44 @@ case class AnalyzeColumnCommand(
}

object ColumnStatStruct {
val zero = Literal(0, LongType)
val one = Literal(1, LongType)
private val zero = Literal(0, LongType)
private val one = Literal(1, LongType)

def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
def max(e: Expression): Expression = Max(e)
def min(e: Expression): Expression = Min(e)
def ndv(e: Expression, relativeSD: Double): Expression = {
private def numNulls(e: Expression): Expression = {
if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
}
private def max(e: Expression): Expression = Max(e)
private def min(e: Expression): Expression = Min(e)
private def ndv(e: Expression, relativeSD: Double): Expression = {
// the approximate ndv should never be larger than the number of rows
Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
}
def avgLength(e: Expression): Expression = Average(Length(e))
def maxLength(e: Expression): Expression = Max(Length(e))
def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
private def avgLength(e: Expression): Expression = Average(Length(e))
private def maxLength(e: Expression): Expression = Max(Length(e))
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))

def getStruct(exprs: Seq[Expression]): CreateStruct = {
private def getStruct(exprs: Seq[Expression]): CreateStruct = {
CreateStruct(exprs.map { expr: Expression =>
expr.transformUp {
case af: AggregateFunction => af.toAggregateExpression()
}
})
}

def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
}

def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
}

def binaryColumnStat(e: Expression): Seq[Expression] = {
private def binaryColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), avgLength(e), maxLength(e))
}

def booleanColumnStat(e: Expression): Seq[Expression] = {
private def booleanColumnStat(e: Expression): Seq[Expression] = {
Seq(numNulls(e), numTrues(e), numFalses(e))
}

Expand All @@ -162,14 +167,14 @@ object ColumnStatStruct {
}
}

def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match {
def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match {
// Use aggregate functions to compute statistics we need.
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD))
case StringType => getStruct(stringColumnStat(e, relativeSD))
case BinaryType => getStruct(binaryColumnStat(e))
case BooleanType => getStruct(booleanColumnStat(e))
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
case StringType => getStruct(stringColumnStat(attr, relativeSD))
case BinaryType => getStruct(binaryColumnStat(attr))
case BooleanType => getStruct(booleanColumnStat(attr))
case otherType =>
throw new AnalysisException("Analyzing columns is not supported for column " +
s"${e.name} of data type: ${e.dataType}.")
s"${attr.name} of data type: ${attr.dataType}.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ object SQLConf {
val NDV_MAX_ERROR =
SQLConfigBuilder("spark.sql.statistics.ndv.maxError")
.internal()
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm.")
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " +
"column level statistics.")
.doubleConf
.createWithDefault(0.05)

Expand Down
198 changes: 166 additions & 32 deletions sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{File, PrintWriter}

import scala.reflect.ClassTag

import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
Expand Down Expand Up @@ -358,53 +358,187 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
}
}

test("generate column-level statistics and load them from hive metastore") {
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
val tableName = "tbl"
var statsBeforeUpdate: Statistics = null
var statsAfterUpdate: Statistics = null
withTable(tableName) {
val tableIndent = TableIdentifier(tableName, Some("default"))
val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
sql(s"CREATE TABLE $tableName (key int) USING PARQUET")
sql(s"INSERT INTO $tableName SELECT 1")
if (isAnalyzeColumns) {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
} else {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
}
// Table lookup will make the table cached.
catalog.lookupRelation(tableIndent)
statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent)
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get

sql(s"INSERT INTO $tableName SELECT 2")
if (isAnalyzeColumns) {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
} else {
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
}
catalog.lookupRelation(tableIndent)
statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent)
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get
}
(statsBeforeUpdate, statsAfterUpdate)
}

test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") {
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false)

assert(statsBeforeUpdate.sizeInBytes > 0)
assert(statsBeforeUpdate.rowCount == Some(1))

assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
assert(statsAfterUpdate.rowCount == Some(2))
}

test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)

assert(statsBeforeUpdate.sizeInBytes > 0)
assert(statsBeforeUpdate.rowCount == Some(1))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsBeforeUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)

assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
assert(statsAfterUpdate.rowCount == Some(2))
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = statsAfterUpdate.colStats("key"),
expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
rsd = spark.sessionState.conf.ndvMaxError)
}

private lazy val (testDataFrame, expectedColStatsSeq) = {
import testImplicits._

val intSeq = Seq(1, 2)
val stringSeq = Seq("a", "bb")
val binarySeq = Seq("a", "bb").map(_.getBytes)
val booleanSeq = Seq(true, false)

val data = intSeq.indices.map { i =>
(intSeq(i), stringSeq(i), booleanSeq(i))
(intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
}
val tableName = "table"
withTable(tableName) {
val df = data.toDF("c1", "c2", "c3")
df.write.format("parquet").saveAsTable(tableName)
val expectedColStatsSeq = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
}
(f, colStat)
val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { f =>
val colStat = f.dataType match {
case IntegerType =>
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
case StringType =>
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
case BinaryType =>
ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
binarySeq.map(_.length).max.toInt))
case BooleanType =>
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
booleanSeq.count(_.equals(false)).toLong))
}
(f, colStat)
}
(df, expectedColStatsSeq)
}

private def checkColStats(
tableName: String,
isDataSourceTable: Boolean,
expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
val readback = spark.table(tableName)
val stats = readback.queryExecution.analyzed.collect {
case rel: MetastoreRelation =>
assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
rel.catalogTable.stats.get
case rel: LogicalRelation =>
assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
rel.catalogTable.get.stats.get
}
assert(stats.length == 1)
val columnStats = stats.head.colStats
assert(columnStats.size == expectedColStatsSeq.length)
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = columnStats(field.name),
expectedColStat = expectedColStat,
rsd = spark.sessionState.conf.ndvMaxError)
}
}

sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3")
val readback = spark.table(tableName)
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
val columnStats = rel.catalogTable.get.stats.get.colStats
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
assert(columnStats.contains(field.name))
val colStat = columnStats(field.name)
test("generate and load column-level stats for data source table") {
val dsTable = "dsTable"
withTable(dsTable) {
testDataFrame.write.format("parquet").saveAsTable(dsTable)
sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
}
}

test("generate and load column-level stats for hive serde table") {
val hTable = "hTable"
val tmp = "tmp"
withTable(hTable, tmp) {
testDataFrame.write.format("parquet").saveAsTable(tmp)
sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
}
}

// When caseSensitive is on, for columns with only case difference, they are different columns
// and we should generate column stats for all of them.
private def checkCaseSensitiveColStats(columnName: String): Unit = {
val tableName = "tbl"
withTable(tableName) {
val column1 = columnName.toLowerCase
val column2 = columnName.toUpperCase
withSQLConf("spark.sql.caseSensitive" -> "true") {
sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
sql(s"INSERT INTO $tableName SELECT 1, 3.0")
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
val readback = spark.table(tableName)
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
val columnStats = rel.catalogTable.get.stats.get.colStats
assert(columnStats.size == 2)
StatisticsTest.checkColStat(
dataType = IntegerType,
colStat = columnStats(column1),
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
StatisticsTest.checkColStat(
dataType = field.dataType,
colStat = colStat,
expectedColStat = expectedColStat,
dataType = DoubleType,
colStat = columnStats(column2),
expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
rsd = spark.sessionState.conf.ndvMaxError)
rel
}
rel
assert(relations.size == 1)
}
assert(relations.size == 1)
}
}

test("check column statistics for case sensitive column names") {
checkCaseSensitiveColStats(columnName = "c1")
}

test("check column statistics for case sensitive non-ascii column names") {
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkCaseSensitiveColStats(columnName = "列c")
// scalastyle:on
}

test("estimates the size of a test MetastoreRelation") {
val df = sql("""SELECT * FROM src""")
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>
Expand Down

0 comments on commit a0b3a45

Please sign in to comment.