Skip to content

Commit

Permalink
Add comments, use ProviderID, and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyu Chen authored and Zeyu Chen committed Feb 26, 2025
1 parent 538cbef commit cf2f46d
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state

import java.io._
import java.util
import java.util.Locale
import java.util.{Locale, UUID}
import java.util.concurrent.atomic.LongAdder

import scala.collection.mutable
Expand All @@ -36,7 +36,7 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution}
import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SizeEstimator, Utils}
Expand Down Expand Up @@ -699,6 +699,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
log"${MDC(LogKeys.STATE_STORE_PROVIDER, this)} at ${MDC(LogKeys.FILE_NAME, targetFile)} " +
log"for ${MDC(LogKeys.OP_TYPE, opType)}")
lastSnapshotUploadedVersion = version
// Report snapshot upload event back to the coordinator
StateStore.reportSnapshotUploaded(
StateStoreProviderId(stateStoreId, UUID.fromString(getRunId(hadoopConf))),
version
)
}

/**
Expand Down Expand Up @@ -1043,6 +1048,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
keySchema, valueSchema)
}

private def getRunId(hadoopConf: Configuration): String = {
val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
if (runId != null) {
runId
} else {
assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
UUID.randomUUID().toString
}
}
}

/** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,7 @@ class RocksDB(
log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
lastUploadedSnapshotVersion.set(snapshot.version)
providerListener.foreach(_.onSnapshotUploaded(snapshot.version))
providerListener.foreach(_.reportSnapshotUploaded(snapshot.version))
} finally {
snapshot.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{NonFateSharingCache, Utils}

/** Trait representing the different events reported from RocksDB instance */
trait RocksDBEventListener {
def onSnapshotUploaded(version: Long): Unit
def reportSnapshotUploaded(version: Long): Unit
}

private[sql] class RocksDBStateStoreProvider
Expand Down Expand Up @@ -395,6 +396,9 @@ private[sql] class RocksDBStateStoreProvider
}

rocksDB // lazy initialization

// Give the RocksDB instance a reference to this provider so it can call back to report
// specific events like snapshot uploads
rocksDB.setListener(this)

val dataEncoderCacheKey = StateRowEncoderCacheKey(
Expand Down Expand Up @@ -650,8 +654,21 @@ private[sql] class RocksDBStateStoreProvider
}
}

def onSnapshotUploaded(version: Long): Unit = {
StateStore.reportSnapshotUploaded(stateStoreId, version)
/** Callback function from RocksDB to report events to the coordinator.
* Additional information such as state store ID and query run ID are populated here
* to report back to the coordinator.
*
* @param version The snapshot version that was just uploaded from RocksDB
*/
def reportSnapshotUploaded(version: Long): Unit = {
// Collect the state store ID and query run ID to report back to the coordinator
StateStore.reportSnapshotUploaded(
StateStoreProviderId(
stateStoreId,
UUID.fromString(getRunId(hadoopConf))
),
version
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,8 @@ object StateStore extends Logging {
}
}

def reportSnapshotUploaded(storeId: StateStoreId, snapshotVersion: Long): Unit = {
coordinatorRef.foreach(_.snapshotUploaded(storeId, snapshotVersion))
def reportSnapshotUploaded(storeProviderId: StateStoreProviderId, snapshotVersion: Long): Unit = {
coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion))
}

private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ private case class GetLocation(storeId: StateStoreProviderId)
private case class DeactivateInstances(runId: UUID)
extends StateStoreCoordinatorMessage

private case class SnapshotUploaded(storeId: StateStoreId, version: Long)
private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long)
extends StateStoreCoordinatorMessage

private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId)
extends StateStoreCoordinatorMessage

private object StopCoordinator
Expand Down Expand Up @@ -123,8 +126,14 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
}

/** Inform that an executor has uploaded a snapshot */
private[sql] def snapshotUploaded(storeId: StateStoreId, version: Long): Unit = {
rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeId, version))
private[sql] def snapshotUploaded(storeProviderId: StateStoreProviderId, version: Long): Unit = {
rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version))
}

/** Get the latest snapshot version uploaded for a state store */
private[sql] def getLatestSnapshotVersion(
stateStoreProviderId: StateStoreProviderId): Option[Long] = {
rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersion(stateStoreProviderId))
}

private[state] def stop(): Unit = {
Expand All @@ -141,7 +150,7 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation]

