Skip to content

Commit

Permalink
SPARK-51358 Introduce snapshot upload lag detection through StateStor…
Browse files Browse the repository at this point in the history
…eCoordinator
  • Loading branch information
zecookiez authored and Zeyu Chen committed Mar 1, 2025
1 parent 496fe7a commit 958b491
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,19 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG =
buildConf("spark.sql.streaming.stateStore.minSnapshotVersionDeltaToLog")
.internal()
.doc(
"Minimum number of versions between the most recent uploaded snapshot version of a " +
"single state store instance and the most recent version across all state store " +
"instances to log a warning message."
)
.version("4.0.0")
.intConf
.checkValue(k => k >= 0, "Must be greater than or equal to 0")
.createWithDefault(30)

val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] (
with Logging {

private[sql] val stateStoreCoordinator =
StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env)
StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf)
private val listenerBus =
new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class RocksDB(
rocksDbOptions.setStatistics(new Statistics())
private val nativeStats = rocksDbOptions.statistics()

// Stores a StateStoreProvider reference for event callback such as snapshot upload reports
private var providerListener: Option[RocksDBEventListener] = None

private val workingDir = createTempDir("workingDir")
private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"),
hadoopConf, conf.compressionCodec, loggingId = loggingId)
Expand Down Expand Up @@ -197,6 +200,11 @@ class RocksDB(
@GuardedBy("acquireLock")
private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)

/** Attaches a RocksDBStateStoreProvider reference to the RocksDB instance for event callback. */
def setListener(listener: RocksDBEventListener): Unit = {
providerListener = Some(listener)
}

private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = {
colFamilyNameToInfoMap.get(cfName)
}
Expand Down Expand Up @@ -1467,6 +1475,11 @@ class RocksDB(
log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
lastUploadedSnapshotVersion.set(snapshot.version)
// Report to coordinator that the snapshot has been uploaded when
// changelog checkpointing is enabled, since that is when stores can lag behind.
if(enableChangelogCheckpointing) {
providerListener.foreach(_.reportSnapshotUploaded(snapshot.version))
}
} finally {
snapshot.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{NonFateSharingCache, Utils}

/** Trait representing the different events reported from RocksDB instance */
trait RocksDBEventListener {
def reportSnapshotUploaded(version: Long): Unit
}

private[sql] class RocksDBStateStoreProvider
extends StateStoreProvider with Logging with Closeable
with SupportsFineGrainedReplay {
with SupportsFineGrainedReplay with RocksDBEventListener {
import RocksDBStateStoreProvider._

class RocksDBStateStore(lastVersion: Long) extends StateStore {
Expand Down Expand Up @@ -392,6 +397,10 @@ private[sql] class RocksDBStateStoreProvider

rocksDB // lazy initialization

// Give the RocksDB instance a reference to this provider so it can call back to report
// specific events like snapshot uploads
rocksDB.setListener(this)

val dataEncoderCacheKey = StateRowEncoderCacheKey(
queryRunId = getRunId(hadoopConf),
operatorId = stateStoreId.operatorId,
Expand Down Expand Up @@ -644,6 +653,23 @@ private[sql] class RocksDBStateStoreProvider
throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName)
}
}

/** Callback function from RocksDB to report events to the coordinator.
* Additional information such as state store ID and query run ID are populated here
* to report back to the coordinator.
*
* @param version The snapshot version that was just uploaded from RocksDB
*/
def reportSnapshotUploaded(version: Long): Unit = {
// Collect the state store ID and query run ID to report back to the coordinator
StateStore.reportSnapshotUploaded(
StateStoreProviderId(
stateStoreId,
UUID.fromString(getRunId(hadoopConf))
),
version
)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,12 @@ object StateStore extends Logging {
}
}

def reportSnapshotUploaded(storeProviderId: StateStoreProviderId, snapshotVersion: Long): Unit = {
// Send current timestamp of uploaded snapshot as well
val currentTime = System.currentTimeMillis()
coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime))
}

private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
val env = SparkEnv.get
if (env != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.RpcUtils

/** Trait representing all messages to [[StateStoreCoordinator]] */
Expand Down Expand Up @@ -55,6 +56,15 @@ private case class GetLocation(storeId: StateStoreProviderId)
private case class DeactivateInstances(runId: UUID)
extends StateStoreCoordinatorMessage

private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long, timestamp: Long)
extends StateStoreCoordinatorMessage

private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId)
extends StateStoreCoordinatorMessage

private case class GetLaggingStores()
extends StateStoreCoordinatorMessage

private object StopCoordinator
extends StateStoreCoordinatorMessage

