Skip to content

Commit

Permalink
Added first set of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Dec 20, 2019
1 parent 0a87228 commit a441604
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.AlterTable
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

private[sql] object CatalogV2Util {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
Expand Down Expand Up @@ -315,4 +316,14 @@ private[sql] object CatalogV2Util {
val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident)
AlterTable(tableCatalog, ident, unresolved, changes)
}

def getTableProviderCatalog(
provider: SupportsCatalogOptions,
catalogManager: CatalogManager,
options: CaseInsensitiveStringMap): TableCatalog = {
Option(provider.extractCatalog(options))
.map(catalogManager.catalog)
.getOrElse(catalogManager.v2SessionCatalog)
.asTableCatalog
}
}
18 changes: 14 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, Univocit
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
Expand Down Expand Up @@ -215,9 +215,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val table = userSpecifiedSchema match {
case Some(schema) => provider.getTable(dsOptions, schema)
case _ => provider.getTable(dsOptions)
val table = provider match {
case hasCatalog: SupportsCatalogOptions =>
val ident = hasCatalog.extractIdentifier(dsOptions)
val catalog = CatalogV2Util.getTableProviderCatalog(
hasCatalog,
sparkSession.sessionState.catalogManager,
dsOptions)
catalog.loadTable(ident)
case other =>
userSpecifiedSchema match {
case Some(schema) => provider.getTable(dsOptions, schema)
case _ => provider.getTable(dsOptions)
}
}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
Expand Down
26 changes: 12 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform}
import org.apache.spark.sql.execution.SQLExecution
Expand Down Expand Up @@ -263,13 +263,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
mode match {
case SaveMode.Append =>
verifyV2Partitioning(table)
checkPartitioningMatchesV2Table(table)
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
}

case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
verifyV2Partitioning(table)
checkPartitioningMatchesV2Table(table)
// truncate the table
runCommand(df.sparkSession, "save") {
OverwriteByExpression.byName(
Expand All @@ -280,18 +280,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val catalogOptions = provider.asInstanceOf[SupportsCatalogOptions]
val ident = catalogOptions.extractIdentifier(dsOptions)
val sessionState = df.sparkSession.sessionState
val catalog = Option(catalogOptions.extractCatalog(dsOptions))
.map(Catalogs.load(_, sessionState.conf))
.getOrElse(sessionState.catalogManager.v2SessionCatalog)
.asInstanceOf[TableCatalog]
val catalog = CatalogV2Util.getTableProviderCatalog(
catalogOptions, sessionState.catalogManager, dsOptions)

val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _)

runCommand(df.sparkSession, "save") {
CreateTableAsSelect(
catalog,
ident,
getV2Transforms,
partitioningAsV2,
df.queryExecution.analyzed,
Map(TableCatalog.PROP_PROVIDER -> source) ++ location,
extraOptions.toMap,
Expand Down Expand Up @@ -538,14 +536,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption))

case (SaveMode.Append, Some(table)) =>
verifyV2Partitioning(table)
checkPartitioningMatchesV2Table(table)
AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap)

case (SaveMode.Overwrite, _) =>
ReplaceTableAsSelect(
catalog,
ident,
getV2Transforms,
partitioningAsV2,
df.queryExecution.analyzed,
Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists,
extraOptions.toMap,
Expand All @@ -558,7 +556,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
CreateTableAsSelect(
catalog,
ident,
getV2Transforms,
partitioningAsV2,
df.queryExecution.analyzed,
Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists,
extraOptions.toMap,
Expand Down Expand Up @@ -637,7 +635,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}

/** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */
private def getV2Transforms: Seq[Transform] = {
private def partitioningAsV2: Seq[Transform] = {
val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => IdentityTransform(FieldReference(name)))
}.getOrElse(Seq.empty[Transform])
Expand All @@ -651,8 +649,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* For V2 DataSources, performs if the provided partitioning matches that of the table.
* Partitioning information is not required when appending data to V2 tables.
*/
private def verifyV2Partitioning(existingTable: Table): Unit = {
val v2Partitions = getV2Transforms
private def checkPartitioningMatchesV2Table(existingTable: Table): Unit = {
val v2Partitions = partitioningAsV2
if (v2Partitions.isEmpty) return
require(v2Partitions.sameElements(existingTable.partitioning()),
"The provided partitioning does not match of the table.\n" +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector

import scala.language.implicitConversions

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{QueryTest, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter {

import testImplicits._

private val catalogName = "testcat"

private def catalog(name: String): InMemoryTableSessionCatalog = {
spark.sessionState.catalogManager.catalog(name).asInstanceOf[InMemoryTableSessionCatalog]
}

private implicit def stringToIdentifier(value: String): Identifier = {
Identifier.of(Array.empty, value)
}

before {
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)
spark.conf.set(
s"spark.sql.catalog.$catalogName", classOf[InMemoryTableSessionCatalog].getName)
}

override def afterEach(): Unit = {
super.afterEach()
catalog(SESSION_CATALOG_NAME).clearTables()
catalog(catalogName).clearTables()
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
spark.conf.unset(s"spark.sql.catalog.$catalogName")
}

def dataFrameWriterTests(withCatalogOption: Option[String]): Unit = {
Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode =>
test(s"save works with $saveMode - no table, no partitioning, session catalog") {
val df = spark.range(10)
val dfw = df.write.mode(saveMode).option("name", "t1")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.save()

val table = catalog(SESSION_CATALOG_NAME).loadTable("t1")
assert(table.name() === "t1", "Table identifier was wrong")
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.schema() === df.schema.asNullable, "Schema did not match")
}

test(s"save works with $saveMode - no table, with partitioning, session catalog") {
val df = spark.range(10).withColumn("part", 'id % 5)
val dfw = df.write.mode(saveMode).option("name", "t1").partitionBy("part")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.save()

val table = catalog(SESSION_CATALOG_NAME).loadTable("t1")
assert(table.name() === "t1", "Table identifier was wrong")
assert(table.partitioning().length === 1, "Partitioning should not be empty")
assert(table.partitioning().head.references().head.fieldNames().head === "part",
"Partitioning was incorrect")
assert(table.schema() === df.schema.asNullable, "Schema did not match")
}
}

test("save fails with ErrorIfExists if table exists") {
sql("create table t1 (id bigint) using foo")
val df = spark.range(10)
intercept[TableAlreadyExistsException] {
val dfw = df.write.option("name", "t1")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.save()
}
}

test("Ignore mode if table exists") {
sql("create table t1 (id bigint) using foo")
val df = spark.range(10).withColumn("part", 'id % 5)
intercept[TableAlreadyExistsException] {
val dfw = df.write.mode(SaveMode.Ignore).option("name", "t1")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.save()
}

val table = catalog(SESSION_CATALOG_NAME).loadTable("t1")
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.schema() === new StructType().add("id", LongType), "Schema did not match")
}
}

dataFrameWriterTests(None)

dataFrameWriterTests(Some(catalogName))
}

class CatalogSupportingInMemoryTableProvider
extends InMemoryTableProvider
with SupportsCatalogOptions {

override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = {
val name = options.get("name")
assert(name != null, "The name should be provided for this table")
Identifier.of(Array.empty, name)
}

override def extractCatalog(options: CaseInsensitiveStringMap): String = {
options.get("catalog")
}
}

0 comments on commit a441604

Please sign in to comment.