From a4416045827dece6128afb253c30d5c0862d594a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 19 Dec 2019 18:02:53 -0800 Subject: [PATCH] Added first set of tests --- .../sql/connector/catalog/CatalogV2Util.scala | 11 ++ .../apache/spark/sql/DataFrameReader.scala | 18 ++- .../apache/spark/sql/DataFrameWriter.scala | 26 ++-- .../SupportsCatalogOptionsSuite.scala | 134 ++++++++++++++++++ 4 files changed, 171 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 2f4914dd7db30..671beb3ab1500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object CatalogV2Util { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -315,4 +316,14 @@ private[sql] object CatalogV2Util { val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident) AlterTable(tableCatalog, ident, unresolved, changes) } + + def getTableProviderCatalog( + provider: SupportsCatalogOptions, + catalogManager: CatalogManager, + options: CaseInsensitiveStringMap): TableCatalog = { + Option(provider.extractCatalog(options)) + .map(catalogManager.catalog) + .getOrElse(catalogManager.v2SessionCatalog) + .asTableCatalog + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8570e4640feea..ab3bbccb721e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, Univocit import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser -import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -215,9 +215,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + val table = provider match { + case hasCatalog: SupportsCatalogOptions => + val ident = hasCatalog.extractIdentifier(dsOptions) + val catalog = CatalogV2Util.getTableProviderCatalog( + hasCatalog, + sparkSession.sessionState.catalogManager, + dsOptions) + catalog.loadTable(ident) + case other => + userSpecifiedSchema match { + case Some(schema) => provider.getTable(dsOptions, schema) + case _ => provider.getTable(dsOptions) + } } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b81d2bad44b3e..d8fb71cf30829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -263,13 +263,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) } case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) // truncate the table runCommand(df.sparkSession, "save") { OverwriteByExpression.byName( @@ -280,10 +280,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val catalogOptions = provider.asInstanceOf[SupportsCatalogOptions] val ident = catalogOptions.extractIdentifier(dsOptions) val sessionState = df.sparkSession.sessionState - val catalog = Option(catalogOptions.extractCatalog(dsOptions)) - .map(Catalogs.load(_, sessionState.conf)) - .getOrElse(sessionState.catalogManager.v2SessionCatalog) - .asInstanceOf[TableCatalog] + val catalog = CatalogV2Util.getTableProviderCatalog( + catalogOptions, sessionState.catalogManager, dsOptions) val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) @@ -291,7 +289,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ location, extraOptions.toMap, @@ -538,14 +536,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => ReplaceTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -558,7 +556,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -637,7 +635,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } /** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */ - private def getV2Transforms: Seq[Transform] = { + private def partitioningAsV2: Seq[Transform] = { val partitioning = partitioningColumns.map { colNames => colNames.map(name => IdentityTransform(FieldReference(name))) }.getOrElse(Seq.empty[Transform]) @@ -651,8 +649,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * For V2 DataSources, performs if the provided partitioning matches that of the table. * Partitioning information is not required when appending data to V2 tables. */ - private def verifyV2Partitioning(existingTable: Table): Unit = { - val v2Partitions = getV2Transforms + private def checkPartitioningMatchesV2Table(existingTable: Table): Unit = { + val v2Partitions = partitioningAsV2 if (v2Partitions.isEmpty) return require(v2Partitions.sameElements(existingTable.partitioning()), "The provided partitioning does not match of the table.\n" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala new file mode 100644 index 0000000000000..0a77c1710761c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + import testImplicits._ + + private val catalogName = "testcat" + + private def catalog(name: String): InMemoryTableSessionCatalog = { + spark.sessionState.catalogManager.catalog(name).asInstanceOf[InMemoryTableSessionCatalog] + } + + private implicit def stringToIdentifier(value: String): Identifier = { + Identifier.of(Array.empty, value) + } + + before { + spark.conf.set( + V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) + spark.conf.set( + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableSessionCatalog].getName) + } + + override def afterEach(): Unit = { + super.afterEach() + catalog(SESSION_CATALOG_NAME).clearTables() + catalog(catalogName).clearTables() + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + spark.conf.unset(s"spark.sql.catalog.$catalogName") + } + + def dataFrameWriterTests(withCatalogOption: Option[String]): Unit = { + Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode => + test(s"save works with $saveMode - no table, no partitioning, session catalog") { + val df = spark.range(10) + val dfw = df.write.mode(saveMode).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.name() === "t1", "Table identifier was wrong") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + } + + test(s"save works with $saveMode - no table, with partitioning, session catalog") { + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.mode(saveMode).option("name", "t1").partitionBy("part") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.name() === "t1", "Table identifier was wrong") + assert(table.partitioning().length === 1, "Partitioning should not be empty") + assert(table.partitioning().head.references().head.fieldNames().head === "part", + "Partitioning was incorrect") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + } + } + + test("save fails with ErrorIfExists if table exists") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + } + } + + test("Ignore mode if table exists") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10).withColumn("part", 'id % 5) + intercept[TableAlreadyExistsException] { + val dfw = df.write.mode(SaveMode.Ignore).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + } + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + } + } + + dataFrameWriterTests(None) + + dataFrameWriterTests(Some(catalogName)) +} + +class CatalogSupportingInMemoryTableProvider + extends InMemoryTableProvider + with SupportsCatalogOptions { + + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = { + val name = options.get("name") + assert(name != null, "The name should be provided for this table") + Identifier.of(Array.empty, name) + } + + override def extractCatalog(options: CaseInsensitiveStringMap): String = { + options.get("catalog") + } +}