Skip to content

Commit

Permalink
Rename to QbeastColumnStatsBuilder and add test spec
Browse files Browse the repository at this point in the history
  • Loading branch information
osopardo1 committed Dec 20, 2024
1 parent 6f67fd5 commit 81ab4a2
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 9 deletions.
25 changes: 18 additions & 7 deletions core/src/main/scala/io/qbeast/core/model/QbeastColumnStats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ import org.apache.spark.sql.SparkSession
/**
* Container for Qbeast Column Stats
*
* @param statsString
* @param columnStatsSchema
* the column stats schema
* @param columnStatsRow
* the column stats row
*/
case class QbeastColumnStats(
statsString: String,
columnStatsSchema: StructType,
columnStatsRow: Row)
case class QbeastColumnStats(columnStatsSchema: StructType, columnStatsRow: Row)

object QbeastColumnStats {
object QbeastColumnStatsBuilder {

/**
* Builds the column stats schema
Expand Down Expand Up @@ -67,6 +65,16 @@ object QbeastColumnStats {
columnStatsSchema
}

/**
* Builds the column stats row
*
* @param stats
* the stats in a JSON string
* @param columnStatsSchema
* the column stats schema
* @return
*/

def buildColumnStatsRow(stats: String, columnStatsSchema: StructType): Row = {
// If the stats are empty, return an empty row
if (stats.isEmpty) return Row.empty
Expand Down Expand Up @@ -95,8 +103,11 @@ object QbeastColumnStats {
* Builds the QbeastColumnStats
*
* @param statsString
* the stats in a JSON string
* @param columnTransformers
* the set of columnTransformers to build the Stats from
* @param dataSchema
* the data schema to build the Stats from
* @return
*/
def build(
Expand All @@ -105,7 +116,7 @@ object QbeastColumnStats {
dataSchema: StructType): QbeastColumnStats = {
val columnStatsSchema = buildColumnStatsSchema(dataSchema, columnTransformers)
val columnStatsRow = buildColumnStatsRow(statsString, columnStatsSchema)
QbeastColumnStats(statsString, columnStatsSchema, columnStatsRow)
QbeastColumnStats(columnStatsSchema, columnStatsRow)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package io.qbeast.spark.index

import io.qbeast.core.model.ColumnToIndex
import io.qbeast.core.model.QTableID
import io.qbeast.core.model.QbeastColumnStats
import io.qbeast.core.model.QbeastColumnStatsBuilder
import io.qbeast.core.model.QbeastOptions
import io.qbeast.core.model.Revision
import io.qbeast.core.model.RevisionChange
Expand Down Expand Up @@ -305,7 +305,8 @@ trait SparkRevisionChangesUtils extends StagingUtils with Logging {
// 1. Get the columnStats from the options
val columnStatsString = options.columnStats.getOrElse("")
// 2. Build the QbeastColumnStats
val qbeastColumnStats = QbeastColumnStats.build(columnStatsString, transformers, dataSchema)
val qbeastColumnStats =
QbeastColumnStatsBuilder.build(columnStatsString, transformers, dataSchema)
// 3. Compute transformations from the columnStats
val columnStatsRow = qbeastColumnStats.columnStatsRow
transformers.map { t =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package io.qbeast.core.model

import io.qbeast.core.transform.CDFQuantilesTransformer
import io.qbeast.core.transform.LinearTransformer
import io.qbeast.QbeastIntegrationTestSpec
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType

class QbeastColumnStatsBuilderTest extends QbeastIntegrationTestSpec {

"QbeastColumnStats" should "build the schema for linear transformations" in withSpark { _ =>
val dataSchema =
StructType(
Seq(
StructField("int_col", IntegerType),
StructField("float_col", FloatType),
StructField("long_col", LongType),
StructField("double_col", DoubleType)))
val columnTransformers = Seq(
LinearTransformer("int_col", IntegerDataType),
LinearTransformer("float_col", FloatDataType),
LinearTransformer("long_col", LongDataType),
LinearTransformer("double_col", DoubleDataType))

val statsString =
"""{"int_col_min":0,"int_col_max":0,
|"float_col_min":0.0,"float_col_max":0.0,
|"long_col_min":0,"long_col_max":0,
|"double_col_min":0.0,"double_col_max":0.0}""".stripMargin
val qbeastColumnStats =
QbeastColumnStatsBuilder.build(statsString, columnTransformers, dataSchema)
val qbeastColumnStatsSchema = qbeastColumnStats.columnStatsSchema
val qbeastColumnStatsRow = qbeastColumnStats.columnStatsRow

assert(qbeastColumnStatsSchema.fields.length != 0)
qbeastColumnStatsRow.getAs[Int]("int_col_min") shouldBe 0
qbeastColumnStatsRow.getAs[Int]("int_col_max") shouldBe 0
qbeastColumnStatsRow.getAs[Float]("float_col_min") shouldBe 0.0
qbeastColumnStatsRow.getAs[Float]("float_col_max") shouldBe 0.0
qbeastColumnStatsRow.getAs[Long]("long_col_min") shouldBe 0
qbeastColumnStatsRow.getAs[Long]("long_col_max") shouldBe 0
qbeastColumnStatsRow.getAs[Double]("double_col_min") shouldBe 0.0
}

it should "build the schema for quantiles" in withSpark { spark =>
val dataSchema =
StructType(Seq(StructField("int_col", IntegerType), StructField("string_col", StringType)))
val columnTransformers = Seq(CDFQuantilesTransformer("int_col", IntegerDataType))
val statsString = """{"int_col_quantiles":[0.0,0.25,0.5,0.75,1.0]}"""
val qbeastColumnStats =
QbeastColumnStatsBuilder.build(statsString, columnTransformers, dataSchema)

qbeastColumnStats.columnStatsSchema shouldBe StructType(
StructField("int_col_quantiles", ArrayType(DoubleType), nullable = true) :: Nil)
qbeastColumnStats.columnStatsRow.getAs[Array[Double]]("int_col_quantiles") shouldBe
Array(0.0, 0.25, 0.5, 0.75, 1.0)

}

}

0 comments on commit 81ab4a2

Please sign in to comment.