diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 820322d1e0ee1..6ef5a885629b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -133,6 +133,7 @@ class RocksDB( private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j rocksDbOptions.setStatistics(new Statistics()) private val nativeStats = rocksDbOptions.statistics() + private var providerListener: Option[RocksDBEventListener] = None private val workingDir = createTempDir("workingDir") private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"), @@ -197,6 +198,10 @@ class RocksDB( @GuardedBy("acquireLock") private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false) + def setListener(listener: RocksDBEventListener): Unit = { + providerListener = Some(listener) + } + private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = { colFamilyNameToInfoMap.get(cfName) } @@ -1467,6 +1472,7 @@ class RocksDB( log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " + log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") lastUploadedSnapshotVersion.set(snapshot.version) + providerListener.foreach(_.onSnapshotUploaded(snapshot.version)) } finally { snapshot.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 47721cea4359f..d4bd360448ffd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -38,9 +38,13 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} +trait RocksDBEventListener { + def onSnapshotUploaded(version: Long): Unit +} + private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable - with SupportsFineGrainedReplay { + with SupportsFineGrainedReplay with RocksDBEventListener { import RocksDBStateStoreProvider._ class RocksDBStateStore(lastVersion: Long) extends StateStore { @@ -391,6 +395,7 @@ private[sql] class RocksDBStateStoreProvider } rocksDB // lazy initialization + rocksDB.setListener(this) val dataEncoderCacheKey = StateRowEncoderCacheKey( queryRunId = getRunId(hadoopConf), @@ -644,6 +649,10 @@ private[sql] class RocksDBStateStoreProvider throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName) } } + + def onSnapshotUploaded(version: Long): Unit = { + StateStore.reportSnapshotUploaded(stateStoreId, version) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 09acc24aff982..f08461741f7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1126,6 +1126,10 @@ object StateStore extends Logging { } } + def reportSnapshotUploaded(storeId: StateStoreId, snapshotVersion: Long): Unit = { + coordinatorRef.foreach(_.snapshotUploaded(storeId, snapshotVersion)) + } + private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 84b77efea3caf..fd8d9e896cacc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -55,6 +55,9 @@ private case class GetLocation(storeId: StateStoreProviderId) private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage +private case class SnapshotUploaded(storeId: StateStoreId, version: Long) + extends StateStoreCoordinatorMessage + private object StopCoordinator extends StateStoreCoordinatorMessage @@ -119,6 +122,11 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } + /** Inform that an executor has uploaded a snapshot */ + private[sql] def snapshotUploaded(storeId: StateStoreId, version: Long): Unit = { + rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeId, version)) + } + private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -133,6 +141,8 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] + private val stateStoreSnapshotVersions = new mutable.HashMap[StateStoreId, Long] + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -168,6 +178,31 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) storeIdsToRemove.mkString(", ")) context.reply(true) + case SnapshotUploaded(storeId, version) => + logWarning(s"ZEYU: ! msg of uploaded Snapshot ${storeId} ${version}") + stateStoreSnapshotVersions.put(storeId, version) + // Check for state stores falling behind + val latestPartitionVersion = instances.map( + instance => stateStoreSnapshotVersions.getOrElse(instance._1.storeId, -1L) + ).max + val storesAtRisk = instances + .filter { + case (storeProviderId, _) => + latestPartitionVersion - stateStoreSnapshotVersions.getOrElse( + storeProviderId.storeId, + -1L + ) > 5L + } + .keys + .toSeq + if (storesAtRisk.nonEmpty) { + logWarning(s"ZEYU: number of partitions at risk: ${storesAtRisk.size}") + storesAtRisk.foreach(storeProviderId => { + logWarning(s"ZEYU: partition at risk: ${storeProviderId.storeId}") + }) + } + context.reply(true) + case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped")