Skip to content

Commit

Permalink
Merge pull request #17 from osopardo1/14-reduce-collect
Browse files Browse the repository at this point in the history
Reduce number of collect calls
  • Loading branch information
osopardo1 authored Oct 25, 2021
2 parents 8a91b75 + 523293d commit b4f061d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 50 deletions.
7 changes: 2 additions & 5 deletions src/main/scala/io/qbeast/spark/sql/files/OTreeIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ case class OTreeIndex(index: TahoeLogFileIndex, desiredCubeSize: Int)
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[AddFile] = {

val (qbeastDataFilters, tahoeDataFilters) = extractDataFilters(dataFilters)
val tahoeMatchingFiles = index.matchingFiles(partitionFilters, tahoeDataFilters)

val (qbeastDataFilters, _) = extractDataFilters(dataFilters)
val (minWeight, maxWeight) = extractWeightRange(qbeastDataFilters)
val files = sample(minWeight, maxWeight, tahoeMatchingFiles)
val files = sample(minWeight, maxWeight, qbeastSnapshot.allFiles)

files
}
Expand Down Expand Up @@ -95,7 +93,6 @@ case class OTreeIndex(index: TahoeLogFileIndex, desiredCubeSize: Int)

val filesVector = files.toVector
qbeastSnapshot.spaceRevisions
.collect()
.flatMap(spaceRevision => {
val querySpace = QuerySpaceFromTo(originalFrom, originalTo, spaceRevision)

Expand Down
67 changes: 30 additions & 37 deletions src/main/scala/io/qbeast/spark/sql/qbeast/QbeastSnapshot.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.spark.sql.delta.{DeltaLogFileIndex, Snapshot}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType}
import org.apache.spark.sql.{Dataset, DatasetFactory, SparkSession}

Expand All @@ -30,23 +29,25 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {

def isInitial: Boolean = snapshot.version == -1

lazy val allFiles = snapshot.allFiles.collect()

val indexedCols: Seq[String] = {
if (isInitial || snapshot.allFiles.isEmpty) Seq.empty
else ColumnsToIndex.decode(snapshot.allFiles.head.tags(indexedColsTag))
if (isInitial || allFiles.isEmpty) Seq.empty
else ColumnsToIndex.decode(allFiles.head.tags(indexedColsTag))
}

val dimensionCount: Int = indexedCols.length

private val spark = SparkSession.active
import spark.implicits._

private val logSchema = StructType(
Array(
StructField(name = cubeColumnName, dataType = BinaryType, nullable = false),
StructField(name = revisionColumnName, dataType = LongType, nullable = false)))

private def fileToDataframe(fileStatus: Array[FileStatus]): Dataset[(Array[Byte], Long)] = {

val spark = SparkSession.active
import spark.implicits._

val index = DeltaLogFileIndex(new ParquetFileFormat, fileStatus)

val relation = HadoopFsRelation(
Expand All @@ -67,9 +68,7 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* @return a SpaceRevision with the corresponding timestamp if any
*/
def getRevisionAt(timestamp: Long): Option[SpaceRevision] = {
val spaceRevision = spaceRevisions.filter(_.timestamp.equals(timestamp))
if (spaceRevision.isEmpty) None
else Some(spaceRevision.first())
spaceRevisions.find(_.timestamp.equals(timestamp))
}

/**
Expand All @@ -79,7 +78,6 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def cubeWeights(spaceRevision: SpaceRevision): Map[CubeId, Weight] = {
indexState(spaceRevision)
.collect()
.map(info => (CubeId(dimensionCount, info.cube), info.maxWeight))
.toMap
}
Expand All @@ -91,15 +89,12 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* @return a map with key cube and value max weight
*/
def cubeNormalizedWeights(spaceRevision: SpaceRevision): Map[CubeId, Double] = {
indexState(spaceRevision)
.collect()
.map {
case CubeInfo(cube, Weight.MaxValue, size) =>
(CubeId(dimensionCount, cube), NormalizedWeight(desiredCubeSize, size))
case CubeInfo(cube, maxWeight, _) =>
(CubeId(dimensionCount, cube), NormalizedWeight(maxWeight))
}
.toMap
indexState(spaceRevision).map {
case CubeInfo(cube, Weight.MaxValue, size) =>
(CubeId(dimensionCount, cube), NormalizedWeight(desiredCubeSize, size))
case CubeInfo(cube, maxWeight, _) =>
(CubeId(dimensionCount, cube), NormalizedWeight(maxWeight))
}.toMap
}

/**
Expand All @@ -111,7 +106,6 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
def overflowedSet(spaceRevision: SpaceRevision): Set[CubeId] = {
indexState(spaceRevision)
.filter(_.maxWeight != Weight.MaxValue)
.collect()
.map(cubeInfo => CubeId(dimensionCount, cubeInfo.cube))
.toSet
}
Expand All @@ -123,6 +117,7 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def replicatedSet(spaceRevision: SpaceRevision): Set[CubeId] = {

val spark = SparkSession.active
val hadoopConf = spark.sessionState.newHadoopConf()

snapshot.setTransactions.filter(_.appId.equals(indexId)) match {
Expand All @@ -149,33 +144,28 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* Returns available space revisions ordered by timestamp
* @return a Dataset of SpaceRevision
*/
def spaceRevisions: Dataset[SpaceRevision] =
snapshot.allFiles
.select(s"tags.$spaceTag")
def spaceRevisions: Seq[SpaceRevision] =
allFiles
.map(_.tags(spaceTag))
.distinct
.map(a => JsonUtils.fromJson[SpaceRevision](a.getString(0)))
.orderBy(col("timestamp").desc)
.map(a => JsonUtils.fromJson[SpaceRevision](a))
.sortBy(_.timestamp)

/**
* Returns the space revision with the higher timestamp
* @return the space revision
*/
def lastSpaceRevision: SpaceRevision = {
// Dataset spaceRevisions is ordered by timestamp
spaceRevisions
.first()

spaceRevisions.last
}

/**
* Returns the index state for the given space revision
* @param spaceRevision space revision
* @return Dataset containing cube information
*/
private def indexState(spaceRevision: SpaceRevision): Dataset[CubeInfo] = {

val allFiles = snapshot.allFiles
val weightValueTag = weightMaxTag + ".value"
private def indexState(spaceRevision: SpaceRevision): Seq[CubeInfo] = {

allFiles
.filter(_.tags(spaceTag).equals(spaceRevision.toString))
Expand All @@ -186,9 +176,13 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
Weight(a.tags(weightMinTag).toInt),
a.tags(stateTag),
a.tags(elementCountTag).toLong))
.groupBy(cubeTag)
.agg(min(weightValueTag), sum(elementCountTag))
.map(row => CubeInfo(row.getAs[String](0), Weight(row.getAs[Int](1)), row.getAs[Long](2)))
.groupBy(_.cube)
.map { case (cube: String, blocks: Array[BlockStats]) =>
val weightMax = blocks.map(_.maxWeight.value).min
val numElements = blocks.map(_.rowCount).sum
CubeInfo(cube, Weight(weightMax), numElements)
}
.toSeq
}

/**
Expand All @@ -199,11 +193,10 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def getCubeBlocks(cubes: Set[CubeId], spaceRevision: SpaceRevision): Seq[AddFile] = {
val dimensionCount = this.dimensionCount
snapshot.allFiles
allFiles
.filter(_.tags(spaceTag).equals(spaceRevision.toString))
.filter(_.tags(stateTag) != ANNOUNCED)
.filter(a => cubes.contains(CubeId(dimensionCount, a.tags(cubeTag))))
.collect()
}

}
4 changes: 2 additions & 2 deletions src/test/scala/io/qbeast/spark/index/NewRevisionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NewRevisionTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, 10000)

qbeastSnapshot.spaceRevisions.count() shouldBe spaceMultipliers.length
qbeastSnapshot.spaceRevisions.size shouldBe spaceMultipliers.length

}

Expand All @@ -55,7 +55,7 @@ class NewRevisionTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, 10000)

val allWM = qbeastSnapshot.spaceRevisions.collect().map(qbeastSnapshot.cubeWeights)
val allWM = qbeastSnapshot.spaceRevisions.map(qbeastSnapshot.cubeWeights)
allWM.foreach(wm => assert(wm.nonEmpty))
assert(allWM.distinct.length == allWM.length)

Expand Down
10 changes: 4 additions & 6 deletions src/test/scala/io/qbeast/spark/utils/QbeastSnapshotTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import io.qbeast.spark.index.OTreeAlgorithmTest.Client3
import io.qbeast.spark.index.{Weight}
import io.qbeast.spark.model.CubeInfo
import io.qbeast.spark.sql.qbeast.QbeastSnapshot
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.delta.DeltaLog
import org.scalatest.PrivateMethodTester
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -25,9 +25,8 @@ class QbeastSnapshotTest

val rdd =
spark.sparkContext.parallelize(
Seq(Client3(size * size, s"student-$size", 20, 1000 + 123, 2567.3432143)) ++
1.until(size)
.map(i => Client3(i * i, s"student-$i", 20 + i, 1000 + 123 + i, 2567.3432143 + i)))
1.to(size)
.map(i => Client3(i * i, s"student-$i", i, i * 1000 + 123, i * 2567.3432143)))

assert(rdd.count() == size)
spark.createDataFrame(rdd)
Expand Down Expand Up @@ -145,15 +144,14 @@ class QbeastSnapshotTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, oTreeAlgorithm.desiredCubeSize)

val indexStateMethod = PrivateMethod[Dataset[CubeInfo]]('indexState)
val indexStateMethod = PrivateMethod[Seq[CubeInfo]]('indexState)
val indexState =
qbeastSnapshot invokePrivate indexStateMethod(qbeastSnapshot.lastSpaceRevision)
val overflowed =
qbeastSnapshot.overflowedSet(qbeastSnapshot.lastSpaceRevision).map(_.string)

indexState
.filter(cubeInfo => overflowed.contains(cubeInfo.cube))
.collect()
.foreach(cubeInfo =>
assert(
cubeInfo.size > oTreeAlgorithm.desiredCubeSize * 0.9,
Expand Down

0 comments on commit b4f061d

Please sign in to comment.