Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FlintTable batch write #1653

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import org.opensearch.flint.core.metadata.FlintMetadata;
import org.opensearch.flint.core.storage.FlintReader;
import org.opensearch.flint.core.storage.FlintWriter;

import java.io.Writer;

/**
* Flint index client that provides API for metadata and data operations
Expand Down Expand Up @@ -53,4 +56,12 @@ public interface FlintClient {
* @return {@link FlintReader}.
*/
FlintReader createReader(String indexName, String query);

/**
* Create {@link FlintWriter}.
*
* @param indexName - index name
* @return {@link FlintWriter}
*/
FlintWriter createWriter(String indexName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ public class FlintOptions implements Serializable {
public static final String SCROLL_SIZE = "scroll_size";
public static final int DEFAULT_SCROLL_SIZE = 100;

public static final String REFRESH_POLICY = "refresh_policy";
/**
* NONE("false")
*
* IMMEDIATE("true")
*
* WAIT_UNTIL("wait_for")
*/
public static final String DEFAULT_REFRESH_POLICY = "false";

public FlintOptions(Map<String, String> options) {
this.options = options;
}
Expand All @@ -38,4 +48,6 @@ public int getPort() {
public int getScrollSize() {
return Integer.parseInt(options.getOrDefault(SCROLL_SIZE, String.valueOf(DEFAULT_SCROLL_SIZE)));
}

public String getRefreshPolicy() {return options.getOrDefault(REFRESH_POLICY, DEFAULT_REFRESH_POLICY);}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.flint.core.storage;

import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;

import org.apache.http.HttpHost;
Expand Down Expand Up @@ -127,6 +128,10 @@ public void deleteIndex(String indexName) {
}
}

public FlintWriter createWriter(String indexName) {
return new OpenSearchWriter(createClient(), indexName, options.getRefreshPolicy());
}

private RestHighLevelClient createClient() {
return new RestHighLevelClient(
RestClient.builder(new HttpHost(host, port, "http")));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import java.io.Writer;

/**
* Extend {@link Writer}, not specific method defined for now.
*/
public abstract class FlintWriter extends Writer {

/**
* Creates a document if it doesn’t already exist and returns an error otherwise. The next line must include a JSON document.
*
* { "create": { "_index": "movies", "_id": "tt1392214" } }
* { "title": "Prisoners", "year": 2013 }
*/
public static final String ACTION_CREATE = "create";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.storage;

import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.rest.RestStatus;

import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Arrays;

/**
* OpenSearch Bulk writer. More reading https://opensearch.org/docs/1.2/opensearch/rest-api/document-apis/bulk/.
* It is not thread safe.
*/
public class OpenSearchWriter extends FlintWriter {
dai-chen marked this conversation as resolved.
Show resolved Hide resolved

private final String indexName;

private final String refreshPolicy;

private StringBuilder sb;

private RestHighLevelClient client;

public OpenSearchWriter(RestHighLevelClient client, String indexName, String refreshPolicy) {
this.client = client;
this.indexName = indexName;
this.sb = new StringBuilder();
this.refreshPolicy = refreshPolicy;
}

@Override public void write(char[] cbuf, int off, int len) {
sb.append(cbuf, off, len);
}

/**
* Flush the data in buffer.
* Todo. StringWriter is not efficient. it will copy the cbuf when create bytes.
*/
@Override public void flush() {
try {
if (sb.length() > 0) {
byte[] bytes = sb.toString().getBytes();
BulkResponse
response =
client.bulk(
new BulkRequest(indexName).setRefreshPolicy(refreshPolicy).add(bytes, 0, bytes.length, XContentType.JSON),
RequestOptions.DEFAULT);
// fail entire bulk request even one doc failed.
if (response.hasFailures() && Arrays.stream(response.getItems()).anyMatch(itemResp -> !isCreateConflict(itemResp))) {
throw new RuntimeException(response.buildFailureMessage());
}
}
} catch (IOException e) {
throw new RuntimeException(String.format("Failed to execute bulk request on index: %s", indexName), e);
} finally {
sb.setLength(0);
}
}

@Override public void close() {
try {
if (client != null) {
client.close();
client = null;
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private boolean isCreateConflict(BulkItemResponse itemResp) {
return itemResp.getOpType() == DocWriteRequest.OpType.CREATE && (itemResp.getFailure() == null || itemResp.getFailure()
.getStatus() == RestStatus.CONFLICT);
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.flint

import java.util
import java.util.TimeZone

import scala.collection.JavaConverters.mapAsScalaMapConverter

import org.opensearch.flint.core.storage.FlintWriter

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.flint.FlintPartitionWriter.{BATCH_SIZE, ID_NAME}
import org.apache.spark.sql.flint.json.FlintJacksonGenerator
import org.apache.spark.sql.types.StructType

/**
* Submit create(put if absent) bulk request using FlintWriter. Using "create" action to avoid
* delete-create docs.
*/
case class FlintPartitionWriter(
flintWriter: FlintWriter,
dataSchema: StructType,
properties: util.Map[String, String],
partitionId: Int,
taskId: Long)
extends DataWriter[InternalRow]
with Logging {

private lazy val jsonOptions = {
new JSONOptions(CaseInsensitiveMap(Map.empty[String, String]), TimeZone.getDefault.getID, "")
}
private lazy val gen = FlintJacksonGenerator(dataSchema, flintWriter, jsonOptions)

private lazy val idOrdinal = properties.asScala.toMap
.get(ID_NAME)
.flatMap(filedName => dataSchema.getFieldIndex(filedName))

private lazy val batchSize =
properties.asScala.toMap.get(BATCH_SIZE).map(_.toInt).filter(_ > 0).getOrElse(1000)

/**
* total write doc count.
*/
private var docCount = 0;

/**
* { "create": { "_id": "id1" } } { "title": "Prisoners", "year": 2013 }
*/
override def write(record: InternalRow): Unit = {
gen.writeAction(FlintWriter.ACTION_CREATE, idOrdinal, record)
gen.writeLineEnding()
gen.write(record)
gen.writeLineEnding()

docCount += 1
if (docCount >= batchSize) {
gen.flush()
docCount = 0
}
}

override def commit(): WriterCommitMessage = {
gen.flush()
logDebug(s"Write commit on partitionId: $partitionId, taskId: $taskId")
FlintWriterCommitMessage(partitionId, taskId)
}

override def abort(): Unit = {
// do nothing.
}

override def close(): Unit = {
gen.close()
logDebug(s"Write close on partitionId: $partitionId, taskId: $taskId")
}
}

case class FlintWriterCommitMessage(partitionId: Int, taskId: Long) extends WriterCommitMessage

/**
* Todo. Move to FlintSparkConfiguration.
*/
object FlintPartitionWriter {
val ID_NAME = "spark.flint.write.id.name"
val BATCH_SIZE = "spark.flint.write.batch.size"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.flint

import java.util

import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
import org.apache.spark.sql.types.StructType

case class FlintPartitionWriterFactory(
tableName: String,
schema: StructType,
properties: util.Map[String, String])
extends DataWriterFactory
with Logging {
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
logDebug(s"create writer for partition: $partitionId, task: $taskId")
val flintClient = FlintClientBuilder.build(new FlintOptions(properties))
FlintPartitionWriter(
flintClient.createWriter(tableName),
schema,
properties,
partitionId,
taskId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ package org.apache.spark.sql.flint
import java.util

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -28,7 +29,8 @@ case class FlintTable(
sparkSession: SparkSession,
userSpecifiedSchema: Option[StructType])
extends Table
with SupportsRead {
with SupportsRead
with SupportsWrite {

var schema: StructType = {
if (schema == null) {
Expand All @@ -42,9 +44,13 @@ case class FlintTable(
}

override def capabilities(): util.Set[TableCapability] =
util.EnumSet.of(BATCH_READ)
util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
FlintScanBuilder(name, sparkSession, schema, options.asCaseSensitiveMap())
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
FlintWriteBuilder(name, info)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.flint

import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.write._

case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo)
extends Write
with BatchWrite
with Logging {

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
logDebug(
s"Create factory of ${logicalWriteInfo.queryId()} with ${info.numPartitions()} partitions")
FlintPartitionWriterFactory(
tableName,
logicalWriteInfo.schema(),
logicalWriteInfo.options().asCaseSensitiveMap())
}

override def commit(messages: Array[WriterCommitMessage]): Unit = {
logDebug(s"Write of ${logicalWriteInfo.queryId()} committed for: ${messages.length} tasks")
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {}

override def toBatch: BatchWrite = this
}
Loading