From 746e0d1d3a11af21604dc8acc9d098fc721a0bb7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 20 Dec 2019 16:56:05 -0800 Subject: [PATCH] implement for append and overwrite as well --- .../apache/spark/sql/DataFrameWriter.scala | 76 ++++++++++++------- .../SupportsCatalogOptionsSuite.scala | 16 ++-- 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 80a9c868f7331..0c55f5ca52bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -258,26 +258,42 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { - case table: SupportsWrite if table.supports(BATCH_WRITE) => - lazy val relation = DataSourceV2Relation.create(table, dsOptions) - mode match { - case SaveMode.Append => - checkPartitioningMatchesV2Table(table) - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) - } + mode match { + case SaveMode.Append | SaveMode.Overwrite => + val table = provider match { + case supportsExtract: SupportsCatalogOptions => + val ident = supportsExtract.extractIdentifier(dsOptions) + val sessionState = df.sparkSession.sessionState + val catalog = CatalogV2Util.getTableProviderCatalog( + supportsExtract, sessionState.catalogManager, dsOptions) - case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - checkPartitioningMatchesV2Table(table) - // truncate the table - runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), extraOptions.toMap) - } + catalog.loadTable(ident) + case tableProvider: TableProvider => tableProvider.getTable(dsOptions) + case _ => + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we fall + // back to saving as though it's a V1 source. + return saveToV1Source() + } + + val relation = DataSourceV2Relation.create(table, dsOptions) + checkPartitioningMatchesV2Table(table) + if (mode == SaveMode.Append) { + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) + } + } else { + // Truncate the table. TableCapabilityCheck will throw a nice exception if this + // isn't supported + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), extraOptions.toMap) + } + } - case other if classOf[SupportsCatalogOptions].isAssignableFrom(provider.getClass) => - val supportsExtract = provider.asInstanceOf[SupportsCatalogOptions] + case create => + provider match { + case supportsExtract: SupportsCatalogOptions => val ident = supportsExtract.extractIdentifier(dsOptions) val sessionState = df.sparkSession.sessionState val catalog = CatalogV2Util.getTableProviderCatalog( @@ -293,20 +309,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ location, extraOptions.toMap, - ignoreIfExists = other == SaveMode.Ignore) + ignoreIfExists = create == SaveMode.Ignore) + } + case tableProvider: TableProvider => + if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) { + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $create mode, please use Append or Overwrite " + + "modes instead.") + } else { + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we + // fallback to saving as though it's a V1 source. + saveToV1Source() } - - case other => - throw new AnalysisException(s"TableProvider implementation $source cannot be " + - s"written with $other mode, please use Append or Overwrite " + - "modes instead.") } - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() } + } else { saveToV1Source() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 079fa21a3a585..51a5a3bf15c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.connector +import java.util + import scala.language.implicitConversions +import scala.util.Try import org.scalatest.BeforeAndAfter @@ -25,6 +28,7 @@ import org.apache.spark.sql.{QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructType} @@ -54,8 +58,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with override def afterEach(): Unit = { super.afterEach() - catalog(SESSION_CATALOG_NAME).clearTables() - catalog(catalogName).clearTables() + Try(catalog(SESSION_CATALOG_NAME).clearTables()) + Try(catalog(catalogName).clearTables()) spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) spark.conf.unset(s"spark.sql.catalog.$catalogName") } @@ -98,7 +102,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("save fails with ErrorIfExists if table exists - session catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table t1 (id bigint) using $format") val df = spark.range(10) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).option("name", "t1") @@ -107,7 +111,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("save fails with ErrorIfExists if table exists - testcat catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table testcat.t1 (id bigint) using $format") val df = spark.range(10) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName) @@ -116,7 +120,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("Ignore mode if table exists - session catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") @@ -129,7 +133,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("Ignore mode if table exists - testcat catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table testcat.t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")