From 11f54f60b2596adcca260e50b548e9817aec8ea4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Dec 2019 22:52:02 +0800 Subject: [PATCH] refine TableProvider --- .../sql/kafka010/KafkaSourceProvider.scala | 3 +- .../sql/connector/catalog/TableProvider.java | 57 ++++++++++++++----- .../catalog/CatalogV2Implicits.scala | 5 +- .../connector/SimpleTableProvider.scala | 50 ++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 4 ++ .../apache/spark/sql/DataFrameWriter.scala | 21 +++++-- .../datasources/noop/NoopDataSource.scala | 5 +- .../datasources/v2/DataSourceV2Utils.scala | 28 ++++++++- .../datasources/v2/FileDataSourceV2.scala | 42 +++++++++++++- .../execution/datasources/v2/FileTable.scala | 2 +- .../sql/execution/streaming/console.scala | 5 +- .../sql/execution/streaming/memory.scala | 5 +- .../sources/RateStreamProvider.scala | 5 +- .../sources/TextSocketSourceProvider.scala | 5 +- .../sql/streaming/DataStreamReader.scala | 10 ++-- .../sql/streaming/DataStreamWriter.scala | 10 +++- .../connector/JavaAdvancedDataSourceV2.java | 6 +- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../JavaPartitionAwareDataSource.java | 4 +- .../JavaReportStatisticsDataSource.java | 4 +- .../JavaSchemaRequiredDataSource.java | 21 +++++-- .../sql/connector/JavaSimpleBatchTable.java | 3 +- .../sql/connector/JavaSimpleDataSourceV2.java | 4 +- .../sql/connector/JavaSimpleScanBuilder.java | 3 +- ...SourceV2DataFrameSessionCatalogSuite.scala | 8 +-- .../sql/connector/DataSourceV2SQLSuite.scala | 3 +- .../sql/connector/DataSourceV2Suite.scala | 48 ++++++++++++---- .../connector/SimpleWritableDataSource.scala | 5 +- .../SupportsCatalogOptionsSuite.scala | 4 +- .../connector/TableCapabilityCheckSuite.scala | 8 +-- .../sql/connector/V1WriteFallbackSuite.scala | 5 +- .../command/PlanResolutionSuite.scala | 4 +- .../sources/TextSocketStreamSuite.scala | 3 +- .../sources/StreamingDataSourceV2Suite.scala | 19 ++++--- .../streaming/util/BlockOnStopSource.scala | 5 +- 35 files changed, 306 insertions(+), 112 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 4ffa70f9f31dd..a5e5d01152db8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBat import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -51,7 +52,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with TableProvider + with SimpleTableProvider with Logging { import KafkaSourceProvider._ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index e9fd87d0e2d40..732c5352a15ac 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -17,7 +17,10 @@ package org.apache.spark.sql.connector.catalog; +import java.util.Map; + import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -36,26 +39,50 @@ public interface TableProvider { /** - * Return a {@link Table} instance to do read/write with user-specified options. + * Infer the schema of the table identified by the given options. + * + * @param options an immutable case-insensitive string-to-string map that can identify a table, + * e.g. file path, Kafka topic name, etc. + */ + StructType inferSchema(CaseInsensitiveStringMap options); + + /** + * Infer the partitioning of the table identified by the given options. + *

+ * By default this method returns empty partitioning, please override it if this source support + * partitioning. + * + * @param options an immutable case-insensitive string-to-string map that can identify a table, + * e.g. file path, Kafka topic name, etc. + */ + default Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + return new Transform[0]; + } + + /** + * Return a {@link Table} instance with the specified table schema, partitioning and properties + * to do read/write. The returned table should report the same schema and partitioning with the + * specified ones, or Spark may fail the operation. * - * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param schema The specified table schema. + * @param partitioning The specified table partitioning. + * @param properties The specified table properties. It's case preserving (contains exactly what + * users specified) and implementations are free to use it case sensitively or + * insensitively. It should be able to identify a table, e.g. file path, Kafka + * topic name, etc. */ - Table getTable(CaseInsensitiveStringMap options); + Table getTable(StructType schema, Transform[] partitioning, Map properties); /** - * Return a {@link Table} instance to do read/write with user-specified schema and options. + * Returns true if the source has the ability of accepting external table metadata when getting + * tables. The external table metadata includes user-specified schema from + * `DataFrameReader`/`DataStreamReader` and schema/partitioning stored in Spark catalog. *

- * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user-specified schema. - *

- * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. - * @param schema the user-specified schema. - * @throws UnsupportedOperationException + * By default this method returns false, which means the schema and partitioning passed to + * `getTable` are from the infer methods. Please override it if this source has expensive + * schema/partitioning inference and wants external table metadata to avoid inference. */ - default Table getTable(CaseInsensitiveStringMap options, StructType schema) { - throw new UnsupportedOperationException( - this.getClass().getSimpleName() + " source does not support user-specified schema"); + default boolean supportsExternalMetadata() { + return false; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 16aec23521f9f..3478af8783af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.connector.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} -import org.apache.spark.sql.types.StructType /** * Conversion helpers for working with v2 [[CatalogPlugin]]. @@ -29,9 +28,9 @@ import org.apache.spark.sql.types.StructType private[sql] object CatalogV2Implicits { import LogicalExpressions._ - implicit class PartitionTypeHelper(partitionType: StructType) { + implicit class PartitionTypeHelper(colNames: Seq[String]) { def asTransforms: Array[Transform] = { - partitionType.names.map(col => identity(reference(Seq(col)))).toArray + colNames.map(col => identity(reference(Seq(col)))).toArray } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala new file mode 100644 index 0000000000000..7bfe1df1117ac --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SimpleTableProvider.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import java.util + +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// A simple version of `TableProvider` which doesn't support specified table schema/partitioning +// and treats table properties case-insensitively. This is private and only used in builtin sources. +trait SimpleTableProvider extends TableProvider { + + def getTable(options: CaseInsensitiveStringMap): Table + + private[this] var loadedTable: Table = _ + private def getOrLoadTable(options: CaseInsensitiveStringMap): Table = { + if (loadedTable == null) loadedTable = getTable(options) + loadedTable + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + getOrLoadTable(options).schema() + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + assert(partitioning.isEmpty) + getOrLoadTable(new CaseInsensitiveStringMap(properties)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b5d7bbca9064d..870ef6b3caabf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -219,11 +219,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { dsOptions) (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => +<<<<<<< HEAD // TODO: Non-catalog paths for DSV2 are currently not well defined. userSpecifiedSchema match { case Some(schema) => (provider.getTable(dsOptions, schema), None, None) case _ => (provider.getTable(dsOptions), None, None) } +======= + DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) +>>>>>>> refine TableProvider } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c041d14c8b8df..59eb5721a4762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -257,6 +257,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) + def getTable: Table = { + // For file source, it's expensive to infer schema/partition at each write. Here we pass + // the schema of input query and the user-specified partitioning to `getTable`. If the + // query schema is not compatible with the existing data, the write can still success but + // following reads would fail. + if (provider.isInstanceOf[FileDataSourceV2]) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + provider.getTable(df.schema, partitioningAsV2, dsOptions.asCaseSensitiveMap()) + } else { + DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema = None) + } + } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val catalogManager = df.sparkSession.sessionState.catalogManager mode match { @@ -268,8 +281,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { supportsExtract, catalogManager, dsOptions) (catalog.loadTable(ident), Some(catalog), Some(ident)) - case tableProvider: TableProvider => - val t = tableProvider.getTable(dsOptions) + case _: TableProvider => + val t = getTable if (t.supports(BATCH_WRITE)) { (t, None, None) } else { @@ -314,8 +327,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { extraOptions.toMap, ignoreIfExists = createMode == SaveMode.Ignore) } - case tableProvider: TableProvider => - if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) { + case _: TableProvider => + if (getTable.supports(BATCH_WRITE)) { throw new AnalysisException(s"TableProvider implementation $source cannot be " + s"written with $createMode mode, please use Append or Overwrite " + "modes instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index b6149ce7290b7..4fad0a2484cde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,9 +22,10 @@ import java.util import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -33,7 +34,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * This is no-op datasource. It does not do anything besides consuming its input. * This can be useful for benchmarking or to cache data without any additional overhead. */ -class NoopDataSource extends TableProvider with DataSourceRegister { +class NoopDataSource extends SimpleTableProvider with DataSourceRegister { override def shortName(): String = "noop" override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 52294ae2cb851..b50b8295463eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern import org.apache.spark.internal.Logging -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, TableProvider} +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object DataSourceV2Utils extends Logging { @@ -57,4 +59,28 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } } + + def getTableFromProvider( + provider: TableProvider, + options: CaseInsensitiveStringMap, + userSpecifiedSchema: Option[StructType]): Table = { + userSpecifiedSchema match { + case Some(schema) => + if (provider.supportsExternalMetadata()) { + provider.getTable( + schema, + provider.inferPartitioning(options), + options.asCaseSensitiveMap()) + } else { + throw new UnsupportedOperationException( + s"${provider.getClass.getSimpleName} source does not support user-specified schema.") + } + + case None => + provider.getTable( + provider.inferSchema(options), + provider.inferPartitioning(options), + options.asCaseSensitiveMap()) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index e0091293d1669..30a964d7e643f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,13 +16,17 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util + import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableProvider +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -59,4 +63,40 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } + + // TODO: To reduce code diff of SPARK-29665, we create stub implementations for file source v2, so + // that we don't need to touch all the file source v2 classes. We should remove the stub + // implementation and directly implement the TableProvider APIs. + protected def getTable(options: CaseInsensitiveStringMap): Table + protected def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + throw new UnsupportedOperationException("user-specified schema") + } + + override def supportsExternalMetadata(): Boolean = true + + private var t: Table = null + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + if (t == null) t = getTable(options) + t.schema() + } + + // TODO: implement a light-weight partition inference which only looks at the path of one leaf + // file and return partition column names. For now the partition inference happens in + // `getTable`, because we don't know the user-specified schema here. + override def inferPartitioning(options: CaseInsensitiveStringMap): Array[Transform] = { + Array.empty + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + // If the table is already loaded during schema inference, return it directly. + if (t != null) { + t + } else { + getTable(new CaseInsensitiveStringMap(properties), schema) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 5329e09916bd6..59dc3ae56bf25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -102,7 +102,7 @@ abstract class FileTable( StructType(fields) } - override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms + override def partitioning: Array[Transform] = fileIndex.partitionSchema.names.toSeq.asTransforms override def properties: util.Map[String, String] = options.asCaseSensitiveMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 63e40891942ae..e471e6c601d16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -22,10 +22,11 @@ import java.util import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsTruncate, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -35,7 +36,7 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) override def schema: StructType = data.schema } -class ConsoleSinkProvider extends TableProvider +class ConsoleSinkProvider extends SimpleTableProvider with DataSourceRegister with CreatableRelationProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 911a526428cf4..395811b72d32f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -31,10 +31,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -94,7 +95,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. -object MemoryStreamTableProvider extends TableProvider { +object MemoryStreamTableProvider extends SimpleTableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3f7b0377f1eab..a093bf54b2107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -23,10 +23,11 @@ import scala.collection.JavaConverters._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -45,7 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends TableProvider with DataSourceRegister { +class RateStreamProvider extends SimpleTableProvider with DataSourceRegister { import RateStreamProvider._ override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index fae3cb765c0c9..a4dcb2049eb87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -26,15 +26,16 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { +class TextSocketSourceProvider extends SimpleTableProvider with DataSourceRegister with Logging { private def checkParameters(params: CaseInsensitiveStringMap): Unit = { logWarning("The socket source should not be used for production applications! " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index cfe6192e7d5c5..0eb4776988d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.StructType @@ -173,15 +173,13 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case provider: TableProvider => + // file source v2 does not support streaming yet. + case provider: TableProvider if !provider.isInstanceOf[FileDataSourceV2] => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) - } + val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 62a1add8b6d94..1c21a30dd5bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -308,7 +308,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } else { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") - val useV1Source = disabledSources.contains(cls.getCanonicalName) + val useV1Source = disabledSources.contains(cls.getCanonicalName) || + // file source v2 does not support streaming yet. + classOf[FileDataSourceV2].isAssignableFrom(cls) val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] @@ -316,8 +318,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { source = provider, conf = df.sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) + val table = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, userSpecifiedSchema = None) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { + table match { case table: SupportsWrite if table.supports(STREAMING_WRITE) => table case _ => createV1Sink() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f0..1a55d198361ee 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -22,15 +22,15 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaAdvancedDataSourceV2 implements TableProvider { +public class JavaAdvancedDataSourceV2 implements TestingV2Source { @Override public Table getTable(CaseInsensitiveStringMap options) { @@ -45,7 +45,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { static class AdvancedScanBuilder implements ScanBuilder, Scan, SupportsPushDownFilters, SupportsPushDownRequiredColumns { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + private StructType requiredSchema = TestingV2Source.schema(); private Filter[] filters = new Filter[0]; @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 76da45e182b3c..2f10c84c999f9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -33,7 +33,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; -public class JavaColumnarDataSourceV2 implements TableProvider { +public class JavaColumnarDataSourceV2 implements TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index fbbc457b2945d..9c1db7a379602 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -22,17 +22,17 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution; import org.apache.spark.sql.connector.read.partitioning.Distribution; import org.apache.spark.sql.connector.read.partitioning.Partitioning; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaPartitionAwareDataSource implements TableProvider { +public class JavaPartitionAwareDataSource implements TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportPartitioning { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index 49438fe668d56..9a787c3d2d92c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -19,15 +19,15 @@ import java.util.OptionalLong; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.Statistics; import org.apache.spark.sql.connector.read.SupportsReportStatistics; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaReportStatisticsDataSource implements TableProvider { +public class JavaReportStatisticsDataSource implements TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { @Override public Statistics estimateStatistics() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index 2181887ae54e2..5f73567ade025 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -17,8 +17,11 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; + import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.types.StructType; @@ -46,7 +49,18 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options, StructType schema) { + public boolean supportsExternalMetadata() { + return true; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + throw new IllegalArgumentException("requires a user-supplied schema"); + } + + @Override + public Table getTable( + StructType schema, Transform[] partitioning, Map properties) { return new JavaSimpleBatchTable() { @Override @@ -60,9 +74,4 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { } }; } - - @Override - public Table getTable(CaseInsensitiveStringMap options) { - throw new IllegalArgumentException("requires a user-supplied schema"); - } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java index 97b00477e1764..71cf97b56fe54 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Set; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; @@ -34,7 +35,7 @@ abstract class JavaSimpleBatchTable implements Table, SupportsRead { @Override public StructType schema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 8b6d71b986ff7..8852249d8a01f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -17,13 +17,13 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaSimpleDataSourceV2 implements TableProvider { +public class JavaSimpleDataSourceV2 implements TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java index 7cbba00420928..bdd9dd3ea0ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java @@ -17,6 +17,7 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; @@ -37,7 +38,7 @@ public Batch toBatch() { @Override public StructType readSchema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 08627e681f9e6..4c67888cbdc48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -92,12 +92,6 @@ class DataSourceV2DataFrameSessionCatalogSuite } } -class InMemoryTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("D'oh!") - } -} - class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { override def newTable( name: String, @@ -140,7 +134,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio spark.sessionState.catalogManager.catalog(name) } - protected val v2Format: String = classOf[InMemoryTableProvider].getName + protected val v2Format: String = classOf[FakeV2Provider].getName protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 04e5a8dfd78ba..2c8349a0e6a75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -2230,7 +2231,7 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ -class FakeV2Provider extends TableProvider { +class FakeV2Provider extends SimpleTableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { throw new UnsupportedOperationException("Unnecessary for DDL tests") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 85ff86ef3fc5b..2d8761f872da7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -418,7 +419,7 @@ object SimpleReaderFactory extends PartitionReaderFactory { abstract class SimpleBatchTable extends Table with SupportsRead { - override def schema(): StructType = new StructType().add("i", "int").add("j", "int") + override def schema(): StructType = TestingV2Source.schema override def name(): String = this.getClass.toString @@ -432,12 +433,31 @@ abstract class SimpleScanBuilder extends ScanBuilder override def toBatch: Batch = this - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def readSchema(): StructType = TestingV2Source.schema override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory } -class SimpleSinglePartitionSource extends TableProvider { +trait TestingV2Source extends TableProvider { + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + TestingV2Source.schema + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + + def getTable(options: CaseInsensitiveStringMap): Table +} + +object TestingV2Source { + val schema = new StructType().add("i", "int").add("j", "int") +} + +class SimpleSinglePartitionSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -452,9 +472,10 @@ class SimpleSinglePartitionSource extends TableProvider { } } + // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends TableProvider { +class SimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -469,7 +490,7 @@ class SimpleDataSourceV2 extends TableProvider { } } -class AdvancedDataSourceV2 extends TableProvider { +class AdvancedDataSourceV2 extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { @@ -481,7 +502,7 @@ class AdvancedDataSourceV2 extends TableProvider { class AdvancedScanBuilder extends ScanBuilder with Scan with SupportsPushDownFilters with SupportsPushDownRequiredColumns { - var requiredSchema = new StructType().add("i", "int").add("j", "int") + var requiredSchema = TestingV2Source.schema var filters = Array.empty[Filter] override def pruneColumns(requiredSchema: StructType): Unit = { @@ -567,11 +588,16 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def supportsExternalMetadata(): Boolean = true + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema @@ -583,7 +609,7 @@ class SchemaRequiredDataSource extends TableProvider { } } -class ColumnarDataSourceV2 extends TableProvider { +class ColumnarDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { @@ -648,7 +674,7 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } } -class PartitionAwareDataSource extends TableProvider { +class PartitionAwareDataSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder with SupportsReportPartitioning{ @@ -716,7 +742,7 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { } } -class ReportStatisticsDataSource extends TableProvider { +class ReportStatisticsDataSource extends SimpleWritableDataSource { class MyScanBuilder extends SimpleScanBuilder with SupportsReportStatistics { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 0070076459f19..f9306ba28e7f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -27,10 +27,11 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, ScanBuilder} import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -40,7 +41,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends SimpleTableProvider with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index cec48bb368aef..7bff955b18360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import scala.language.implicitConversions import scala.util.Try @@ -275,7 +273,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } class CatalogSupportingInMemoryTableProvider - extends InMemoryTableProvider + extends FakeV2Provider with SupportsCatalogOptions { override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index 5196ca65276e4..23e4c293cbc28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -40,7 +40,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { private val emptyMap = CaseInsensitiveStringMap.empty private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { StreamingRelationV2( - TestTableProvider, + new FakeV2Provider, "fake", table, CaseInsensitiveStringMap.empty(), @@ -211,12 +211,6 @@ private case object TestRelation extends LeafNode with NamedRelation { override def output: Seq[AttributeReference] = TableCapabilityCheckSuite.schema.toAttributes } -private object TestTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException - } -} - private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index a36e8dbdec506..10ed2048dbf61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -25,10 +25,11 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -173,7 +174,7 @@ private object InMemoryV1Provider { } class InMemoryV1Provider - extends TableProvider + extends SimpleTableProvider with DataSourceRegister with CreatableRelationProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 70b9b7ec12ea2..30b7e93a4beb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable} -import org.apache.spark.sql.connector.InMemoryTableProvider +import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, Integ class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ - private val v2Format = classOf[InMemoryTableProvider].getName + private val v2Format = classOf[FakeV2Provider].getName private val table: Table = { val t = mock(classOf[Table]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 0f80e2d431bb1..5c66fc52592b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -194,13 +194,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { } test("user-specified schema given") { - val provider = new TextSocketSourceProvider val userSpecifiedSchema = StructType( StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) + spark.readStream.schema(userSpecifiedSchema).format("socket").options(params).load() } assert(exception.getMessage.contains( "TextSocketSourceProvider source does not support user-specified schema")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 13bc811a8fe9b..05cf324f8d490 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactor import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ContinuousTrigger, RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType @@ -93,7 +94,7 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { class FakeReadMicroBatchOnly extends DataSourceRegister - with TableProvider + with SimpleTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" @@ -116,7 +117,7 @@ class FakeReadMicroBatchOnly class FakeReadContinuousOnly extends DataSourceRegister - with TableProvider + with SimpleTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" @@ -137,7 +138,7 @@ class FakeReadContinuousOnly } } -class FakeReadBothModes extends DataSourceRegister with TableProvider { +class FakeReadBothModes extends DataSourceRegister with SimpleTableProvider { override def shortName(): String = "fake-read-microbatch-continuous" override def getTable(options: CaseInsensitiveStringMap): Table = { @@ -154,7 +155,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { } } -class FakeReadNeitherMode extends DataSourceRegister with TableProvider { +class FakeReadNeitherMode extends DataSourceRegister with SimpleTableProvider { override def shortName(): String = "fake-read-neither-mode" override def getTable(options: CaseInsensitiveStringMap): Table = { @@ -168,7 +169,7 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { class FakeWriteOnly extends DataSourceRegister - with TableProvider + with SimpleTableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" @@ -183,7 +184,7 @@ class FakeWriteOnly } } -class FakeNoWrite extends DataSourceRegister with TableProvider { +class FakeNoWrite extends DataSourceRegister with SimpleTableProvider { override def shortName(): String = "fake-write-neither-mode" override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { @@ -201,7 +202,7 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with TableProvider with StreamSinkProvider { + with SimpleTableProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -378,10 +379,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[SimpleTableProvider].getTable(CaseInsensitiveStringMap.empty()) val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[SimpleTableProvider].getTable(CaseInsensitiveStringMap.empty()) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ trigger match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala index f25758c520691..c594a8523d15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockOnStopSource.scala @@ -25,11 +25,12 @@ import scala.collection.JavaConverters._ import org.apache.zookeeper.KeeperException.UnimplementedException import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability.CONTINUOUS_READ import org.apache.spark.sql.connector.read.{streaming, InputPartition, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReaderFactory, ContinuousStream, PartitionOffset} import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Source} +import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -54,7 +55,7 @@ object BlockOnStopSourceProvider { } } -class BlockOnStopSourceProvider extends StreamSourceProvider with TableProvider { +class BlockOnStopSourceProvider extends StreamSourceProvider with SimpleTableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { new BlockOnStopSourceTable(BlockOnStopSourceProvider._latch) }