Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce number of collect calls #17

Merged
merged 3 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/main/scala/io/qbeast/spark/sql/files/OTreeIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,9 @@ case class OTreeIndex(index: TahoeLogFileIndex, desiredCubeSize: Int)
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[AddFile] = {

val (qbeastDataFilters, tahoeDataFilters) = extractDataFilters(dataFilters)
val tahoeMatchingFiles = index.matchingFiles(partitionFilters, tahoeDataFilters)

val (qbeastDataFilters, _) = extractDataFilters(dataFilters)
val (minWeight, maxWeight) = extractWeightRange(qbeastDataFilters)
val files = sample(minWeight, maxWeight, tahoeMatchingFiles)
val files = sample(minWeight, maxWeight, qbeastSnapshot.allFiles)

files
}
Expand Down Expand Up @@ -95,7 +93,6 @@ case class OTreeIndex(index: TahoeLogFileIndex, desiredCubeSize: Int)

val filesVector = files.toVector
qbeastSnapshot.spaceRevisions
.collect()
.flatMap(spaceRevision => {
val querySpace = QuerySpaceFromTo(originalFrom, originalTo, spaceRevision)

Expand Down
67 changes: 30 additions & 37 deletions src/main/scala/io/qbeast/spark/sql/qbeast/QbeastSnapshot.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.spark.sql.delta.{DeltaLogFileIndex, Snapshot}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType}
import org.apache.spark.sql.{Dataset, DatasetFactory, SparkSession}

Expand All @@ -30,23 +29,25 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {

def isInitial: Boolean = snapshot.version == -1

lazy val allFiles = snapshot.allFiles.collect()
cugni marked this conversation as resolved.
Show resolved Hide resolved

val indexedCols: Seq[String] = {
if (isInitial || snapshot.allFiles.isEmpty) Seq.empty
else ColumnsToIndex.decode(snapshot.allFiles.head.tags(indexedColsTag))
if (isInitial || allFiles.isEmpty) Seq.empty
else ColumnsToIndex.decode(allFiles.head.tags(indexedColsTag))
}

val dimensionCount: Int = indexedCols.length

private val spark = SparkSession.active
import spark.implicits._

private val logSchema = StructType(
Array(
StructField(name = cubeColumnName, dataType = BinaryType, nullable = false),
StructField(name = revisionColumnName, dataType = LongType, nullable = false)))

private def fileToDataframe(fileStatus: Array[FileStatus]): Dataset[(Array[Byte], Long)] = {

val spark = SparkSession.active
import spark.implicits._

val index = DeltaLogFileIndex(new ParquetFileFormat, fileStatus)

val relation = HadoopFsRelation(
Expand All @@ -67,9 +68,7 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* @return a SpaceRevision with the corresponding timestamp if any
*/
def getRevisionAt(timestamp: Long): Option[SpaceRevision] = {
val spaceRevision = spaceRevisions.filter(_.timestamp.equals(timestamp))
if (spaceRevision.isEmpty) None
else Some(spaceRevision.first())
spaceRevisions.find(_.timestamp.equals(timestamp))
}

/**
Expand All @@ -79,7 +78,6 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def cubeWeights(spaceRevision: SpaceRevision): Map[CubeId, Weight] = {
indexState(spaceRevision)
.collect()
.map(info => (CubeId(dimensionCount, info.cube), info.maxWeight))
.toMap
}
Expand All @@ -91,15 +89,12 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* @return a map with key cube and value max weight
*/
def cubeNormalizedWeights(spaceRevision: SpaceRevision): Map[CubeId, Double] = {
indexState(spaceRevision)
.collect()
.map {
case CubeInfo(cube, Weight.MaxValue, size) =>
(CubeId(dimensionCount, cube), NormalizedWeight(desiredCubeSize, size))
case CubeInfo(cube, maxWeight, _) =>
(CubeId(dimensionCount, cube), NormalizedWeight(maxWeight))
}
.toMap
indexState(spaceRevision).map {
case CubeInfo(cube, Weight.MaxValue, size) =>
(CubeId(dimensionCount, cube), NormalizedWeight(desiredCubeSize, size))
case CubeInfo(cube, maxWeight, _) =>
(CubeId(dimensionCount, cube), NormalizedWeight(maxWeight))
}.toMap
}

/**
Expand All @@ -111,7 +106,6 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
def overflowedSet(spaceRevision: SpaceRevision): Set[CubeId] = {
indexState(spaceRevision)
.filter(_.maxWeight != Weight.MaxValue)
.collect()
.map(cubeInfo => CubeId(dimensionCount, cubeInfo.cube))
.toSet
}
Expand All @@ -123,6 +117,7 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def replicatedSet(spaceRevision: SpaceRevision): Set[CubeId] = {

val spark = SparkSession.active
val hadoopConf = spark.sessionState.newHadoopConf()

snapshot.setTransactions.filter(_.appId.equals(indexId)) match {
Expand All @@ -149,33 +144,28 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
* Returns available space revisions ordered by timestamp
* @return a Dataset of SpaceRevision
*/
def spaceRevisions: Dataset[SpaceRevision] =
snapshot.allFiles
.select(s"tags.$spaceTag")
def spaceRevisions: Seq[SpaceRevision] =
allFiles
.map(_.tags(spaceTag))
.distinct
.map(a => JsonUtils.fromJson[SpaceRevision](a.getString(0)))
.orderBy(col("timestamp").desc)
.map(a => JsonUtils.fromJson[SpaceRevision](a))
.sortBy(_.timestamp)

/**
* Returns the space revision with the higher timestamp
* @return the space revision
*/
def lastSpaceRevision: SpaceRevision = {
// Dataset spaceRevisions is ordered by timestamp
spaceRevisions
.first()

spaceRevisions.last
}

/**
* Returns the index state for the given space revision
* @param spaceRevision space revision
* @return Dataset containing cube information
*/
private def indexState(spaceRevision: SpaceRevision): Dataset[CubeInfo] = {

val allFiles = snapshot.allFiles
val weightValueTag = weightMaxTag + ".value"
private def indexState(spaceRevision: SpaceRevision): Seq[CubeInfo] = {

allFiles
.filter(_.tags(spaceTag).equals(spaceRevision.toString))
Expand All @@ -186,9 +176,13 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
Weight(a.tags(weightMinTag).toInt),
a.tags(stateTag),
a.tags(elementCountTag).toLong))
.groupBy(cubeTag)
.agg(min(weightValueTag), sum(elementCountTag))
.map(row => CubeInfo(row.getAs[String](0), Weight(row.getAs[Int](1)), row.getAs[Long](2)))
.groupBy(_.cube)
.map { case (cube: String, blocks: Array[BlockStats]) =>
val weightMax = blocks.map(_.maxWeight.value).min
val numElements = blocks.map(_.rowCount).sum
CubeInfo(cube, Weight(weightMax), numElements)
}
.toSeq
}

/**
Expand All @@ -199,11 +193,10 @@ case class QbeastSnapshot(snapshot: Snapshot, desiredCubeSize: Int) {
*/
def getCubeBlocks(cubes: Set[CubeId], spaceRevision: SpaceRevision): Seq[AddFile] = {
val dimensionCount = this.dimensionCount
snapshot.allFiles
allFiles
.filter(_.tags(spaceTag).equals(spaceRevision.toString))
.filter(_.tags(stateTag) != ANNOUNCED)
.filter(a => cubes.contains(CubeId(dimensionCount, a.tags(cubeTag))))
.collect()
}

}
4 changes: 2 additions & 2 deletions src/test/scala/io/qbeast/spark/index/NewRevisionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NewRevisionTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, 10000)

qbeastSnapshot.spaceRevisions.count() shouldBe spaceMultipliers.length
qbeastSnapshot.spaceRevisions.size shouldBe spaceMultipliers.length

}

