From 56e652034f4085c2e676acc50957060667b1e2ef Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Fri, 26 May 2023 10:22:51 -0700 Subject: [PATCH] add flint spark configuration (#1661) * add spark configuration Signed-off-by: Peng Huo * fix IT, streaming write to flint Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../opensearch/flint/core/FlintOptions.java | 4 +- .../spark/sql/flint/FlintDataSourceV2.scala | 34 ++++--- .../flint/FlintPartitionReaderFactory.scala | 9 +- .../sql/flint/FlintPartitionWriter.scala | 24 +---- .../flint/FlintPartitionWriterFactory.scala | 13 ++- .../apache/spark/sql/flint/FlintScan.scala | 7 +- .../spark/sql/flint/FlintScanBuilder.scala | 12 +-- .../apache/spark/sql/flint/FlintTable.scala | 19 ++-- .../apache/spark/sql/flint/FlintWrite.scala | 5 +- .../spark/sql/flint/config/FlintConfig.scala | 58 +++++++++++ .../sql/flint/config/FlintSparkConf.scala | 98 +++++++++++++++++++ .../opensearch/flint/spark/FlintSpark.scala | 22 +---- .../flint/config/FlintSparkConfSuite.scala | 36 +++++++ .../spark/FlintDataSourceV2ITSuite.scala | 59 ++++++++++- .../core/FlintOpenSearchClientSuite.scala | 4 +- .../spark/FlintSparkSkippingIndexSuite.scala | 6 +- 16 files changed, 309 insertions(+), 101 deletions(-) create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintConfig.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala create mode 100644 flint/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index a8ecb8016b..d133db4449 100644 --- a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -20,10 +20,10 @@ public class FlintOptions implements Serializable { /** * Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader} */ - public static final String SCROLL_SIZE = "scroll_size"; + public static final String SCROLL_SIZE = "read.scroll_size"; public static final int DEFAULT_SCROLL_SIZE = 100; - public static final String REFRESH_POLICY = "refresh_policy"; + public static final String REFRESH_POLICY = "write.refresh_policy"; /** * NONE("false") * diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintDataSourceV2.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintDataSourceV2.scala index db2d6740e6..c967cf0966 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintDataSourceV2.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintDataSourceV2.scala @@ -6,16 +6,15 @@ package org.apache.spark.sql.flint import java.util -import java.util.NoSuchElementException -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class FlintDataSourceV2 extends TableProvider with DataSourceRegister { +class FlintDataSourceV2 extends TableProvider with DataSourceRegister with SessionConfigSupport { private var table: FlintTable = null @@ -37,22 +36,29 @@ class FlintDataSourceV2 extends TableProvider with DataSourceRegister { } } - protected def getTableName(properties: util.Map[String, String]): String = { - if (properties.containsKey("path")) properties.get("path") - else if (properties.containsKey("index")) properties.get("index") - else throw new NoSuchElementException("index or path not found") - } - protected def getFlintTable( schema: Option[StructType], - properties: util.Map[String, String]): FlintTable = { - FlintTable(getTableName(properties), SparkSession.active, schema) - } + properties: util.Map[String, String]): FlintTable = FlintTable(properties, schema) /** * format name. for instance, `sql.read.format("flint")` */ - override def shortName(): String = "flint" + override def shortName(): String = FLINT_DATASOURCE override def supportsExternalMetadata(): Boolean = true + + // scalastyle:off + /** + * extract datasource session configs and remove prefix. for example, it extract xxx.yyy from + * spark.datasource.flint.xxx.yyy. more + * reading.https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache + * /spark/sql/execution/datasources/v2/DataSourceV2Utils.scala#L52 + */ + // scalastyle:off + override def keyPrefix(): String = FLINT_DATASOURCE +} + +object FlintDataSourceV2 { + + val FLINT_DATASOURCE = "flint" } diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala index 0721a61335..74e5c29689 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionReaderFactory.scala @@ -5,25 +5,24 @@ package org.apache.spark.sql.flint -import java.util - -import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} +import org.opensearch.flint.core.FlintClientBuilder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.storage.FlintQueryCompiler import org.apache.spark.sql.types.StructType case class FlintPartitionReaderFactory( tableName: String, schema: StructType, - properties: util.Map[String, String], + options: FlintSparkConf, pushedPredicates: Array[Predicate]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val query = FlintQueryCompiler(schema).compile(pushedPredicates) - val flintClient = FlintClientBuilder.build(new FlintOptions(properties)) + val flintClient = FlintClientBuilder.build(options.flintOptions()) new FlintPartitionReader(flintClient.createReader(tableName, query), schema) } } diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala index 8098656835..0aa03100ba 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala @@ -5,11 +5,8 @@ package org.apache.spark.sql.flint -import java.util import java.util.TimeZone -import scala.collection.JavaConverters.mapAsScalaMapConverter - import org.opensearch.flint.core.storage.FlintWriter import org.apache.spark.internal.Logging @@ -17,7 +14,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.flint.FlintPartitionWriter.{BATCH_SIZE, ID_NAME} +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.json.FlintJacksonGenerator import org.apache.spark.sql.types.StructType @@ -28,7 +25,7 @@ import org.apache.spark.sql.types.StructType case class FlintPartitionWriter( flintWriter: FlintWriter, dataSchema: StructType, - properties: util.Map[String, String], + options: FlintSparkConf, partitionId: Int, taskId: Long, epochId: Long = -1) @@ -40,13 +37,10 @@ case class FlintPartitionWriter( } private lazy val gen = FlintJacksonGenerator(dataSchema, flintWriter, jsonOptions) - private lazy val idOrdinal = properties.asScala.toMap - .get(ID_NAME) + private lazy val idOrdinal = options + .docIdColumnName() .flatMap(filedName => dataSchema.getFieldIndex(filedName)) - private lazy val batchSize = - properties.asScala.toMap.get(BATCH_SIZE).map(_.toInt).filter(_ > 0).getOrElse(1000) - /** * total write doc count. */ @@ -62,7 +56,7 @@ case class FlintPartitionWriter( gen.writeLineEnding() docCount += 1 - if (docCount >= batchSize) { + if (docCount >= options.batchSize()) { gen.flush() docCount = 0 } @@ -86,11 +80,3 @@ case class FlintPartitionWriter( case class FlintWriterCommitMessage(partitionId: Int, taskId: Long, epochId: Long) extends WriterCommitMessage - -/** - * Todo. Move to FlintSparkConfiguration. - */ -object FlintPartitionWriter { - val ID_NAME = "spark.flint.write.id.name" - val BATCH_SIZE = "spark.flint.write.batch.size" -} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala index 5e5c443759..d9acdc6263 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala @@ -5,32 +5,31 @@ package org.apache.spark.sql.flint -import java.util - -import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} +import org.opensearch.flint.core.FlintClientBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.StructType case class FlintPartitionWriterFactory( tableName: String, schema: StructType, - properties: util.Map[String, String]) + options: FlintSparkConf) extends DataWriterFactory with StreamingDataWriterFactory with Logging { - private lazy val flintClient = FlintClientBuilder.build(new FlintOptions(properties)) + private lazy val flintClient = FlintClientBuilder.build(options.flintOptions()) override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { logDebug(s"create writer for partition: $partitionId, task: $taskId") FlintPartitionWriter( flintClient.createWriter(tableName), schema, - properties, + options, partitionId, taskId) } @@ -42,7 +41,7 @@ case class FlintPartitionWriterFactory( FlintPartitionWriter( flintClient.createWriter(tableName), schema, - properties, + options, partitionId, taskId, epochId) diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala index bfe6fa41e4..154e954764 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScan.scala @@ -5,16 +5,15 @@ package org.apache.spark.sql.flint -import java.util - import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.StructType case class FlintScan( tableName: String, schema: StructType, - properties: util.Map[String, String], + options: FlintSparkConf, pushedPredicates: Array[Predicate]) extends Scan with Batch { @@ -26,7 +25,7 @@ case class FlintScan( } override def createReaderFactory(): PartitionReaderFactory = { - FlintPartitionReaderFactory(tableName, schema, properties, pushedPredicates) + FlintPartitionReaderFactory(tableName, schema, options, pushedPredicates) } override def toBatch: Batch = this diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala index 6c06d11d92..71bfe36e81 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintScanBuilder.scala @@ -5,20 +5,14 @@ package org.apache.spark.sql.flint -import java.util - import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters} +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.storage.FlintQueryCompiler import org.apache.spark.sql.types.StructType -case class FlintScanBuilder( - tableName: String, - sparkSession: SparkSession, - schema: StructType, - properties: util.Map[String, String]) +case class FlintScanBuilder(tableName: String, schema: StructType, options: FlintSparkConf) extends ScanBuilder with SupportsPushDownV2Filters with Logging { @@ -26,7 +20,7 @@ case class FlintScanBuilder( private var pushedPredicate = Array.empty[Predicate] override def build(): Scan = { - FlintScan(tableName, schema, properties, pushedPredicate) + FlintScan(tableName, schema, options, pushedPredicate) } override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala index fdd525b932..9bf3937902 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala @@ -7,31 +7,28 @@ package org.apache.spark.sql.flint import java.util -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * OpenSearchTable represent an index in OpenSearch. - * @param name - * OpenSearch index name. - * @param sparkSession - * sparkSession + * FlintTable. + * @param conf + * configuration * @param userSpecifiedSchema * userSpecifiedSchema */ -case class FlintTable( - name: String, - sparkSession: SparkSession, - userSpecifiedSchema: Option[StructType]) +case class FlintTable(conf: util.Map[String, String], userSpecifiedSchema: Option[StructType]) extends Table with SupportsRead with SupportsWrite { + val name = FlintSparkConf(conf).tableName() + var schema: StructType = { if (schema == null) { schema = if (userSpecifiedSchema.isDefined) { @@ -47,7 +44,7 @@ case class FlintTable( util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE, STREAMING_WRITE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - FlintScanBuilder(name, sparkSession, schema, options.asCaseSensitiveMap()) + FlintScanBuilder(name, schema, FlintSparkConf(options.asCaseSensitiveMap())) } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala index ea87c8fea8..dca009cdd9 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala @@ -8,6 +8,7 @@ package org.apache.spark.sql.flint import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.flint.config.FlintSparkConf case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo) extends Write @@ -21,7 +22,7 @@ case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo) FlintPartitionWriterFactory( tableName, logicalWriteInfo.schema(), - logicalWriteInfo.options().asCaseSensitiveMap()) + FlintSparkConf(logicalWriteInfo.options().asCaseSensitiveMap())) } override def commit(messages: Array[WriterCommitMessage]): Unit = { @@ -37,7 +38,7 @@ case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo) FlintPartitionWriterFactory( tableName, logicalWriteInfo.schema(), - logicalWriteInfo.options().asCaseSensitiveMap()) + FlintSparkConf(logicalWriteInfo.options().asCaseSensitiveMap())) } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintConfig.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintConfig.scala new file mode 100644 index 0000000000..bc0c7e52a9 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintConfig.scala @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint.config + +import org.apache.spark.internal.config.ConfigReader + +/** + * Similar to SPARK ConfigEntry. ConfigEntry register the configuration which can not been + * modified. + */ +private case class FlintConfig(key: String) { + + private var doc = "" + + def doc(s: String): FlintConfig = { + doc = s + this + } + + def createWithDefault(defaultValue: String): FlintConfigEntry[String] = { + new FlintConfigEntryWithDefault(key, defaultValue, doc) + } + + def createOptional(): FlintConfigEntry[Option[String]] = { + new FlintOptionalConfigEntry(key, doc) + } +} + +abstract class FlintConfigEntry[T](val key: String, val doc: String) { + protected def readString(reader: ConfigReader): Option[String] = { + reader.get(key) + } + + def readFrom(reader: ConfigReader): T + + def defaultValue: Option[String] = None +} + +private class FlintConfigEntryWithDefault(key: String, defaultValue: String, doc: String) + extends FlintConfigEntry[String](key, doc) { + + override def defaultValue: Option[String] = Some(defaultValue) + + def readFrom(reader: ConfigReader): String = { + readString(reader).getOrElse(defaultValue) + } +} + +private class FlintOptionalConfigEntry(key: String, doc: String) + extends FlintConfigEntry[Option[String]](key, doc) { + + def readFrom(reader: ConfigReader): Option[String] = { + readString(reader) + } +} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala new file mode 100644 index 0000000000..8973586957 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint.config + +import java.util.{Map => JMap, NoSuchElementException} + +import scala.collection.JavaConverters._ + +import org.opensearch.flint.core.FlintOptions + +import org.apache.spark.internal.config.ConfigReader +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.flint.config.FlintSparkConf._ + +/** + * Define all the Flint Spark Related configuration.

User define the config as xxx.yyy using + * {@link FlintConfig}. + * + *

How to use config + *

    + *
  1. define config using spark.datasource.flint.xxx.yyy in spark conf. + *
  2. define config using xxx.yyy in datasource options. + *
  3. Configurations defined in the datasource options will override the same configurations + * present in the Spark configuration. + *
+ */ +object FlintSparkConf { + + val PREFIX = "spark.datasource.flint." + + def apply(conf: JMap[String, String]): FlintSparkConf = new FlintSparkConf(conf) + + /** + * Helper class, create {@link FlintOptions} from spark conf. + */ + def apply(sparkConf: RuntimeConfig): FlintOptions = new FlintOptions( + Seq(HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY, SCROLL_SIZE) + .map(conf => (conf.key, sparkConf.get(PREFIX + conf.key, conf.defaultValue.get))) + .toMap + .asJava) + + def sparkConf(key: String): String = PREFIX + key + + val HOST_ENDPOINT = FlintConfig("host") + .createWithDefault("localhost") + + val HOST_PORT = FlintConfig("port") + .createWithDefault("9200") + + val DOC_ID_COLUMN_NAME = FlintConfig("write.id_name") + .doc( + "spark write task use spark.flint.write.id.name defined column as doc id when write to " + + "flint. if not provided, use system generated random id") + .createOptional() + + val BATCH_SIZE = FlintConfig("write.batch_size") + .doc( + "The number of documents written to Flint in a single batch request is determined by the " + + "overall size of the HTTP request, which should not exceed 100MB. The actual number of " + + "documents will vary depending on the individual size of each document.") + .createWithDefault("1000") + + val REFRESH_POLICY = FlintConfig("write.refresh_policy") + .doc("refresh_policy, possible value are NONE(false), IMMEDIATE(true), WAIT_UNTIL(wait_for)") + .createWithDefault("false") + + val SCROLL_SIZE = FlintConfig("read.scroll_size") + .doc("scroll read size") + .createWithDefault("100") +} + +class FlintSparkConf(properties: JMap[String, String]) extends Serializable { + + lazy val reader = new ConfigReader(properties) + + def batchSize(): Int = BATCH_SIZE.readFrom(reader).toInt + + def docIdColumnName(): Option[String] = DOC_ID_COLUMN_NAME.readFrom(reader) + + def tableName(): String = { + if (properties.containsKey("path")) properties.get("path") + else throw new NoSuchElementException("index or path not found") + } + + /** + * Helper class, create {@link FlintOptions}. + */ + def flintOptions(): FlintOptions = { + new FlintOptions( + Seq(HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY, SCROLL_SIZE) + .map(conf => (conf.key, conf.readFrom(reader))) + .toMap + .asJava) + } +} diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index a1a6ffc16b..dcde38a68f 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -5,10 +5,7 @@ package org.opensearch.flint.spark -import scala.collection.JavaConverters._ - -import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} -import org.opensearch.flint.core.FlintOptions._ +import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSpark._ import org.opensearch.flint.spark.skipping.{FlintSparkSkippingIndex, FlintSparkSkippingStrategy} @@ -16,6 +13,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalog.Column +import org.apache.spark.sql.flint.config.FlintSparkConf /** * Flint Spark integration API entrypoint. @@ -23,13 +21,7 @@ import org.apache.spark.sql.catalog.Column class FlintSpark(val spark: SparkSession) { /** Flint client for low-level index operation */ - private val flintClient: FlintClient = { - val options = new FlintOptions( - Map( - HOST -> spark.conf.get(FLINT_INDEX_STORE_LOCATION, FLINT_INDEX_STORE_LOCATION_DEFAULT), - PORT -> spark.conf.get(FLINT_INDEX_STORE_PORT, FLINT_INDEX_STORE_PORT_DEFAULT)).asJava) - FlintClientBuilder.build(options) - } + private val flintClient: FlintClient = FlintClientBuilder.build(FlintSparkConf(spark.conf)) /** * Create index builder for creating index with fluent API. @@ -94,14 +86,6 @@ class FlintSpark(val spark: SparkSession) { object FlintSpark { - /** - * Flint configurations in Spark. TODO: shared with Flint data source config? - */ - val FLINT_INDEX_STORE_LOCATION = "spark.flint.indexstore.location" - val FLINT_INDEX_STORE_LOCATION_DEFAULT = "localhost" - val FLINT_INDEX_STORE_PORT = "spark.flint.indexstore.port" - val FLINT_INDEX_STORE_PORT_DEFAULT = "9200" - /** * Helper class for index class construct. For now only skipping index supported. */ diff --git a/flint/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala new file mode 100644 index 0000000000..4021ee989f --- /dev/null +++ b/flint/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint.config + +import scala.collection.JavaConverters._ + +import org.apache.spark.FlintSuite + +class FlintSparkConfSuite extends FlintSuite { + test("test spark conf") { + spark.conf.set("spark.datasource.flint.host", "127.0.0.1") + spark.conf.set("spark.datasource.flint.read.scroll_size", "10") + + val flintOptions = FlintSparkConf(spark.conf) + assert(flintOptions.getHost == "127.0.0.1") + assert(flintOptions.getScrollSize == 10) + + // default value + assert(flintOptions.getPort == 9200) + assert(flintOptions.getRefreshPolicy == "false") + } + + test("test spark options") { + val options = FlintSparkConf(Map("write.batch_size" -> "10", "write.id_name" -> "id").asJava) + assert(options.batchSize() == 10) + assert(options.docIdColumnName().isDefined) + assert(options.docIdColumnName().get == "id") + + // default value + assert(options.flintOptions().getHost == "localhost") + assert(options.flintOptions().getPort == 9200) + } +} diff --git a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala index 2cf211e7f7..4fb9f8fb04 100644 --- a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala +++ b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala @@ -11,6 +11,7 @@ import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.functions.asc import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest} import org.apache.spark.sql.types._ @@ -56,7 +57,7 @@ class FlintDataSourceV2ITSuite val df = spark.sqlContext.read .format("flint") - .options(openSearchOptions + ("scroll_size" -> "1")) + .options(openSearchOptions + (s"${FlintSparkConf.SCROLL_SIZE.key}" -> "1")) .schema(schema) .load(indexName) .sort(asc("id")) @@ -147,7 +148,8 @@ class FlintDataSourceV2ITSuite | } |}""".stripMargin val options = - openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") + openSearchOptions + (s"${FlintSparkConf.REFRESH_POLICY.key}" -> "wait_for", + s"${FlintSparkConf.DOC_ID_COLUMN_NAME.key}" -> "aInt") Seq(Seq.empty, 1 to 14).foreach(data => { withIndexName(indexName) { index(indexName, oneNodeSetting, mappings, Seq.empty) @@ -184,7 +186,8 @@ class FlintDataSourceV2ITSuite test("write dataframe to flint with batch size configuration") { val indexName = "t0004" val options = - openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") + openSearchOptions + (s"${FlintSparkConf.REFRESH_POLICY.key}" -> "wait_for", + s"${FlintSparkConf.DOC_ID_COLUMN_NAME.key}" -> "aInt") Seq(0, 1).foreach(batchSize => { withIndexName(indexName) { val mappings = @@ -241,8 +244,8 @@ class FlintDataSourceV2ITSuite .option("checkpointLocation", checkpointDir) .format("flint") .options(openSearchOptions) - .option("refresh_policy", "wait_for") - .option("spark.flint.write.id.name", "aInt") + .option(s"${FlintSparkConf.REFRESH_POLICY.key}", "wait_for") + .option(s"${FlintSparkConf.DOC_ID_COLUMN_NAME.key}", "aInt") .start(indexName) inputData.addData(1, 2, 3) @@ -267,6 +270,52 @@ class FlintDataSourceV2ITSuite } } + test("read index with spark conf") { + val indexName = "t0001" + withIndexName(indexName) { + simpleIndex(indexName) + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_ENDPOINT.key), openSearchHost) + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_PORT.key), openSearchPort) + val schema = StructType( + Seq( + StructField("accountId", StringType, true), + StructField("eventName", StringType, true), + StructField("eventSource", StringType, true))) + val df = spark.sqlContext.read + .format("flint") + .schema(schema) + .load(indexName) + + assert(df.count() == 1) + checkAnswer(df, Row("123", "event", "source")) + } + } + + test("datasource option should overwrite spark conf") { + val indexName = "t0001" + withIndexName(indexName) { + simpleIndex(indexName) + // set invalid host name and port which should be overwrite by datasource option. + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_ENDPOINT.key), "invalid host") + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_PORT.key), "0") + + val schema = StructType( + Seq( + StructField("accountId", StringType, true), + StructField("eventName", StringType, true), + StructField("eventSource", StringType, true))) + val df = spark.sqlContext.read + .format("flint") + // override spark conf + .options(openSearchOptions) + .schema(schema) + .load(indexName) + + assert(df.count() == 1) + checkAnswer(df, Row("123", "event", "source")) + } + } + /** * Copy from SPARK JDBCV2Suite. */ diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index a54d40a4e8..2e5f490f55 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -14,6 +14,8 @@ import org.opensearch.flint.core.storage.FlintOpenSearchClient import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY + class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with Matchers { /** Lazy initialize after container started. */ @@ -71,7 +73,7 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M | } |}""".stripMargin - val options = openSearchOptions + ("refresh_policy" -> "wait_for") + val options = openSearchOptions + (s"${REFRESH_POLICY.key}" -> "wait_for") val flintClient = new FlintOpenSearchClient(new FlintOptions(options.asJava)) index(indexName, oneNodeSetting, mappings, Seq.empty) val writer = flintClient.createWriter(indexName) diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala index de78e45526..9cc06fdf21 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala @@ -9,19 +9,19 @@ import scala.Option._ import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.opensearch.flint.OpenSearchSuite -import org.opensearch.flint.spark.FlintSpark.FLINT_INDEX_STORE_LOCATION import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite +import org.apache.spark.sql.flint.config.FlintSparkConf class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { /** Flint Spark high level API being tested */ lazy val flint: FlintSpark = { - spark.conf.set(FLINT_INDEX_STORE_LOCATION, openSearchHost) - spark.conf.set(FlintSpark.FLINT_INDEX_STORE_PORT, openSearchPort) + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_ENDPOINT.key), openSearchHost) + spark.conf.set(FlintSparkConf.sparkConf(FlintSparkConf.HOST_PORT.key), openSearchPort) new FlintSpark(spark) }