From 90a159b28a6e89b404937179cd3ca6afb34da2aa Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 4 Aug 2022 01:03:45 +0800 Subject: [PATCH] [SPARK-39952][SQL] SaveIntoDataSourceCommand should recache result relation ### What changes were proposed in this pull request? recacheByPlan the result relation inside `SaveIntoDataSourceCommand` ### Why are the changes needed? The behavior of `SaveIntoDataSourceCommand` is similar with `InsertIntoDataSourceCommand` which supports append or overwirte data. In order to keep data consistent, we should always do recacheByPlan the relation on post hoc. ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? add test Closes #37380 from ulysses-you/refresh. Authored-by: ulysses-you Signed-off-by: Wenchen Fan (cherry picked from commit 5fe0b245f7891a05bc4e1e641fd0aa9130118ea4) Signed-off-by: Wenchen Fan (cherry picked from commit 15ebd56de6ae37587d750bb1e106c5dcb3e22958) Signed-off-by: Dongjoon Hyun --- .../SaveIntoDataSourceCommand.scala | 12 +++- .../SaveIntoDataSourceCommandSuite.scala | 61 ++++++++++++++++++- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 486f73cab44f7..ef74036b23bef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.util.control.NonFatal + import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -41,9 +43,17 @@ case class SaveIntoDataSourceCommand( override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { - dataSource.createRelation( + val relation = dataSource.createRelation( sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) + try { + val logicalRelation = LogicalRelation(relation, relation.schema.toAttributes, None, false) + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) + } catch { + case NonFatal(_) => + // some data source can not support return a valid relation, e.g. `KafkaSourceProvider` + } + Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index e843d1d328425..e68d6561fb8fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SaveMode +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, TableScan} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StructField, StructType} -class SaveIntoDataSourceCommandSuite extends SharedSparkSession { +class SaveIntoDataSourceCommandSuite extends QueryTest with SharedSparkSession { test("simpleString is redacted") { val URL = "connection.url" @@ -41,4 +44,58 @@ class SaveIntoDataSourceCommandSuite extends SharedSparkSession { assert(!logicalPlanString.contains(PASS)) assert(logicalPlanString.contains(DRIVER)) } + + test("SPARK-39952: SaveIntoDataSourceCommand should recache result relation") { + val provider = classOf[FakeV1DataSource].getName + + def saveIntoDataSource(data: Int): Unit = { + spark.range(data) + .write + .mode("append") + .format(provider) + .save() + } + + def loadData: DataFrame = { + spark.read + .format(provider) + .load() + } + + saveIntoDataSource(1) + val cached = loadData.cache() + checkAnswer(cached, Row(0)) + + saveIntoDataSource(2) + checkAnswer(loadData, Row(0) :: Row(1) :: Nil) + + FakeV1DataSource.data = null + } +} + +object FakeV1DataSource { + var data: RDD[Row] = _ +} + +class FakeV1DataSource extends RelationProvider with CreatableRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + FakeRelation() + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + FakeV1DataSource.data = data.rdd + FakeRelation() + } +} + +case class FakeRelation() extends BaseRelation with TableScan { + override def sqlContext: SQLContext = SparkSession.getActiveSession.get.sqlContext + override def schema: StructType = StructType(Seq(StructField("id", LongType))) + override def buildScan(): RDD[Row] = FakeV1DataSource.data }