private val stateStoreSnapshotVersions = new mutable.HashMap[StateStoreId, Long]
private val stateStoreSnapshotVersions = new mutable.HashMap[StateStoreProviderId, Long]

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ReportActiveInstance(id, host, executorId, providerIdsToCheck) =>
Expand Down Expand Up @@ -178,31 +187,37 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
storeIdsToRemove.mkString(", "))
context.reply(true)

case SnapshotUploaded(storeId, version) =>
logWarning(s"ZEYU: ! msg of uploaded Snapshot ${storeId} ${version}")
stateStoreSnapshotVersions.put(storeId, version)
case SnapshotUploaded(providerId, version) =>
stateStoreSnapshotVersions.put(providerId, version)
logWarning(s"ZEYU: Snapshot uploaded at ${providerId} with version ${version}")
// Check for state stores falling behind
val latestPartitionVersion = instances.map(
instance => stateStoreSnapshotVersions.getOrElse(instance._1.storeId, -1L)
instance => stateStoreSnapshotVersions.getOrElse(instance._1, -1L)
).max
val storesAtRisk = instances
val storesBehind = instances
.filter {
case (storeProviderId, _) =>
latestPartitionVersion - stateStoreSnapshotVersions.getOrElse(
storeProviderId.storeId,
-1L
) > 5L
val versionDelta =
latestPartitionVersion - stateStoreSnapshotVersions.getOrElse(storeProviderId, -1L)
versionDelta > 5L
}
.keys
.toSeq
if (storesAtRisk.nonEmpty) {
logWarning(s"ZEYU: number of partitions at risk: ${storesAtRisk.size}")
storesAtRisk.foreach(storeProviderId => {
logWarning(s"ZEYU: partition at risk: ${storeProviderId.storeId}")
// Report all stores that are behind in snapshot uploads
if (storesBehind.nonEmpty) {
logWarning(s"ZEYU: Number of state stores falling behind: ${storesBehind.size}")
storesBehind.foreach(storeProviderId => {
val version = stateStoreSnapshotVersions.getOrElse(storeProviderId, -1L)
logWarning(s"ZEYU: State store falling behind ${storeProviderId} with version $version")
})
}
context.reply(true)

case GetLatestSnapshotVersion(providerId) =>
val version = stateStoreSnapshotVersions.get(providerId)
logWarning(s"ZEYU: Got latest snapshot version of the state store $providerId: $version")
context.reply(version)

case StopCoordinator =>
stop() // Stop before replying to ensure that endpoint name has been deregistered
logInfo("StateStoreCoordinator stopped")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
Expand Down Expand Up @@ -125,7 +125,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
import spark.implicits._
coordRef = spark.streams.stateStoreCoordinator
implicit val sqlContext = spark.sqlContext
spark.conf.set(SHUFFLE_PARTITIONS.key, "1")
spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")

// Start a query and run a batch to load state stores
val inputData = MemoryStream[Int]
Expand Down Expand Up @@ -155,6 +155,51 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
StateStore.stop()
}
}

test("snapshot uploads in RocksDB are properly reported to the coordinator") {
var coordRef: StateStoreCoordinatorRef = null
try {
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
SparkSession.setActiveSession(spark)
import spark.implicits._
coordRef = spark.streams.stateStoreCoordinator
implicit val sqlContext = spark.sqlContext
spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "100")
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, "1")

// Start a query and run a batch to load state stores
val inputData = MemoryStream[Int]
val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query
val checkpointLocation = Utils.createTempDir().getAbsoluteFile
val query = aggregated.writeStream
.format("memory")
.outputMode("update")
.queryName("query")
.option("checkpointLocation", checkpointLocation.toString)
.start()
inputData.addData(1, 2, 3)
query.processAllAvailable()
inputData.addData(1, 2, 3)
query.processAllAvailable()
inputData.addData(1, 2, 3)
query.processAllAvailable()

// Verify state store has uploaded a snapshot and this is registered with the coordinator
val stateCheckpointDir =
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation
val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId)
logWarning(
s"ZEYU: snapshot version ${coordRef.getLatestSnapshotVersion(providerId)}"
)

query.stop()
} finally {
SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop()))
if (coordRef != null) coordRef.stop()
StateStore.stop()
}
}
}

object StateStoreCoordinatorSuite {
Expand Down

0 comments on commit cf2f46d

Please sign in to comment.