diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 893984feabf11..0024ef1a5cae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH import org.apache.spark.sql.execution.streaming.CheckpointFileManager -import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, StateStoreMetadataV1} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,6 +46,7 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, + operatorPropertiesJson: String, numColsPrefixKey: Int) { def toRow(): InternalRow = { new GenericInternalRow( @@ -55,6 +56,7 @@ case class StateMetadataTableEntry( numPartitions, minBatchId, maxBatchId, + UTF8String.fromString(operatorPropertiesJson), numColsPrefixKey)) } } @@ -68,6 +70,7 @@ object StateMetadataTableEntry { .add("numPartitions", IntegerType) .add("minBatchId", LongType) .add("maxBatchId", LongType) + .add("operatorProperties", StringType) } } @@ -188,29 +191,59 @@ class StateMetadataPartitionReader( } else Array.empty } - private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + // Need this to be accessible from IncrementalExecution for the planning rule. + private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = { val stateDir = new Path(checkpointLocation, "state") val opIds = fileManager .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted opIds.map { opId => - new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + val operatorIdPath = new Path(stateDir, opId.toString) + // check if OperatorStateMetadataV2 path exists, if it does, read it + // otherwise, fall back to OperatorStateMetadataV1 + val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath) + val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) { + 2 + } else { + 1 + } + OperatorStateMetadataReader.createReader( + operatorIdPath, hadoopConf, operatorStateMetadataVersion).read() match { + case Some(metadata) => metadata + case None => OperatorStateMetadataV1(OperatorInfoV1(opId, null), + Array(StateStoreMetadataV1(null, -1, -1))) + } } } private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => - require(operatorStateMetadata.version == 1) - val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] - operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => - StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, - operatorStateMetadataV1.operatorInfo.operatorName, - stateStoreMetadata.storeName, - stateStoreMetadata.numPartitions, - if (batchIds.nonEmpty) batchIds.head else -1, - if (batchIds.nonEmpty) batchIds.last else -1, - stateStoreMetadata.numColsPrefixKey - ) + require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2) + operatorStateMetadata match { + case v1: OperatorStateMetadataV1 => + v1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v1.operatorInfo.operatorId, + v1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + null, + stateStoreMetadata.numColsPrefixKey + ) + } + case v2: OperatorStateMetadataV2 => + v2.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(v2.operatorInfo.operatorId, + v2.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + v2.operatorPropertiesJson, + -1 // numColsPrefixKey is not available in OperatorStateMetadataV2 + ) + } + } } - } - }.iterator + }.iterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 722a3bd86b7e1..567fb1b98f14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 -import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -208,13 +208,16 @@ class IncrementalExecution( } val schemaValidationResult = statefulOp. validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion) + val stateSchemaPaths = schemaValidationResult.map(_.schemaPath) // write out the state schema paths to the metadata file statefulOp match { - case stateStoreWriter: StateStoreWriter => - val metadata = stateStoreWriter.operatorStateMetadata() - // TODO: [SPARK-48849] Populate metadata with stateSchemaPaths if metadata version is v2 - val metadataWriter = new OperatorStateMetadataWriter(new Path( - checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf) + case ssw: StateStoreWriter => + val metadata = ssw.operatorStateMetadata(stateSchemaPaths) + val metadataWriter = OperatorStateMetadataWriter.createWriter( + new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString), + hadoopConf, + ssw.operatorStateMetadataVersion, + Some(currentBatchId)) metadataWriter.write(metadata) case _ => } @@ -456,8 +459,12 @@ class IncrementalExecution( val reader = new StateMetadataPartitionReader( new Path(checkpointLocation).getParent.toString, new SerializableConfiguration(hadoopConf)) - ret = reader.stateMetadata.map { metadataTableEntry => - metadataTableEntry.operatorId -> metadataTableEntry.operatorName + val opMetadataList = reader.allOperatorStateMetadata + ret = opMetadataList.map { + case OperatorStateMetadataV1(operatorInfo, _) => + operatorInfo.operatorId -> operatorInfo.operatorName + case OperatorStateMetadataV2(operatorInfo, _, _) => + operatorInfo.operatorId -> operatorInfo.operatorName }.toMap } catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index a303d4db66a01..c54917bdb7873 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec( private val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) - val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray + val stateStoreInfo = + stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray OperatorStateMetadataV1(operatorInfo, stateStoreInfo) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index a4d525ad13fd4..d2b8f92aa985b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.JString +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -96,6 +100,8 @@ case class TransformWithStateExec( } } + override def operatorStateMetadataVersion: Int = 2 + /** * We initialize this processor handle in the driver to run the init function * and fetch the schemas of the state variables initialized in this processor. @@ -382,12 +388,47 @@ case class TransformWithStateExec( batchId: Long, stateSchemaVersion: Int): List[StateSchemaValidationResult] = { assert(stateSchemaVersion >= 3) - val newColumnFamilySchemas = getColFamilySchemas() + val newSchemas = getColFamilySchemas() val stateSchemaDir = stateSchemaDirPath() - val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}") - List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, - newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion, - schemaFilePath = Some(stateSchemaFilePath))) + val newStateSchemaFilePath = + new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}") + val metadataPath = new Path(getStateInfo.checkpointLocation, s"${getStateInfo.operatorId}") + val metadataReader = OperatorStateMetadataReader.createReader( + metadataPath, hadoopConf, operatorStateMetadataVersion) + val operatorStateMetadata = metadataReader.read() + val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match { + case Some(metadata) => + metadata match { + case v2: OperatorStateMetadataV2 => + Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath)) + case _ => None + } + case None => None + } + List(StateSchemaCompatibilityChecker. + validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf, + newSchemas.values.toList, session.sessionState, stateSchemaVersion, + storeName = StateStoreId.DEFAULT_STORE_NAME, + oldSchemaFilePath = oldStateSchemaFilePath, + newSchemaFilePath = Some(newStateSchemaFilePath))) + } + + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata( + stateSchemaPaths: List[String]): OperatorStateMetadata = { + val info = getStateInfo + val operatorInfo = OperatorInfoV1(info.operatorId, shortName) + // stateSchemaFilePath should be populated at this point + val stateStoreInfo = + Array(StateStoreMetadataV2( + StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head)) + + val operatorPropertiesJson: JValue = + ("timeMode" -> JString(timeMode.toString)) ~ + ("outputMode" -> JString(outputMode.toString)) + + val json = compact(render(operatorPropertiesJson)) + OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) } private def stateSchemaDirPath(): Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index dcea29085bf2b..df3de5d9ceab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -20,13 +20,17 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{BufferedReader, InputStreamReader} import java.nio.charset.StandardCharsets +import scala.reflect.ClassTag + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream +import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter} /** * Metadata for a state store instance. @@ -40,6 +44,21 @@ trait StateStoreMetadata { case class StateStoreMetadataV1(storeName: String, numColsPrefixKey: Int, numPartitions: Int) extends StateStoreMetadata +case class StateStoreMetadataV2( + storeName: String, + numColsPrefixKey: Int, + numPartitions: Int, + stateSchemaFilePath: String) + extends StateStoreMetadata with Serializable + +object StateStoreMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[StateStoreMetadataV2](implicitly[ClassTag[StateStoreMetadataV2]].runtimeClass) +} + /** * Information about a stateful operator. */ @@ -51,7 +70,10 @@ trait OperatorInfo { case class OperatorInfoV1(operatorId: Long, operatorName: String) extends OperatorInfo trait OperatorStateMetadata { + def version: Int + + def operatorInfo: OperatorInfo } case class OperatorStateMetadataV1( @@ -60,12 +82,56 @@ case class OperatorStateMetadataV1( override def version: Int = 1 } -object OperatorStateMetadataUtils { +case class OperatorStateMetadataV2( + operatorInfo: OperatorInfoV1, + stateStoreInfo: Array[StateStoreMetadataV2], + operatorPropertiesJson: String) extends OperatorStateMetadata { + override def version: Int = 2 +} + +object OperatorStateMetadataUtils extends Logging { + + sealed trait OperatorStateMetadataReader { + def version: Int + + def read(): Option[OperatorStateMetadata] + } + + sealed trait OperatorStateMetadataWriter { + def version: Int + def write(operatorMetadata: OperatorStateMetadata): Unit + } private implicit val formats: Formats = Serialization.formats(NoTypeHints) - def metadataFilePath(stateCheckpointPath: Path): Path = - new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + def readMetadata(inputStream: FSDataInputStream): Option[OperatorStateMetadata] = { + val inputReader = + new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + try { + val versionStr = inputReader.readLine() + val version = MetadataVersionUtil.validateVersion(versionStr, 2) + Some(deserialize(version, inputReader)) + } finally { + inputStream.close() + } + } + + def writeMetadata( + outputStream: CancellableFSDataOutputStream, + operatorMetadata: OperatorStateMetadata, + metadataFilePath: Path): Unit = { + try { + outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8)) + OperatorStateMetadataUtils.serialize(outputStream, operatorMetadata) + outputStream.close() + } catch { + case e: Throwable => + logError( + log"Fail to write state metadata file to ${MDC(LogKeys.META_FILE, metadataFilePath)}", e) + outputStream.cancel() + throw e + } + } def deserialize( version: Int, @@ -73,6 +139,8 @@ object OperatorStateMetadataUtils { version match { case 1 => Serialization.read[OperatorStateMetadataV1](in) + case 2 => + Serialization.read[OperatorStateMetadataV2](in) case _ => throw new IllegalArgumentException(s"Failed to deserialize operator metadata with " + @@ -86,7 +154,8 @@ object OperatorStateMetadataUtils { operatorStateMetadata.version match { case 1 => Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1], out) - + case 2 => + Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV2], out) case _ => throw new IllegalArgumentException(s"Failed to serialize operator metadata with " + s"version=${operatorStateMetadata.version}") @@ -94,54 +163,153 @@ object OperatorStateMetadataUtils { } } +object OperatorStateMetadataReader { + def createReader( + stateCheckpointPath: Path, + hadoopConf: Configuration, + version: Int): OperatorStateMetadataReader = { + version match { + case 1 => + new OperatorStateMetadataV1Reader(stateCheckpointPath, hadoopConf) + case 2 => + new OperatorStateMetadataV2Reader(stateCheckpointPath, hadoopConf) + case _ => + throw new IllegalArgumentException(s"Failed to create reader for operator metadata " + + s"with version=$version") + } + } +} + +object OperatorStateMetadataWriter { + def createWriter( + stateCheckpointPath: Path, + hadoopConf: Configuration, + version: Int, + currentBatchId: Option[Long] = None): OperatorStateMetadataWriter = { + version match { + case 1 => + new OperatorStateMetadataV1Writer(stateCheckpointPath, hadoopConf) + case 2 => + if (currentBatchId.isEmpty) { + throw new IllegalArgumentException("currentBatchId is required for version 2") + } + new OperatorStateMetadataV2Writer(stateCheckpointPath, hadoopConf, currentBatchId.get) + case _ => + throw new IllegalArgumentException(s"Failed to create writer for operator metadata " + + s"with version=$version") + } + } +} + +object OperatorStateMetadataV1 { + def metadataFilePath(stateCheckpointPath: Path): Path = + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") +} + +object OperatorStateMetadataV2 { + private implicit val formats: Formats = Serialization.formats(NoTypeHints) + + @scala.annotation.nowarn + private implicit val manifest = Manifest + .classType[OperatorStateMetadataV2](implicitly[ClassTag[OperatorStateMetadataV2]].runtimeClass) + + def metadataDirPath(stateCheckpointPath: Path): Path = + new Path(new Path(new Path(stateCheckpointPath, "_metadata"), "metadata"), "v2") + + def metadataFilePath(stateCheckpointPath: Path, currentBatchId: Long): Path = + new Path(metadataDirPath(stateCheckpointPath), currentBatchId.toString) +} + /** * Write OperatorStateMetadata into the state checkpoint directory. */ -class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: Configuration) - extends Logging { +class OperatorStateMetadataV1Writer( + stateCheckpointPath: Path, + hadoopConf: Configuration) + extends OperatorStateMetadataWriter with Logging { - private val metadataFilePath = OperatorStateMetadataUtils.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + override def version: Int = 1 + def write(operatorMetadata: OperatorStateMetadata): Unit = { if (fm.exists(metadataFilePath)) return fm.mkdirs(metadataFilePath.getParent) val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false) - try { - outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8)) - OperatorStateMetadataUtils.serialize(outputStream, operatorMetadata) - outputStream.close() - } catch { - case e: Throwable => - logError( - log"Fail to write state metadata file to ${MDC(LogKeys.META_FILE, metadataFilePath)}", e) - outputStream.cancel() - throw e - } + OperatorStateMetadataUtils.writeMetadata(outputStream, operatorMetadata, metadataFilePath) } } /** - * Read OperatorStateMetadata from the state checkpoint directory. + * Read OperatorStateMetadata from the state checkpoint directory. This class will only be + * used to read OperatorStateMetadataV1. + * OperatorStateMetadataV2 will be read by the OperatorStateMetadataLog. */ -class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: Configuration) { +class OperatorStateMetadataV1Reader( + stateCheckpointPath: Path, + hadoopConf: Configuration) extends OperatorStateMetadataReader { + override def version: Int = 1 - private val metadataFilePath = OperatorStateMetadataUtils.metadataFilePath(stateCheckpointPath) + private val metadataFilePath = OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) - def read(): OperatorStateMetadata = { + def read(): Option[OperatorStateMetadata] = { val inputStream = fm.open(metadataFilePath) - val inputReader = - new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - try { - val versionStr = inputReader.readLine() - val version = MetadataVersionUtil.validateVersion(versionStr, 1) - OperatorStateMetadataUtils.deserialize(version, inputReader) - } finally { - inputStream.close() + OperatorStateMetadataUtils.readMetadata(inputStream) + } +} + +class OperatorStateMetadataV2Writer( + stateCheckpointPath: Path, + hadoopConf: Configuration, + currentBatchId: Long) extends OperatorStateMetadataWriter { + + private val metadataFilePath = OperatorStateMetadataV2.metadataFilePath( + stateCheckpointPath, currentBatchId) + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + + override def version: Int = 2 + + override def write(operatorMetadata: OperatorStateMetadata): Unit = { + if (fm.exists(metadataFilePath)) return + + fm.mkdirs(metadataFilePath.getParent) + val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = false) + OperatorStateMetadataUtils.writeMetadata(outputStream, operatorMetadata, metadataFilePath) + } +} + +class OperatorStateMetadataV2Reader( + stateCheckpointPath: Path, + hadoopConf: Configuration) extends OperatorStateMetadataReader { + + private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateCheckpointPath) + private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf) + + fm.mkdirs(metadataDirPath.getParent) + override def version: Int = 2 + + private def listBatches(): Array[Long] = { + if (!fm.exists(metadataDirPath)) { + return Array.empty + } + fm.list(metadataDirPath).map(_.getPath.getName.toLong).sorted + } + + override def read(): Option[OperatorStateMetadata] = { + val batches = listBatches() + if (batches.isEmpty) { + return None } + val lastBatchId = batches.last + val metadataFilePath = OperatorStateMetadataV2.metadataFilePath( + stateCheckpointPath, lastBatchId) + val inputStream = fm.open(metadataFilePath) + OperatorStateMetadataUtils.readMetadata(inputStream) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 3230098c74cd4..ca03de6f1ad3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -48,13 +48,14 @@ case class StateStoreColFamilySchema( class StateSchemaCompatibilityChecker( providerId: StateStoreProviderId, hadoopConf: Configuration, - schemaFilePath: Option[Path] = None) extends Logging { + oldSchemaFilePath: Option[Path] = None, + newSchemaFilePath: Option[Path] = None) extends Logging { - private val schemaFileLocation = if (schemaFilePath.isEmpty) { + private val schemaFileLocation = if (oldSchemaFilePath.isEmpty) { val storeCpLocation = providerId.storeId.storeCheckpointLocation() schemaFile(storeCpLocation) } else { - schemaFilePath.get + oldSchemaFilePath.get } private val fm = CheckpointFileManager.create(schemaFileLocation, hadoopConf) @@ -65,10 +66,6 @@ class StateSchemaCompatibilityChecker( val inStream = fm.open(schemaFileLocation) try { val versionStr = inStream.readUTF() - // Ensure that version 3 format has schema file path provided explicitly - if (versionStr == "v3" && schemaFilePath.isEmpty) { - throw new IllegalStateException("Schema file path is required for schema version 3") - } val schemaReader = SchemaReader.createSchemaReader(versionStr) schemaReader.read(inStream) } catch { @@ -98,7 +95,7 @@ class StateSchemaCompatibilityChecker( stateStoreColFamilySchema: List[StateStoreColFamilySchema], stateSchemaVersion: Int): Unit = { // Ensure that schema file path is passed explicitly for schema version 3 - if (stateSchemaVersion == 3 && schemaFilePath.isEmpty) { + if (stateSchemaVersion == 3 && newSchemaFilePath.isEmpty) { throw new IllegalStateException("Schema file path is required for schema version 3") } @@ -110,13 +107,19 @@ class StateSchemaCompatibilityChecker( private[sql] def createSchemaFile( stateStoreColFamilySchema: List[StateStoreColFamilySchema], schemaWriter: SchemaWriter): Unit = { - val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) + val schemaFilePath = newSchemaFilePath match { + case Some(path) => + fm.mkdirs(path.getParent) + path + case None => schemaFileLocation + } + val outStream = fm.createAtomic(schemaFilePath, overwriteIfPossible = false) try { schemaWriter.write(stateStoreColFamilySchema, outStream) outStream.close() } catch { case e: Throwable => - logError(log"Fail to write schema file to ${MDC(LogKeys.PATH, schemaFileLocation)}", e) + logError(log"Fail to write schema file to ${MDC(LogKeys.PATH, schemaFilePath)}", e) outStream.cancel() throw e } @@ -208,7 +211,10 @@ object StateSchemaCompatibilityChecker { * @param stateSchemaVersion - version of the state schema to be used * @param extraOptions - any extra options to be passed for StateStoreConf creation * @param storeName - optional state store name - * @param schemaFilePath - optional schema file path + * @param oldSchemaFilePath - optional path to the old schema file. If not provided, will default + * to the schema file location + * @param newSchemaFilePath - optional path to the destination schema file. + * Needed for schema version 3 * @return - StateSchemaValidationResult containing the result of the schema validation */ def validateAndMaybeEvolveStateSchema( @@ -219,7 +225,8 @@ object StateSchemaCompatibilityChecker { stateSchemaVersion: Int, extraOptions: Map[String, String] = Map.empty, storeName: String = StateStoreId.DEFAULT_STORE_NAME, - schemaFilePath: Option[Path] = None): StateSchemaValidationResult = { + oldSchemaFilePath: Option[Path] = None, + newSchemaFilePath: Option[Path] = None): StateSchemaValidationResult = { // SPARK-47776: collation introduces the concept of binary (in)equality, which means // in some collation we no longer be able to just compare the binary format of two // UnsafeRows to determine equality. For example, 'aaa' and 'AAA' can be "semantically" @@ -237,7 +244,7 @@ object StateSchemaCompatibilityChecker { val providerId = StateStoreProviderId(StateStoreId(stateInfo.checkpointLocation, stateInfo.operatorId, 0, storeName), stateInfo.queryRunId) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - schemaFilePath = schemaFilePath) + oldSchemaFilePath = oldSchemaFilePath, newSchemaFilePath = newSchemaFilePath) // regardless of configuration, we check compatibility to at least write schema file // if necessary // if the format validation for value schema is disabled, we also disable the schema @@ -261,6 +268,10 @@ object StateSchemaCompatibilityChecker { if (storeConf.stateSchemaCheckEnabled && result.isDefined) { throw result.get } - StateSchemaValidationResult(evolvedSchema, checker.schemaFileLocation.toString) + val schemaFileLocation = newSchemaFilePath match { + case Some(path) => path.toString + case None => checker.schemaFileLocation.toString + } + StateSchemaValidationResult(evolvedSchema, schemaFileLocation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 14f67460763b1..43d75c4b4d137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -73,6 +74,12 @@ trait StatefulOperator extends SparkPlan { } } + def metadataFilePath(): Path = { + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, getStateInfo.operatorId.toString) + new Path(new Path(stateCheckpointPath, "_metadata"), "metadata") + } + // Function used to record state schema for the first time and validate it against proposed // schema changes in the future. Runs as part of a planning rule on the driver. // Returns the schema file path for operators that write this to the metadata file, @@ -142,6 +149,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp */ def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = Some(inputWatermarkMs) + def operatorStateMetadataVersion: Int = 1 + override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext, @@ -157,6 +166,20 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics + def stateSchemaFilePath(storeName: Option[String] = None): Path = { + def stateInfo = getStateInfo + val stateCheckpointPath = + new Path(getStateInfo.checkpointLocation, + s"${stateInfo.operatorId.toString}") + storeName match { + case Some(storeName) => + val storeNamePath = new Path(stateCheckpointPath, storeName) + new Path(new Path(storeNamePath, "_metadata"), "schema") + case None => + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } + /** * Get the progress made by this stateful operator after execution. This should be called in * the driver after this SparkPlan has been executed and metrics have been updated. @@ -190,7 +213,8 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** Metadata of this stateful operator and its states stores. */ - def operatorStateMetadata(): OperatorStateMetadata = { + def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = @@ -920,7 +944,8 @@ case class SessionWindowStateStoreSaveExec( keyWithoutSessionExpressions, getStateInfo, conf) :: Nil } - override def operatorStateMetadata(): OperatorStateMetadata = { + override def operatorStateMetadata( + stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = { val info = getStateInfo val operatorInfo = OperatorInfoV1(info.operatorId, shortName) val stateStoreInfo = Array(StateStoreMetadataV1( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index dd8f7aab51dd0..65d32b474708a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -40,10 +40,11 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { operatorId: Int, expectedMetadata: OperatorStateMetadataV1): Unit = { val statePath = new Path(checkpointDir, s"state/$operatorId") - val operatorMetadata = new OperatorStateMetadataReader(statePath, hadoopConf).read() - .asInstanceOf[OperatorStateMetadataV1] - assert(operatorMetadata.operatorInfo == expectedMetadata.operatorInfo && - operatorMetadata.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo)) + val operatorMetadata = new OperatorStateMetadataV1Reader(statePath, hadoopConf).read() + .asInstanceOf[Option[OperatorStateMetadataV1]] + assert(operatorMetadata.isDefined) + assert(operatorMetadata.get.operatorInfo == expectedMetadata.operatorInfo && + operatorMetadata.get.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo)) } test("Serialize and deserialize stateful operator metadata") { @@ -52,14 +53,14 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val stateStoreInfo = (1 to 4).map(i => StateStoreMetadataV1(s"store$i", 1, 200)) val operatorInfo = OperatorInfoV1(1, "Join") val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) - new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) + new OperatorStateMetadataV1Writer(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) val df = spark.read.format("state-metadata").load(checkpointDir.toString) // Commit log is empty, there is no available batch id. - checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L), - Row(1, "Join", "store2", 200, -1L, -1L), - Row(1, "Join", "store3", 200, -1L, -1L), - Row(1, "Join", "store4", 200, -1L, -1L) + checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L, null), + Row(1, "Join", "store2", 200, -1L, -1L, null), + Row(1, "Join", "store3", 200, -1L, -1L, null), + Row(1, "Join", "store4", 200, -1L, -1L, null) )) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1), Row(1), Row(1), Row(1))) @@ -118,10 +119,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val df = spark.read.format("state-metadata") .load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L), - Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L) + checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L, null), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L, null) )) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0), Row(0), Row(0))) @@ -169,7 +170,7 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) val df = spark.read.format("state-metadata").load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) + checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L, null))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } } @@ -202,8 +203,8 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) val df = spark.read.format("state-metadata").load(checkpointDir.toString) - checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), - Row(1, "stateStoreSave", "default", 5, 0L, 1L))) + checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L, null), + Row(1, "stateStoreSave", "default", 5, 0L, 1L, null))) checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index f5a5d1277d05d..38533825ece90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -275,7 +275,8 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val schemaFilePath = Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf, - schemaFilePath = schemaFilePath) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = schemaFilePath) checker.createSchemaFile(storeColFamilySchema, SchemaHelper.SchemaWriter.createSchemaWriter(stateSchemaVersion)) val stateSchema = checker.readSchemaFile() @@ -359,6 +360,14 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { } } + private def getNewSchemaPath(stateSchemaDir: Path, stateSchemaVersion: Int): Option[Path] = { + if (stateSchemaVersion == 3) { + Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) + } else { + None + } + } + private def verifyException( oldKeySchema: StructType, oldValueSchema: StructType, @@ -373,9 +382,9 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) + val stateSchemaDir = stateSchemaDirPath(stateInfo) Seq(2, 3).foreach { stateSchemaVersion => val schemaFilePath = if (stateSchemaVersion == 3) { - val stateSchemaDir = stateSchemaDirPath(stateInfo) Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) } else { None @@ -384,10 +393,13 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val oldStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, oldKeySchema, oldValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) + val newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion) val result = Try( StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = newSchemaFilePath, + extraOptions = extraOptions) ).toEither.fold(Some(_), _ => None) val ex = if (result.isDefined) { @@ -399,7 +411,12 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + extraOptions = extraOptions, + oldSchemaFilePath = stateSchemaVersion match { + case 3 => newSchemaFilePath + case _ => None + }, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion)) } } @@ -433,9 +450,9 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { val extraOptions = Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> formatValidationForValue.toString) + val stateSchemaDir = stateSchemaDirPath(stateInfo) Seq(2, 3).foreach { stateSchemaVersion => val schemaFilePath = if (stateSchemaVersion == 3) { - val stateSchemaDir = stateSchemaDirPath(stateInfo) Some(new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")) } else { None @@ -446,14 +463,18 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, oldKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, oldStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), + extraOptions = extraOptions) val newStateSchema = List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, newKeySchema, newValueSchema, keyStateEncoderSpec = getKeyStateEncoderSpec(stateSchemaVersion, newKeySchema))) StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(stateInfo, hadoopConf, newStateSchema, spark.sessionState, stateSchemaVersion = stateSchemaVersion, - schemaFilePath = schemaFilePath, extraOptions = extraOptions) + oldSchemaFilePath = schemaFilePath, + newSchemaFilePath = getNewSchemaPath(stateSchemaDir, stateSchemaVersion), + extraOptions = extraOptions) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 2e65748cb4673..d55a16a60eac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkRuntimeException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state._ @@ -63,6 +63,32 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S } } +class RunningCountStatefulProcessorInt + extends StatefulProcessor[String, String, (String, String)] { + @transient protected var _countState: ValueState[Int] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + // Class to verify stateful processor usage with adding processing time timers class RunningCountStatefulProcessorWithProcTimeTimer extends RunningCountStatefulProcessor { private def handleProcessingTimeBasedTimers( @@ -886,6 +912,77 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + test("transformWithState - verify that OperatorStateMetadataV2" + + " file is being written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq( + Row(0, "transformWithStateExec", "default", 5, 0L, 1L, + """{"timeMode":"NoTime","outputMode":"Update"}""") + )) + } + } + } + + test("test that invalid schema evolution fails query for column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreValueSchemaNotCompatible] { + (t: Throwable) => { + assert(t.getMessage.contains("Please check number and type of fields.")) + } + } + ) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {