diff --git a/core/src/main/scala/io/qbeast/spark/index/QbeastColumns.scala b/core/src/main/scala/io/qbeast/spark/index/QbeastColumns.scala index fe12ba5ef..62d19761f 100644 --- a/core/src/main/scala/io/qbeast/spark/index/QbeastColumns.scala +++ b/core/src/main/scala/io/qbeast/spark/index/QbeastColumns.scala @@ -35,17 +35,11 @@ object QbeastColumns { val cubeColumnName = "_qbeastCube" /** - * Cube to rollup column name. + * Destination file UUID column name. */ val fileUUIDColumnName = "_qbeastFileUUID" - /** - * Cube to rollup file name column name. - */ - val filenameColumnName = "_qbeastFilename" - - val columnNames: Set[String] = - Set(weightColumnName, cubeColumnName, fileUUIDColumnName, filenameColumnName) + val columnNames: Set[String] = Set(weightColumnName, cubeColumnName, fileUUIDColumnName) /** * Creates an instance for a given data frame. @@ -70,8 +64,7 @@ object QbeastColumns { QbeastColumns( weightColumnIndex = columnIndexes.getOrElse(weightColumnName, -1), cubeColumnIndex = columnIndexes.getOrElse(cubeColumnName, -1), - fileUUIDColumnIndex = columnIndexes.getOrElse(fileUUIDColumnName, -1), - filenameColumnIndex = columnIndexes.getOrElse(filenameColumnName, -1)) + fileUUIDColumnIndex = columnIndexes.getOrElse(fileUUIDColumnName, -1)) } /** @@ -106,14 +99,8 @@ object QbeastColumns { * the cube column index or -1 if it is missing * @param fileUUIDColumnIndex * target file UUID column index or -1 if it is missing - * @param filenameColumnIndex - * the cube to rollup file name column index or -1 if it is missing */ -case class QbeastColumns( - weightColumnIndex: Int, - cubeColumnIndex: Int, - fileUUIDColumnIndex: Int, - filenameColumnIndex: Int) { +case class QbeastColumns(weightColumnIndex: Int, cubeColumnIndex: Int, fileUUIDColumnIndex: Int) { /** * Returns whether a given column is one of the Qbeast columns. @@ -126,8 +113,7 @@ case class QbeastColumns( def contains(columnIndex: Int): Boolean = { columnIndex == weightColumnIndex || columnIndex == cubeColumnIndex || - columnIndex == fileUUIDColumnIndex || - columnIndex == filenameColumnIndex + columnIndex == fileUUIDColumnIndex } /** @@ -147,18 +133,11 @@ case class QbeastColumns( def hasCubeColumn: Boolean = cubeColumnIndex >= 0 /** - * Returns whether the cube to rollup column exists. + * Returns whether the destination file UUID column exists. * * @return * the cube to rollup column exists */ def hasFileUUIDColumn: Boolean = fileUUIDColumnIndex >= 0 - /** - * Returns whether the filename column exists. - * - * @return - * the filename column exists - */ - def hasFilenameColumn: Boolean = filenameColumnIndex >= 0 } diff --git a/core/src/main/scala/io/qbeast/spark/writer/RollupDataWriter.scala b/core/src/main/scala/io/qbeast/spark/writer/RollupDataWriter.scala index afb450dd1..c08d79cc0 100644 --- a/core/src/main/scala/io/qbeast/spark/writer/RollupDataWriter.scala +++ b/core/src/main/scala/io/qbeast/spark/writer/RollupDataWriter.scala @@ -37,23 +37,32 @@ import scala.collection.mutable */ trait RollupDataWriter extends DataWriter { - type GetCubeMaxWeight = CubeId => Weight - type Extract = InternalRow => (InternalRow, Weight, CubeId, String) - type WriteRows = Iterator[InternalRow] => Iterator[(IndexFile, TaskStats)] + type ProcessRows = (InternalRow, String) => (InternalRow, String) + private type GetCubeMaxWeight = CubeId => Weight + private type Extract = InternalRow => (InternalRow, Weight, CubeId, String) + private type WriteRows = Iterator[InternalRow] => Iterator[(IndexFile, TaskStats)] protected def doWrite( tableId: QTableID, schema: StructType, extendedData: DataFrame, tableChanges: TableChanges, - trackers: Seq[WriteJobStatsTracker]): IISeq[(IndexFile, TaskStats)] = { + trackers: Seq[WriteJobStatsTracker], + processRow: Option[ProcessRows] = None): IISeq[(IndexFile, TaskStats)] = { val revision = tableChanges.updatedRevision val getCubeMaxWeight = { cubeId: CubeId => tableChanges.cubeWeight(cubeId).getOrElse(Weight.MaxValue) } val writeRows = - getWriteRows(tableId, schema, extendedData, revision, getCubeMaxWeight, trackers) + getWriteRows( + tableId, + schema, + extendedData, + revision, + getCubeMaxWeight, + trackers, + processRow) extendedData .repartition(col(QbeastColumns.fileUUIDColumnName)) .queryExecution @@ -70,8 +79,9 @@ trait RollupDataWriter extends DataWriter { extendedData: DataFrame, revision: Revision, getCubeMaxWeight: GetCubeMaxWeight, - trackers: Seq[WriteJobStatsTracker]): WriteRows = { - val extract = getExtract(extendedData, revision) + trackers: Seq[WriteJobStatsTracker], + processRow: Option[ProcessRows]): WriteRows = { + val extract = getExtract(extendedData, revision, processRow) val revisionId = revision.revisionID val writerFactory = getIndexFileWriterFactory(tableId, schema, extendedData, revisionId, trackers) @@ -88,7 +98,10 @@ trait RollupDataWriter extends DataWriter { } } - private def getExtract(extendedData: DataFrame, revision: Revision): Extract = { + private def getExtract( + extendedData: DataFrame, + revision: Revision, + processRow: Option[ProcessRows]): Extract = { val schema = extendedData.schema val qbeastColumns = QbeastColumns(extendedData) val extractors = schema.fields.indices @@ -97,16 +110,16 @@ trait RollupDataWriter extends DataWriter { row.get(i, schema(i).dataType) } extendedRow => { + val fileUUID = extendedRow.getString(qbeastColumns.fileUUIDColumnIndex) val row = InternalRow.fromSeq(extractors.map(_.apply(extendedRow))) + val (processedRow, filename) = processRow match { + case Some(func) => func(row, fileUUID) + case None => (row, s"$fileUUID.parquet") + } val weight = Weight(extendedRow.getInt(qbeastColumns.weightColumnIndex)) val cubeIdBytes = extendedRow.getBinary(qbeastColumns.cubeColumnIndex) val cubeId = revision.createCubeId(cubeIdBytes) - val filename = { - if (qbeastColumns.hasFilenameColumn) - extendedRow.getString(qbeastColumns.filenameColumnIndex) - else extendedRow.getString(qbeastColumns.fileUUIDColumnIndex) + ".parquet" - } - (row, weight, cubeId, filename) + (processedRow, weight, cubeId, filename) } } diff --git a/src/test/scala/io/qbeast/spark/index/QbeastColumnsTest.scala b/src/test/scala/io/qbeast/spark/index/QbeastColumnsTest.scala index cbd141c2c..5fb6b14b0 100644 --- a/src/test/scala/io/qbeast/spark/index/QbeastColumnsTest.scala +++ b/src/test/scala/io/qbeast/spark/index/QbeastColumnsTest.scala @@ -32,7 +32,6 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers { QbeastColumns.weightColumnName should startWith("_qbeast") QbeastColumns.cubeColumnName should startWith("_qbeast") QbeastColumns.fileUUIDColumnName should startWith("_qbeast") - QbeastColumns.filenameColumnName should startWith("_qbeast") } it should "create instance from schema correctly" in { @@ -50,8 +49,6 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers { columns1.hasCubeColumn shouldBe true columns1.fileUUIDColumnIndex shouldBe -1 columns1.hasFileUUIDColumn shouldBe false - columns1.filenameColumnIndex shouldBe -1 - columns1.hasFilenameColumn shouldBe false val schema2 = StructType( Seq( @@ -59,8 +56,7 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers { StructField("B", StringType), StructField("C", StringType), StructField(QbeastColumns.fileUUIDColumnName, BinaryType), - StructField("D", StringType), - StructField(QbeastColumns.filenameColumnName, StringType))) + StructField("D", StringType))) val columns2 = QbeastColumns(schema2) columns2.weightColumnIndex shouldBe -1 columns2.hasWeightColumn shouldBe false @@ -68,8 +64,6 @@ class QbeastColumnsTest extends AnyFlatSpec with Matchers { columns2.hasCubeColumn shouldBe false columns2.fileUUIDColumnIndex shouldBe 3 columns2.hasFileUUIDColumn shouldBe true - columns2.filenameColumnIndex shouldBe 5 - columns2.hasFilenameColumn shouldBe true } it should "implement contains correctly" in {