Expand All @@ -55,7 +55,7 @@ class NewRevisionTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, 10000)

val allWM = qbeastSnapshot.spaceRevisions.collect().map(qbeastSnapshot.cubeWeights)
val allWM = qbeastSnapshot.spaceRevisions.map(qbeastSnapshot.cubeWeights)
allWM.foreach(wm => assert(wm.nonEmpty))
assert(allWM.distinct.length == allWM.length)

Expand Down
10 changes: 4 additions & 6 deletions src/test/scala/io/qbeast/spark/utils/QbeastSnapshotTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import io.qbeast.spark.index.OTreeAlgorithmTest.Client3
import io.qbeast.spark.index.{Weight}
import io.qbeast.spark.model.CubeInfo
import io.qbeast.spark.sql.qbeast.QbeastSnapshot
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.delta.DeltaLog
import org.scalatest.PrivateMethodTester
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -25,9 +25,8 @@ class QbeastSnapshotTest

val rdd =
spark.sparkContext.parallelize(
Seq(Client3(size * size, s"student-$size", 20, 1000 + 123, 2567.3432143)) ++
1.until(size)
.map(i => Client3(i * i, s"student-$i", 20 + i, 1000 + 123 + i, 2567.3432143 + i)))
1.to(size)
.map(i => Client3(i * i, s"student-$i", i, i * 1000 + 123, i * 2567.3432143)))

assert(rdd.count() == size)
spark.createDataFrame(rdd)
Expand Down Expand Up @@ -145,15 +144,14 @@ class QbeastSnapshotTest
val deltaLog = DeltaLog.forTable(spark, tmpDir)
val qbeastSnapshot = QbeastSnapshot(deltaLog.snapshot, oTreeAlgorithm.desiredCubeSize)

val indexStateMethod = PrivateMethod[Dataset[CubeInfo]]('indexState)
val indexStateMethod = PrivateMethod[Seq[CubeInfo]]('indexState)
val indexState =
qbeastSnapshot invokePrivate indexStateMethod(qbeastSnapshot.lastSpaceRevision)
val overflowed =
qbeastSnapshot.overflowedSet(qbeastSnapshot.lastSpaceRevision).map(_.string)

indexState
.filter(cubeInfo => overflowed.contains(cubeInfo.cube))
.collect()
.foreach(cubeInfo =>
assert(
cubeInfo.size > oTreeAlgorithm.desiredCubeSize * 0.9,
Expand Down