Skip to content

Commit

Permalink
implement for append and overwrite as well
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Dec 21, 2019
1 parent 33ae658 commit 746e0d1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 35 deletions.
76 changes: 47 additions & 29 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,26 +258,42 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val dsOptions = new CaseInsensitiveStringMap(options.asJava)

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
provider.getTable(dsOptions) match {
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
mode match {
case SaveMode.Append =>
checkPartitioningMatchesV2Table(table)
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
}
mode match {
case SaveMode.Append | SaveMode.Overwrite =>
val table = provider match {
case supportsExtract: SupportsCatalogOptions =>
val ident = supportsExtract.extractIdentifier(dsOptions)
val sessionState = df.sparkSession.sessionState
val catalog = CatalogV2Util.getTableProviderCatalog(
supportsExtract, sessionState.catalogManager, dsOptions)

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
checkPartitioningMatchesV2Table(table)
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), extraOptions.toMap)
}
catalog.loadTable(ident)
case tableProvider: TableProvider => tableProvider.getTable(dsOptions)
case _ =>
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we fall
// back to saving as though it's a V1 source.
return saveToV1Source()
}

val relation = DataSourceV2Relation.create(table, dsOptions)
checkPartitioningMatchesV2Table(table)
if (mode == SaveMode.Append) {
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
}
} else {
// Truncate the table. TableCapabilityCheck will throw a nice exception if this
// isn't supported
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(
relation, df.logicalPlan, Literal(true), extraOptions.toMap)
}
}

case other if classOf[SupportsCatalogOptions].isAssignableFrom(provider.getClass) =>
val supportsExtract = provider.asInstanceOf[SupportsCatalogOptions]
case create =>
provider match {
case supportsExtract: SupportsCatalogOptions =>
val ident = supportsExtract.extractIdentifier(dsOptions)
val sessionState = df.sparkSession.sessionState
val catalog = CatalogV2Util.getTableProviderCatalog(
Expand All @@ -293,20 +309,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
df.queryExecution.analyzed,
Map(TableCatalog.PROP_PROVIDER -> source) ++ location,
extraOptions.toMap,
ignoreIfExists = other == SaveMode.Ignore)
ignoreIfExists = create == SaveMode.Ignore)
}
case tableProvider: TableProvider =>
if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) {
throw new AnalysisException(s"TableProvider implementation $source cannot be " +
s"written with $create mode, please use Append or Overwrite " +
"modes instead.")
} else {
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we
// fallback to saving as though it's a V1 source.
saveToV1Source()
}

case other =>
throw new AnalysisException(s"TableProvider implementation $source cannot be " +
s"written with $other mode, please use Append or Overwrite " +
"modes instead.")
}

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
case _ => saveToV1Source()
}

} else {
saveToV1Source()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.sql.connector

import java.util

import scala.language.implicitConversions
import scala.util.Try

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{QueryTest, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, StructType}
Expand Down Expand Up @@ -54,8 +58,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with

override def afterEach(): Unit = {
super.afterEach()
catalog(SESSION_CATALOG_NAME).clearTables()
catalog(catalogName).clearTables()
Try(catalog(SESSION_CATALOG_NAME).clearTables())
Try(catalog(catalogName).clearTables())
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
spark.conf.unset(s"spark.sql.catalog.$catalogName")
}
Expand Down Expand Up @@ -98,7 +102,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}

test("save fails with ErrorIfExists if table exists - session catalog") {
sql("create table t1 (id bigint) using foo")
sql(s"create table t1 (id bigint) using $format")
val df = spark.range(10)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).option("name", "t1")
Expand All @@ -107,7 +111,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}

test("save fails with ErrorIfExists if table exists - testcat catalog") {
sql("create table t1 (id bigint) using foo")
sql(s"create table testcat.t1 (id bigint) using $format")
val df = spark.range(10)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName)
Expand All @@ -116,7 +120,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}

test("Ignore mode if table exists - session catalog") {
sql("create table t1 (id bigint) using foo")
sql(s"create table t1 (id bigint) using $format")
val df = spark.range(10).withColumn("part", 'id % 5)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
Expand All @@ -129,7 +133,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with
}

test("Ignore mode if table exists - testcat catalog") {
sql("create table t1 (id bigint) using foo")
sql(s"create table testcat.t1 (id bigint) using $format")
val df = spark.range(10).withColumn("part", 'id % 5)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
Expand Down

0 comments on commit 746e0d1

Please sign in to comment.