Skip to content

Commit

Permalink
add flint spark configuration (#1661)
Browse files Browse the repository at this point in the history
* add spark configuration

Signed-off-by: Peng Huo <penghuo@gmail.com>

* fix IT, streaming write to flint

Signed-off-by: Peng Huo <penghuo@gmail.com>

---------

Signed-off-by: Peng Huo <penghuo@gmail.com>
  • Loading branch information
penghuo authored May 26, 2023
1 parent 252f54c commit 56e6520
Show file tree
Hide file tree
Showing 16 changed files with 309 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ public class FlintOptions implements Serializable {
/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
*/
public static final String SCROLL_SIZE = "scroll_size";
public static final String SCROLL_SIZE = "read.scroll_size";
public static final int DEFAULT_SCROLL_SIZE = 100;

public static final String REFRESH_POLICY = "refresh_policy";
public static final String REFRESH_POLICY = "write.refresh_policy";
/**
* NONE("false")
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
package org.apache.spark.sql.flint

import java.util
import java.util.NoSuchElementException

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class FlintDataSourceV2 extends TableProvider with DataSourceRegister {
class FlintDataSourceV2 extends TableProvider with DataSourceRegister with SessionConfigSupport {

private var table: FlintTable = null

Expand All @@ -37,22 +36,29 @@ class FlintDataSourceV2 extends TableProvider with DataSourceRegister {
}
}

protected def getTableName(properties: util.Map[String, String]): String = {
if (properties.containsKey("path")) properties.get("path")
else if (properties.containsKey("index")) properties.get("index")
else throw new NoSuchElementException("index or path not found")
}

protected def getFlintTable(
schema: Option[StructType],
properties: util.Map[String, String]): FlintTable = {
FlintTable(getTableName(properties), SparkSession.active, schema)
}
properties: util.Map[String, String]): FlintTable = FlintTable(properties, schema)

/**
* format name. for instance, `sql.read.format("flint")`
*/
override def shortName(): String = "flint"
override def shortName(): String = FLINT_DATASOURCE

override def supportsExternalMetadata(): Boolean = true

// scalastyle:off
/**
* extract datasource session configs and remove prefix. for example, it extract xxx.yyy from
* spark.datasource.flint.xxx.yyy. more
* reading.https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache
* /spark/sql/execution/datasources/v2/DataSourceV2Utils.scala#L52
*/
// scalastyle:off
override def keyPrefix(): String = FLINT_DATASOURCE
}

object FlintDataSourceV2 {

val FLINT_DATASOURCE = "flint"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,24 @@

package org.apache.spark.sql.flint

import java.util

import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.FlintClientBuilder

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.storage.FlintQueryCompiler
import org.apache.spark.sql.types.StructType

case class FlintPartitionReaderFactory(
tableName: String,
schema: StructType,
properties: util.Map[String, String],
options: FlintSparkConf,
pushedPredicates: Array[Predicate])
extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val query = FlintQueryCompiler(schema).compile(pushedPredicates)
val flintClient = FlintClientBuilder.build(new FlintOptions(properties))
val flintClient = FlintClientBuilder.build(options.flintOptions())
new FlintPartitionReader(flintClient.createReader(tableName, query), schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@

package org.apache.spark.sql.flint

import java.util
import java.util.TimeZone

import scala.collection.JavaConverters.mapAsScalaMapConverter

import org.opensearch.flint.core.storage.FlintWriter

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.flint.FlintPartitionWriter.{BATCH_SIZE, ID_NAME}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.json.FlintJacksonGenerator
import org.apache.spark.sql.types.StructType

Expand All @@ -28,7 +25,7 @@ import org.apache.spark.sql.types.StructType
case class FlintPartitionWriter(
flintWriter: FlintWriter,
dataSchema: StructType,
properties: util.Map[String, String],
options: FlintSparkConf,
partitionId: Int,
taskId: Long,
epochId: Long = -1)
Expand All @@ -40,13 +37,10 @@ case class FlintPartitionWriter(
}
private lazy val gen = FlintJacksonGenerator(dataSchema, flintWriter, jsonOptions)

private lazy val idOrdinal = properties.asScala.toMap
.get(ID_NAME)
private lazy val idOrdinal = options
.docIdColumnName()
.flatMap(filedName => dataSchema.getFieldIndex(filedName))

private lazy val batchSize =
properties.asScala.toMap.get(BATCH_SIZE).map(_.toInt).filter(_ > 0).getOrElse(1000)

/**
* total write doc count.
*/
Expand All @@ -62,7 +56,7 @@ case class FlintPartitionWriter(
gen.writeLineEnding()

docCount += 1
if (docCount >= batchSize) {
if (docCount >= options.batchSize()) {
gen.flush()
docCount = 0
}
Expand All @@ -86,11 +80,3 @@ case class FlintPartitionWriter(

case class FlintWriterCommitMessage(partitionId: Int, taskId: Long, epochId: Long)
extends WriterCommitMessage

/**
* Todo. Move to FlintSparkConfiguration.
*/
object FlintPartitionWriter {
val ID_NAME = "spark.flint.write.id.name"
val BATCH_SIZE = "spark.flint.write.batch.size"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,31 @@

package org.apache.spark.sql.flint

import java.util

import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.FlintClientBuilder

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.StructType

case class FlintPartitionWriterFactory(
tableName: String,
schema: StructType,
properties: util.Map[String, String])
options: FlintSparkConf)
extends DataWriterFactory
with StreamingDataWriterFactory
with Logging {

private lazy val flintClient = FlintClientBuilder.build(new FlintOptions(properties))
private lazy val flintClient = FlintClientBuilder.build(options.flintOptions())

override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
logDebug(s"create writer for partition: $partitionId, task: $taskId")
FlintPartitionWriter(
flintClient.createWriter(tableName),
schema,
properties,
options,
partitionId,
taskId)
}
Expand All @@ -42,7 +41,7 @@ case class FlintPartitionWriterFactory(
FlintPartitionWriter(
flintClient.createWriter(tableName),
schema,
properties,
options,
partitionId,
taskId,
epochId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@

package org.apache.spark.sql.flint

import java.util

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.StructType

case class FlintScan(
tableName: String,
schema: StructType,
properties: util.Map[String, String],
options: FlintSparkConf,
pushedPredicates: Array[Predicate])
extends Scan
with Batch {
Expand All @@ -26,7 +25,7 @@ case class FlintScan(
}

override def createReaderFactory(): PartitionReaderFactory = {
FlintPartitionReaderFactory(tableName, schema, properties, pushedPredicates)
FlintPartitionReaderFactory(tableName, schema, options, pushedPredicates)
}

override def toBatch: Batch = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@

package org.apache.spark.sql.flint

import java.util

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.storage.FlintQueryCompiler
import org.apache.spark.sql.types.StructType

case class FlintScanBuilder(
tableName: String,
sparkSession: SparkSession,
schema: StructType,
properties: util.Map[String, String])
case class FlintScanBuilder(tableName: String, schema: StructType, options: FlintSparkConf)
extends ScanBuilder
with SupportsPushDownV2Filters
with Logging {

private var pushedPredicate = Array.empty[Predicate]

override def build(): Scan = {
FlintScan(tableName, schema, properties, pushedPredicate)
FlintScan(tableName, schema, options, pushedPredicate)
}

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,28 @@ package org.apache.spark.sql.flint

import java.util

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* OpenSearchTable represent an index in OpenSearch.
* @param name
* OpenSearch index name.
* @param sparkSession
* sparkSession
* FlintTable.
* @param conf
* configuration
* @param userSpecifiedSchema
* userSpecifiedSchema
*/
case class FlintTable(
name: String,
sparkSession: SparkSession,
userSpecifiedSchema: Option[StructType])
case class FlintTable(conf: util.Map[String, String], userSpecifiedSchema: Option[StructType])
extends Table
with SupportsRead
with SupportsWrite {

val name = FlintSparkConf(conf).tableName()

var schema: StructType = {
if (schema == null) {
schema = if (userSpecifiedSchema.isDefined) {
Expand All @@ -47,7 +44,7 @@ case class FlintTable(
util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE, STREAMING_WRITE)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
FlintScanBuilder(name, sparkSession, schema, options.asCaseSensitiveMap())
FlintScanBuilder(name, schema, FlintSparkConf(options.asCaseSensitiveMap()))
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.apache.spark.sql.flint
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.flint.config.FlintSparkConf

case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo)
extends Write
Expand All @@ -21,7 +22,7 @@ case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo)
FlintPartitionWriterFactory(
tableName,
logicalWriteInfo.schema(),
logicalWriteInfo.options().asCaseSensitiveMap())
FlintSparkConf(logicalWriteInfo.options().asCaseSensitiveMap()))
}

override def commit(messages: Array[WriterCommitMessage]): Unit = {
Expand All @@ -37,7 +38,7 @@ case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo)
FlintPartitionWriterFactory(
tableName,
logicalWriteInfo.schema(),
logicalWriteInfo.options().asCaseSensitiveMap())
FlintSparkConf(logicalWriteInfo.options().asCaseSensitiveMap()))
}

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
Expand Down
Loading

0 comments on commit 56e6520

Please sign in to comment.