Expand All @@ -66,9 +76,9 @@ object StateStoreCoordinatorRef extends Logging {
/**
* Create a reference to a [[StateStoreCoordinator]]
*/
def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
def forDriver(env: SparkEnv, conf: SQLConf): StateStoreCoordinatorRef = synchronized {
try {
val coordinator = new StateStoreCoordinator(env.rpcEnv)
val coordinator = new StateStoreCoordinator(env.rpcEnv, conf)
val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator)
logInfo("Registered StateStoreCoordinator endpoint")
new StateStoreCoordinatorRef(coordinatorRef)
Expand Down Expand Up @@ -119,6 +129,25 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId))
}

/** Inform that an executor has uploaded a snapshot */
private[sql] def snapshotUploaded(
storeProviderId: StateStoreProviderId,
version: Long,
timestamp: Long): Unit = {
rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version, timestamp))
}

/** Get the latest snapshot version uploaded for a state store */
private[sql] def getLatestSnapshotVersion(
stateStoreProviderId: StateStoreProviderId): Option[Long] = {
rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersion(stateStoreProviderId))
}

/** Get the state store instances that are falling behind in snapshot uploads */
private[sql] def getLaggingStores(): Seq[StateStoreProviderId] = {
rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStores)
}

private[state] def stop(): Unit = {
rpcEndpointRef.askSync[Boolean](StopCoordinator)
}
Expand All @@ -129,10 +158,17 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
* Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
* and get their locations for job scheduling.
*/
private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
private class StateStoreCoordinator(
override val rpcEnv: RpcEnv,
val sqlConf: SQLConf)
extends ThreadSafeRpcEndpoint
with Logging {
private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation]

// Stores the latest snapshot version of a specific state store provider instance
private val stateStoreSnapshotVersions =
new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent]

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ReportActiveInstance(id, host, executorId, providerIdsToCheck) =>
logDebug(s"Reported state store $id is active at $executorId")
Expand Down Expand Up @@ -168,9 +204,85 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
storeIdsToRemove.mkString(", "))
context.reply(true)

case SnapshotUploaded(providerId, version, timestamp) =>
stateStoreSnapshotVersions.put(providerId, SnapshotUploadEvent(version, timestamp))
logDebug(s"Snapshot uploaded at ${providerId} with version ${version}")
// Report all stores that are behind in snapshot uploads
val (laggingStores, latestSnapshot) = findLaggingStores()
if (laggingStores.nonEmpty) {
logWarning(s"Number of state stores falling behind: ${laggingStores.size}")
laggingStores.foreach { storeProviderId =>
val snapshotEvent =
stateStoreSnapshotVersions.getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0))
logWarning(
s"State store falling behind $storeProviderId " +
s"(current: $snapshotEvent, latest: $latestSnapshot)"
)
}
}
context.reply(true)

case GetLatestSnapshotVersion(providerId) =>
val version = stateStoreSnapshotVersions.get(providerId).map(_.version)
logDebug(s"Got latest snapshot version of the state store $providerId: $version")
context.reply(version)

case GetLaggingStores =>
val (laggingStores, _) = findLaggingStores()
logDebug(s"Got lagging state stores ${laggingStores
.map(
id =>
s"StateStoreId(operatorId=${id.storeId.operatorId}, " +
s"partitionId=${id.storeId.partitionId}, " +
s"storeName=${id.storeId.storeName})"
)
.mkString(", ")}")
context.reply(laggingStores)

case StopCoordinator =>
stop() // Stop before replying to ensure that endpoint name has been deregistered
logInfo("StateStoreCoordinator stopped")
context.reply(true)
}

case class SnapshotUploadEvent(
version: Long,
timestamp: Long
) extends Ordered[SnapshotUploadEvent] {
def isLagging(latest: SnapshotUploadEvent): Boolean = {
val versionDelta = latest.version - version
val timeDelta = latest.timestamp - timestamp
val minVersionDeltaForLogging =
sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG)
// Use 10 times the maintenance interval as the minimum time delta for logging
val minTimeDeltaForLogging = 10 * sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL)

versionDelta >= minVersionDeltaForLogging ||
(version >= 0 && timeDelta > minTimeDeltaForLogging)
}

override def compare(that: SnapshotUploadEvent): Int = {
this.version.compare(that.version)
}

override def toString(): String = {
s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)"
}
}

private def findLaggingStores(): (Seq[StateStoreProviderId], SnapshotUploadEvent) = {
// Find the most updated instance to use as reference point
val latestSnapshot = instances
.map(
instance => stateStoreSnapshotVersions.getOrElse(instance._1, SnapshotUploadEvent(-1, 0))
)
.max
// Look for instances that are lagging behind in snapshot uploads
val laggingStores = instances.keys.filter { storeProviderId =>
stateStoreSnapshotVersions
.getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0))
.isLagging(latestSnapshot)
}.toSeq
(laggingStores, latestSnapshot)
}
}
Loading

0 comments on commit 958b491

Please sign in to comment.