Skip to content

Commit

Permalink
Issue #511: Add support for custom functions in RollupDataWriter (#513)
Browse files Browse the repository at this point in the history
* Remove fileName from QbeastColumns and add interface for custom logic
  • Loading branch information
JosepSampe authored Dec 12, 2024
1 parent ea4bcd8 commit 1d8361d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 48 deletions.
33 changes: 6 additions & 27 deletions core/src/main/scala/io/qbeast/spark/index/QbeastColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,11 @@ object QbeastColumns {
val cubeColumnName = "_qbeastCube"

/**
* Cube to rollup column name.
* Destination file UUID column name.
*/
val fileUUIDColumnName = "_qbeastFileUUID"

/**
* Cube to rollup file name column name.
*/
val filenameColumnName = "_qbeastFilename"

val columnNames: Set[String] =
Set(weightColumnName, cubeColumnName, fileUUIDColumnName, filenameColumnName)
val columnNames: Set[String] = Set(weightColumnName, cubeColumnName, fileUUIDColumnName)

/**
* Creates an instance for a given data frame.
Expand All @@ -70,8 +64,7 @@ object QbeastColumns {
QbeastColumns(
weightColumnIndex = columnIndexes.getOrElse(weightColumnName, -1),
cubeColumnIndex = columnIndexes.getOrElse(cubeColumnName, -1),
fileUUIDColumnIndex = columnIndexes.getOrElse(fileUUIDColumnName, -1),
filenameColumnIndex = columnIndexes.getOrElse(filenameColumnName, -1))
fileUUIDColumnIndex = columnIndexes.getOrElse(fileUUIDColumnName, -1))
}

/**
Expand Down Expand Up @@ -106,14 +99,8 @@ object QbeastColumns {
* the cube column index or -1 if it is missing
* @param fileUUIDColumnIndex
* target file UUID column index or -1 if it is missing
* @param filenameColumnIndex
* the cube to rollup file name column index or -1 if it is missing
*/
case class QbeastColumns(
weightColumnIndex: Int,
cubeColumnIndex: Int,
fileUUIDColumnIndex: Int,
filenameColumnIndex: Int) {
case class QbeastColumns(weightColumnIndex: Int, cubeColumnIndex: Int, fileUUIDColumnIndex: Int) {

/**
* Returns whether a given column is one of the Qbeast columns.
Expand All @@ -126,8 +113,7 @@ case class QbeastColumns(
def contains(columnIndex: Int): Boolean = {
columnIndex == weightColumnIndex ||
columnIndex == cubeColumnIndex ||
columnIndex == fileUUIDColumnIndex ||
columnIndex == filenameColumnIndex
columnIndex == fileUUIDColumnIndex
}

/**
Expand All @@ -147,18 +133,11 @@ case class QbeastColumns(
def hasCubeColumn: Boolean = cubeColumnIndex >= 0

/**
* Returns whether the cube to rollup column exists.
* Returns whether the destination file UUID column exists.
*
* @return
* the cube to rollup column exists
*/
def hasFileUUIDColumn: Boolean = fileUUIDColumnIndex >= 0

/**
* Returns whether the filename column exists.
*
* @return
* the filename column exists
*/
def hasFilenameColumn: Boolean = filenameColumnIndex >= 0
}
41 changes: 27 additions & 14 deletions core/src/main/scala/io/qbeast/spark/writer/RollupDataWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,32 @@ import scala.collection.mutable
*/
trait RollupDataWriter extends DataWriter {

type GetCubeMaxWeight = CubeId => Weight
type Extract = InternalRow => (InternalRow, Weight, CubeId, String)
type WriteRows = Iterator[InternalRow] => Iterator[(IndexFile, TaskStats)]
type ProcessRows = (InternalRow, String) => (InternalRow, String)
private type GetCubeMaxWeight = CubeId => Weight
private type Extract = InternalRow => (InternalRow, Weight, CubeId, String)
private type WriteRows = Iterator[InternalRow] => Iterator[(IndexFile, TaskStats)]

protected def doWrite(
tableId: QTableID,
schema: StructType,
extendedData: DataFrame,
tableChanges: TableChanges,
trackers: Seq[WriteJobStatsTracker]): IISeq[(IndexFile, TaskStats)] = {
trackers: Seq[WriteJobStatsTracker],
processRow: Option[ProcessRows] = None): IISeq[(IndexFile, TaskStats)] = {

val revision = tableChanges.updatedRevision
val getCubeMaxWeight = { cubeId: CubeId =>
tableChanges.cubeWeight(cubeId).getOrElse(Weight.MaxValue)
}
val writeRows =
getWriteRows(tableId, schema, extendedData, revision, getCubeMaxWeight, trackers)
getWriteRows(
tableId,
schema,
extendedData,
revision,
getCubeMaxWeight,
trackers,
processRow)
extendedData
.repartition(col(QbeastColumns.fileUUIDColumnName))
.queryExecution
Expand All @@ -70,8 +79,9 @@ trait RollupDataWriter extends DataWriter {
extendedData: DataFrame,
revision: Revision,
getCubeMaxWeight: GetCubeMaxWeight,
trackers: Seq[WriteJobStatsTracker]): WriteRows = {
val extract = getExtract(extendedData, revision)
trackers: Seq[WriteJobStatsTracker],
processRow: Option[ProcessRows]): WriteRows = {
val extract = getExtract(extendedData, revision, processRow)
val revisionId = revision.revisionID
val writerFactory =
getIndexFileWriterFactory(tableId, schema, extendedData, revisionId, trackers)
Expand All @@ -88,7 +98,10 @@ trait RollupDataWriter extends DataWriter {
}
}

private def getExtract(extendedData: DataFrame, revision: Revision): Extract = {
private def getExtract(
extendedData: DataFrame,
revision: Revision,
processRow: Option[ProcessRows]): Extract = {
val schema = extendedData.schema
val qbeastColumns = QbeastColumns(extendedData)
val extractors = schema.fields.indices
Expand All @@ -97,16 +110,16 @@ trait RollupDataWriter extends DataWriter {
row.get(i, schema(i).dataType)
}
extendedRow => {
val fileUUID = extendedRow.getString(qbeastColumns.fileUUIDColumnIndex)
val row = InternalRow.fromSeq(extractors.map(_.apply(extendedRow)))
val (processedRow, filename) = processRow match {
case Some(func) => func(row, fileUUID)
case None => (row, s"$fileUUID.parquet")
}
val weight = Weight(extendedRow.getInt(qbeastColumns.weightColumnIndex))
val cubeIdBytes = extendedRow.getBinary(qbeastColumns.cubeColumnIndex)
val cubeId = revision.createCubeId(cubeIdBytes)
val filename = {
if (qbeastColumns.hasFilenameColumn)
extendedRow.getString(qbeastColumns.filenameColumnIndex)
else extendedRow.getString(qbeastColumns.fileUUIDColumnIndex) + ".parquet"
}
(row, weight, cubeId, filename)
(processedRow, weight, cubeId, filename)
}
}

Expand Down
8 changes: 1 addition & 7 deletions src/test/scala/io/qbeast/spark/index/QbeastColumnsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers {
QbeastColumns.weightColumnName should startWith("_qbeast")
QbeastColumns.cubeColumnName should startWith("_qbeast")
QbeastColumns.fileUUIDColumnName should startWith("_qbeast")
QbeastColumns.filenameColumnName should startWith("_qbeast")
}

it should "create instance from schema correctly" in {
Expand All @@ -50,26 +49,21 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers {
columns1.hasCubeColumn shouldBe true
columns1.fileUUIDColumnIndex shouldBe -1
columns1.hasFileUUIDColumn shouldBe false
columns1.filenameColumnIndex shouldBe -1
columns1.hasFilenameColumn shouldBe false

val schema2 = StructType(
Seq(
StructField("A", StringType),
StructField("B", StringType),
StructField("C", StringType),
StructField(QbeastColumns.fileUUIDColumnName, BinaryType),
StructField("D", StringType),
StructField(QbeastColumns.filenameColumnName, StringType)))
StructField("D", StringType)))
val columns2 = QbeastColumns(schema2)
columns2.weightColumnIndex shouldBe -1
columns2.hasWeightColumn shouldBe false
columns2.cubeColumnIndex shouldBe -1
columns2.hasCubeColumn shouldBe false
columns2.fileUUIDColumnIndex shouldBe 3
columns2.hasFileUUIDColumn shouldBe true
columns2.filenameColumnIndex shouldBe 5
columns2.hasFilenameColumn shouldBe true
}

it should "implement contains correctly" in {
Expand Down

0 comments on commit 1d8361d

Please sign in to comment.