From 183ada5810ef05db326d706885963cf8e67ba33e Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Tue, 23 May 2023 18:07:32 -0700 Subject: [PATCH 1/3] add write support Signed-off-by: Peng Huo --- .../opensearch/flint/core/FlintClient.java | 11 + .../opensearch/flint/core/FlintOptions.java | 12 + .../core/storage/FlintOpenSearchClient.java | 5 + .../flint/core/storage/FlintWriter.java | 22 ++ .../flint/core/storage/OpenSearchWriter.java | 91 +++++ .../sql/flint/FlintPartitionWriter.scala | 87 +++++ .../flint/FlintPartitionWriterFactory.scala | 31 ++ .../apache/spark/sql/flint/FlintTable.scala | 14 +- .../apache/spark/sql/flint/FlintWrite.scala | 30 ++ .../spark/sql/flint/FlintWriteBuilder.scala | 20 ++ .../flint/json/FlintJacksonGenerator.scala | 310 ++++++++++++++++++ .../spark/FlintDataSourceV2ITSuite.scala | 118 +++++++ .../opensearch/flint/OpenSearchSuite.scala | 21 +- 13 files changed, 759 insertions(+), 13 deletions(-) create mode 100644 flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintWriter.java create mode 100644 flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWriteBuilder.scala create mode 100644 flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonGenerator.scala diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java index f49726d37c..8ee4903d18 100644 --- a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintClient.java @@ -7,6 +7,9 @@ import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.storage.FlintReader; +import org.opensearch.flint.core.storage.FlintWriter; + +import java.io.Writer; /** * Flint index client that provides API for metadata and data operations @@ -53,4 +56,12 @@ public interface FlintClient { * @return {@link FlintReader}. */ FlintReader createReader(String indexName, String query); + + /** + * Create {@link FlintWriter}. + * + * @param indexName - index name + * @return {@link FlintWriter} + */ + FlintWriter createWriter(String indexName); } diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 6a275a5b83..a8ecb8016b 100644 --- a/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -23,6 +23,16 @@ public class FlintOptions implements Serializable { public static final String SCROLL_SIZE = "scroll_size"; public static final int DEFAULT_SCROLL_SIZE = 100; + public static final String REFRESH_POLICY = "refresh_policy"; + /** + * NONE("false") + * + * IMMEDIATE("true") + * + * WAIT_UNTIL("wait_for") + */ + public static final String DEFAULT_REFRESH_POLICY = "false"; + public FlintOptions(Map options) { this.options = options; } @@ -38,4 +48,6 @@ public int getPort() { public int getScrollSize() { return Integer.parseInt(options.getOrDefault(SCROLL_SIZE, String.valueOf(DEFAULT_SCROLL_SIZE))); } + + public String getRefreshPolicy() {return options.getOrDefault(REFRESH_POLICY, DEFAULT_REFRESH_POLICY);} } diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 779ed0e0fb..95315bccbc 100644 --- a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -6,6 +6,7 @@ package org.opensearch.flint.core.storage; import java.io.IOException; +import java.io.Writer; import java.util.ArrayList; import org.apache.http.HttpHost; @@ -127,6 +128,10 @@ public void deleteIndex(String indexName) { } } + public FlintWriter createWriter(String indexName) { + return new OpenSearchWriter(createClient(), indexName, options.getRefreshPolicy()); + } + private RestHighLevelClient createClient() { return new RestHighLevelClient( RestClient.builder(new HttpHost(host, port, "http"))); diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintWriter.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintWriter.java new file mode 100644 index 0000000000..84edd5f605 --- /dev/null +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintWriter.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import java.io.Writer; + +/** + * Extend {@link Writer}, not specific method defined for now. + */ +public abstract class FlintWriter extends Writer { + + /** + * Creates a document if it doesn’t already exist and returns an error otherwise. The next line must include a JSON document. + * + * { "create": { "_index": "movies", "_id": "tt1392214" } } + * { "title": "Prisoners", "year": 2013 } + */ + public static final String ACTION_CREATE = "create"; +} diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java new file mode 100644 index 0000000000..da6014706b --- /dev/null +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.storage; + +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.rest.RestStatus; + +import java.io.BufferedWriter; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.StringWriter; +import java.io.Writer; +import java.util.Arrays; + +/** + * OpenSearch Bulk writer. More reading https://opensearch.org/docs/1.2/opensearch/rest-api/document-apis/bulk/. + */ +public class OpenSearchWriter extends FlintWriter { + + private final String indexName; + + private final String refreshPolicy; + + private StringBuilder sb; + + private RestHighLevelClient client; + + public OpenSearchWriter(RestHighLevelClient client, String indexName, String refreshPolicy) { + this.client = client; + this.indexName = indexName; + this.sb = new StringBuilder(); + this.refreshPolicy = refreshPolicy; + } + + @Override public void write(char[] cbuf, int off, int len) { + sb.append(cbuf, off, len); + } + + /** + * Todo. StringWriter is not efficient. it will copy the cbuf when create bytes. + */ + @Override public void flush() { + try { + if (sb.length() > 0) { + byte[] bytes = sb.toString().getBytes(); + BulkResponse + response = + client.bulk( + new BulkRequest(indexName).setRefreshPolicy(refreshPolicy).add(bytes, 0, bytes.length, XContentType.JSON), + RequestOptions.DEFAULT); + // fail entire bulk request even one doc failed. + if (response.hasFailures() && Arrays.stream(response.getItems()).anyMatch(itemResp -> !isCreateConflict(itemResp))) { + throw new RuntimeException(response.buildFailureMessage()); + } + } + } catch (IOException e) { + throw new RuntimeException(String.format("Failed to execute bulk request on index: %s", indexName), e); + } finally { + sb.setLength(0); + } + } + + @Override public void close() { + try { + if (client != null) { + client.close(); + client = null; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private boolean isCreateConflict(BulkItemResponse itemResp) { + return itemResp.getOpType() == DocWriteRequest.OpType.CREATE && (itemResp.getFailure() == null || itemResp.getFailure() + .getStatus() == RestStatus.CONFLICT); + } +} + + diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala new file mode 100644 index 0000000000..fc3a87280a --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint + +import java.util +import java.util.TimeZone + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import org.opensearch.flint.core.storage.FlintWriter + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.flint.FlintPartitionWriter.{BATCH_SIZE, ID_NAME} +import org.apache.spark.sql.flint.json.FlintJacksonGenerator +import org.apache.spark.sql.types.StructType + +/** + * Submit create(put if absent) bulk request using FlintWriter. Using "create" action to avoid + * delete-create docs. + */ +case class FlintPartitionWriter( + flintWriter: FlintWriter, + dataSchema: StructType, + properties: util.Map[String, String], + partitionId: Int, + taskId: Long) + extends DataWriter[InternalRow] + with Logging { + + private lazy val jsonOptions = { + new JSONOptions(CaseInsensitiveMap(Map.empty[String, String]), TimeZone.getDefault.getID, "") + } + private lazy val gen = FlintJacksonGenerator(dataSchema, flintWriter, jsonOptions) + + private lazy val idOrdinal = properties.asScala.toMap + .get(ID_NAME) + .flatMap(filedName => dataSchema.getFieldIndex(filedName)) + + private lazy val batchSize = + properties.asScala.toMap.get(BATCH_SIZE).map(_.toInt).filter(_ > 0).getOrElse(1000) + + private var count = 0; + + /** + * { "create": { "_id": "id1" } } { "title": "Prisoners", "year": 2013 } + */ + override def write(record: InternalRow): Unit = { + gen.writeAction(FlintWriter.ACTION_CREATE, idOrdinal, record) + gen.writeLineEnding() + gen.write(record) + gen.writeLineEnding() + + count += 1 + if (count >= batchSize) { + gen.flush() + count = 0 + } + } + + override def commit(): WriterCommitMessage = { + gen.flush() + logDebug(s"Write finish on partitionId: $partitionId, taskId: $taskId") + FlintWriterCommitMessage(partitionId, taskId) + } + + override def abort(): Unit = { + // do nothing. + } + + override def close(): Unit = { + gen.close() + } +} + +case class FlintWriterCommitMessage(partitionId: Int, taskId: Long) extends WriterCommitMessage + +object FlintPartitionWriter { + val ID_NAME = "spark.flint.write.id.name" + val BATCH_SIZE = "spark.flint.write.batch.size" +} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala new file mode 100644 index 0000000000..0f5e6734ec --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint + +import java.util + +import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} +import org.apache.spark.sql.types.StructType + +case class FlintPartitionWriterFactory( + tableName: String, + schema: StructType, + properties: util.Map[String, String]) + extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + val flintClient = FlintClientBuilder.build(new FlintOptions(properties)) + + FlintPartitionWriter( + flintClient.createWriter(tableName), + schema, + properties, + partitionId, + taskId) + } +} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala index 0b762687fd..e92dcb7dce 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintTable.scala @@ -8,9 +8,10 @@ package org.apache.spark.sql.flint import java.util import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} -import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE} import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -28,7 +29,8 @@ case class FlintTable( sparkSession: SparkSession, userSpecifiedSchema: Option[StructType]) extends Table - with SupportsRead { + with SupportsRead + with SupportsWrite { var schema: StructType = { if (schema == null) { @@ -42,9 +44,13 @@ case class FlintTable( } override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(BATCH_READ) + util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { FlintScanBuilder(name, sparkSession, schema, options.asCaseSensitiveMap()) } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + FlintWriteBuilder(name, info) + } } diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala new file mode 100644 index 0000000000..b55636d6c9 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.write._ + +case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo) + extends Write + with BatchWrite + with Logging { + + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + FlintPartitionWriterFactory( + tableName, + logicalWriteInfo.schema(), + logicalWriteInfo.options().asCaseSensitiveMap()) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + logDebug(s"Write of ${logicalWriteInfo.queryId()} committed for: ${messages.length} tasks") + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def toBatch: BatchWrite = this +} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWriteBuilder.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWriteBuilder.scala new file mode 100644 index 0000000000..4326962401 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWriteBuilder.scala @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint + +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, Write, WriteBuilder} +import org.apache.spark.sql.sources.Filter + +case class FlintWriteBuilder(tableName: String, info: LogicalWriteInfo) + extends SupportsOverwrite { + + /** + * Flint client support overwrite docs with same id and does not use filters. + */ + override def overwrite(filters: Array[Filter]): WriteBuilder = this + + override def build(): Write = FlintWrite(tableName, info) +} diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonGenerator.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonGenerator.scala new file mode 100644 index 0000000000..1e35ca362a --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/json/FlintJacksonGenerator.scala @@ -0,0 +1,310 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql.flint.json + +import java.io.Writer + +import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, IntervalStringStyles, IntervalUtils, MapData, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ + +/** + * copy from spark {@link JacksonGenerator}. + */ +case class FlintJacksonGenerator(dataType: DataType, writer: Writer, options: JSONOptions) { + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to appropriate + // JSON data. Here we are using `SpecializedGetters` rather than `InternalRow` so that + // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`. + require( + dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType] + || dataType.isInstanceOf[ArrayType], + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " + + s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}") + + // `ValueWriter`s for all fields of the schema + private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { + case st: StructType => st.map(_.dataType).map(makeWriter).toArray + case _ => + throw QueryExecutionErrors.initialTypeNotTargetDataTypeError( + dataType, + StructType.simpleString) + } + + // `ValueWriter` for array data storing rows of the schema. + private lazy val arrElementWriter: ValueWriter = dataType match { + case at: ArrayType => makeWriter(at.elementType) + case _: StructType | _: MapType => makeWriter(dataType) + case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypesError(dataType) + } + + private lazy val mapElementWriter: ValueWriter = dataType match { + case mt: MapType => makeWriter(mt.valueType) + case _ => + throw QueryExecutionErrors.initialTypeNotTargetDataTypeError(dataType, MapType.simpleString) + } + + private val gen = { + val generator = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + if (options.pretty) { + generator.setPrettyPrinter(new DefaultPrettyPrinter("")) + } + if (options.writeNonAsciiCharacterAsCodePoint) { + generator.setHighestNonEscapedChar(0x7f) + } + generator + } + + private val lineSeparator: String = options.lineSeparatorInWrite + + private val timestampFormatter = TimestampFormatter( + options.timestampFormatInWrite, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInWrite, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false, + forTimestampNTZ = true) + private val dateFormatter = DateFormatter( + options.dateFormatInWrite, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + + private def makeWriter(dataType: DataType): ValueWriter = dataType match { + case NullType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNull() + + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => gen.writeBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getShort(ordinal)) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(row.getUTF8String(ordinal).toString) + + case TimestampType => + (row: SpecializedGetters, ordinal: Int) => + val timestampString = timestampFormatter.format(row.getLong(ordinal)) + gen.writeString(timestampString) + + case TimestampNTZType => + (row: SpecializedGetters, ordinal: Int) => + val timestampString = + timestampNTZFormatter.format(DateTimeUtils.microsToLocalDateTime(row.getLong(ordinal))) + gen.writeString(timestampString) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + val dateString = dateFormatter.format(row.getInt(ordinal)) + gen.writeString(dateString) + + case CalendarIntervalType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(row.getInterval(ordinal).toString) + + case YearMonthIntervalType(start, end) => + (row: SpecializedGetters, ordinal: Int) => + val ymString = IntervalUtils.toYearMonthIntervalString( + row.getInt(ordinal), + IntervalStringStyles.ANSI_STYLE, + start, + end) + gen.writeString(ymString) + + case DayTimeIntervalType(start, end) => + (row: SpecializedGetters, ordinal: Int) => + val dtString = IntervalUtils.toDayTimeIntervalString( + row.getLong(ordinal), + IntervalStringStyles.ANSI_STYLE, + start, + end) + gen.writeString(dtString) + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => gen.writeBinary(row.getBinary(ordinal)) + + case dt: DecimalType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal) + + case st: StructType => + val fieldWriters = st.map(_.dataType).map(makeWriter) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeFields(row.getStruct(ordinal, st.length), st, fieldWriters)) + + case at: ArrayType => + val elementWriter = makeWriter(at.elementType) + (row: SpecializedGetters, ordinal: Int) => + writeArray(writeArrayData(row.getArray(ordinal), elementWriter)) + + case mt: MapType => + val valueWriter = makeWriter(mt.valueType) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter)) + + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case t: UserDefinedType[_] => + makeWriter(t.sqlType) + + case _ => + (row: SpecializedGetters, ordinal: Int) => + val v = row.get(ordinal, dataType) + throw QueryExecutionErrors.failToConvertValueToJsonError(v, v.getClass, dataType) + } + + private def writeObject(f: => Unit): Unit = { + gen.writeStartObject() + f + gen.writeEndObject() + } + + private def writeFields( + row: InternalRow, + schema: StructType, + fieldWriters: Seq[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + val field = schema(i) + if (!row.isNullAt(i)) { + gen.writeFieldName(field.name) + fieldWriters(i).apply(row, i) + } else if (!options.ignoreNullFields) { + gen.writeFieldName(field.name) + gen.writeNull() + } + i += 1 + } + } + + private def writeArray(f: => Unit): Unit = { + gen.writeStartArray() + f + gen.writeEndArray() + } + + private def writeArrayData(array: ArrayData, fieldWriter: ValueWriter): Unit = { + var i = 0 + while (i < array.numElements()) { + if (!array.isNullAt(i)) { + fieldWriter.apply(array, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + private def writeMapData(map: MapData, mapType: MapType, fieldWriter: ValueWriter): Unit = { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + gen.writeFieldName(keyArray.get(i, mapType.keyType).toString) + if (!valueArray.isNullAt(i)) { + fieldWriter.apply(valueArray, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + def close(): Unit = gen.close() + + def flush(): Unit = gen.flush() + + /** + * Transforms a single `InternalRow` to JSON object using Jackson. This api calling will be + * validated through accessing `rootFieldWriters`. + * + * @param row + * The row to convert + */ + def write(row: InternalRow): Unit = { + writeObject( + writeFields( + fieldWriters = rootFieldWriters, + row = row, + schema = dataType.asInstanceOf[StructType])) + } + + /** + * Transforms multiple `InternalRow`s or `MapData`s to JSON array using Jackson + * + * @param array + * The array of rows or maps to convert + */ + def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter)) + + /** + * Transforms a single `MapData` to JSON object using Jackson This api calling will will be + * validated through accessing `mapElementWriter`. + * + * @param map + * a map to convert + */ + def write(map: MapData): Unit = { + writeObject( + writeMapData( + fieldWriter = mapElementWriter, + map = map, + mapType = dataType.asInstanceOf[MapType])) + } + + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } + + /** + * customized action. for instance. {"create": {"id": "value"}} + */ + def writeAction(action: String, idOrdinal: Option[Int], row: InternalRow): Unit = { + writeObject({ + gen.writeFieldName(action) + writeObject(idOrdinal match { + case Some(i) => + gen.writeFieldName("_id") + rootFieldWriters(i).apply(row, i) + case _ => None + }) + }) + } +} diff --git a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala index c6c8b8bf24..ef24db00e0 100644 --- a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala +++ b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala @@ -132,6 +132,124 @@ class FlintDataSourceV2ITSuite } } + test("write dataframe to flint datasource") { + val indexName = "t0004" + val mappings = + """{ + | "properties": { + | "aInt": { + | "type": "integer" + | } + | } + |}""".stripMargin + val options = + openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") + Seq(Seq.empty, 1 to 14).foreach(data => { + withIndexName(indexName) { + index(indexName, oneNodeSetting, mappings, Seq.empty) + if (data.nonEmpty) { + data + .toDF("aInt") + .coalesce(1) + .write + .format("flint") + .options(options) + .mode("overwrite") + .save(indexName) + } + + val df = spark.range(15).toDF("aInt") + df.coalesce(1) + .write + .format("flint") + .options(options) + .mode("overwrite") + .save(indexName) + + val schema = StructType(Seq(StructField("aInt", IntegerType))) + val dfResult1 = spark.sqlContext.read + .format("flint") + .options(options) + .schema(schema) + .load(indexName) + checkAnswer(dfResult1, df) + } + }) + } + +// test("write dataframe to flint datasource") { +// val indexName = "t0004" +// val mappings = +// """{ +// | "properties": { +// | "aInt": { +// | "type": "integer" +// | } +// | } +// |}""".stripMargin +// val options = +// openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") +// Seq( +// Seq.empty, +// Seq("""{"aInt": 1}"""), +// for (n <- 1 to 14) yield s"""{"aInt": $n}""".stripMargin).foreach(data => { +// withIndexName(indexName) { +// index(indexName, oneNodeSetting, mappings, data) +// +// val df = spark.range(15).toDF("aInt") +// df.coalesce(1) +// .write +// .format("flint") +// .options(options) +// .mode("overwrite") +// .save(indexName) +// +// val schema = StructType(Seq(StructField("aInt", IntegerType, true))) +// val dfResult1 = spark.sqlContext.read +// .format("flint") +// .options(openSearchOptions) +// .schema(schema) +// .load(indexName) +// checkAnswer(dfResult1, df) +// } +// }) +// } + + test("write with batch size configuration") { + val indexName = "t0004" + val options = + openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") + Seq(0, 1).foreach(batchSize => { + withIndexName(indexName) { + val mappings = + """{ + | "properties": { + | "aInt": { + | "type": "integer" + | } + | } + |}""".stripMargin + index(indexName, oneNodeSetting, mappings, Seq.empty) + + val df = spark.range(15).toDF("aInt") + df.coalesce(1) + .write + .format("flint") + .options(options + ("spark.flint.write.batch.size" -> s"$batchSize")) + .mode("overwrite") + .save(indexName) + + val schema = StructType(Seq(StructField("aInt", IntegerType))) + val dfResult1 = spark.sqlContext.read + .format("flint") + .options(openSearchOptions) + .schema(schema) + .load(indexName) + checkAnswer(dfResult1, df) + } + }) + } + /** * Copy from SPARK JDBCV2Suite. */ diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/OpenSearchSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/OpenSearchSuite.scala index 1f9170c645..8d21287d30 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/OpenSearchSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/OpenSearchSuite.scala @@ -113,15 +113,18 @@ trait OpenSearchSuite extends BeforeAndAfterAll { /** * 1. Wait until refresh the index. */ - val request = new BulkRequest().setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) - for (doc <- docs) { - request.add(new IndexRequest(index).source(doc, XContentType.JSON)) - } - val response = - openSearchClient.bulk(request, RequestOptions.DEFAULT) + if (docs.nonEmpty) { + val request = new BulkRequest().setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + for (doc <- docs) { + request.add(new IndexRequest(index).source(doc, XContentType.JSON)) + } - assume( - !response.hasFailures, - s"bulk index docs to $index failed: ${response.buildFailureMessage()}") + val response = + openSearchClient.bulk(request, RequestOptions.DEFAULT) + + assume( + !response.hasFailures, + s"bulk index docs to $index failed: ${response.buildFailureMessage()}") + } } } From 6dab0ab15a91dfbad797ddee41c3290f8c1bdabb Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 24 May 2023 07:56:32 -0700 Subject: [PATCH 2/3] add debug log Signed-off-by: Peng Huo --- .../flint/core/storage/OpenSearchWriter.java | 2 + .../sql/flint/FlintPartitionWriter.scala | 17 ++++-- .../flint/FlintPartitionWriterFactory.scala | 6 +- .../apache/spark/sql/flint/FlintWrite.scala | 2 + .../spark/FlintDataSourceV2ITSuite.scala | 55 +++---------------- 5 files changed, 29 insertions(+), 53 deletions(-) diff --git a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java index da6014706b..1e55084b2e 100644 --- a/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java +++ b/flint/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchWriter.java @@ -25,6 +25,7 @@ /** * OpenSearch Bulk writer. More reading https://opensearch.org/docs/1.2/opensearch/rest-api/document-apis/bulk/. + * It is not thread safe. */ public class OpenSearchWriter extends FlintWriter { @@ -48,6 +49,7 @@ public OpenSearchWriter(RestHighLevelClient client, String indexName, String ref } /** + * Flush the data in buffer. * Todo. StringWriter is not efficient. it will copy the cbuf when create bytes. */ @Override public void flush() { diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala index fc3a87280a..84fa05b280 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriter.scala @@ -46,7 +46,10 @@ case class FlintPartitionWriter( private lazy val batchSize = properties.asScala.toMap.get(BATCH_SIZE).map(_.toInt).filter(_ > 0).getOrElse(1000) - private var count = 0; + /** + * total write doc count. + */ + private var docCount = 0; /** * { "create": { "_id": "id1" } } { "title": "Prisoners", "year": 2013 } @@ -57,16 +60,16 @@ case class FlintPartitionWriter( gen.write(record) gen.writeLineEnding() - count += 1 - if (count >= batchSize) { + docCount += 1 + if (docCount >= batchSize) { gen.flush() - count = 0 + docCount = 0 } } override def commit(): WriterCommitMessage = { gen.flush() - logDebug(s"Write finish on partitionId: $partitionId, taskId: $taskId") + logDebug(s"Write commit on partitionId: $partitionId, taskId: $taskId") FlintWriterCommitMessage(partitionId, taskId) } @@ -76,11 +79,15 @@ case class FlintPartitionWriter( override def close(): Unit = { gen.close() + logDebug(s"Write close on partitionId: $partitionId, taskId: $taskId") } } case class FlintWriterCommitMessage(partitionId: Int, taskId: Long) extends WriterCommitMessage +/** + * Todo. Move to FlintSparkConfiguration. + */ object FlintPartitionWriter { val ID_NAME = "spark.flint.write.id.name" val BATCH_SIZE = "spark.flint.write.batch.size" diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala index 0f5e6734ec..0229e1abd2 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintPartitionWriterFactory.scala @@ -9,6 +9,7 @@ import java.util import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions} +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} import org.apache.spark.sql.types.StructType @@ -17,10 +18,11 @@ case class FlintPartitionWriterFactory( tableName: String, schema: StructType, properties: util.Map[String, String]) - extends DataWriterFactory { + extends DataWriterFactory + with Logging { override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + logDebug(s"create writer for partition: $partitionId, task: $taskId") val flintClient = FlintClientBuilder.build(new FlintOptions(properties)) - FlintPartitionWriter( flintClient.createWriter(tableName), schema, diff --git a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala index b55636d6c9..1f6bbfe990 100644 --- a/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala +++ b/flint/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/FlintWrite.scala @@ -14,6 +14,8 @@ case class FlintWrite(tableName: String, logicalWriteInfo: LogicalWriteInfo) with Logging { override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + logDebug( + s"Create factory of ${logicalWriteInfo.queryId()} with ${info.numPartitions()} partitions") FlintPartitionWriterFactory( tableName, logicalWriteInfo.schema(), diff --git a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala index ef24db00e0..8311e6a49c 100644 --- a/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala +++ b/flint/integ-test/src/test/scala/org/apache/spark/FlintDataSourceV2ITSuite.scala @@ -132,7 +132,7 @@ class FlintDataSourceV2ITSuite } } - test("write dataframe to flint datasource") { + test("write dataframe to flint") { val indexName = "t0004" val mappings = """{ @@ -177,45 +177,7 @@ class FlintDataSourceV2ITSuite }) } -// test("write dataframe to flint datasource") { -// val indexName = "t0004" -// val mappings = -// """{ -// | "properties": { -// | "aInt": { -// | "type": "integer" -// | } -// | } -// |}""".stripMargin -// val options = -// openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") -// Seq( -// Seq.empty, -// Seq("""{"aInt": 1}"""), -// for (n <- 1 to 14) yield s"""{"aInt": $n}""".stripMargin).foreach(data => { -// withIndexName(indexName) { -// index(indexName, oneNodeSetting, mappings, data) -// -// val df = spark.range(15).toDF("aInt") -// df.coalesce(1) -// .write -// .format("flint") -// .options(options) -// .mode("overwrite") -// .save(indexName) -// -// val schema = StructType(Seq(StructField("aInt", IntegerType, true))) -// val dfResult1 = spark.sqlContext.read -// .format("flint") -// .options(openSearchOptions) -// .schema(schema) -// .load(indexName) -// checkAnswer(dfResult1, df) -// } -// }) -// } - - test("write with batch size configuration") { + test("write dataframe to flint with batch size configuration") { val indexName = "t0004" val options = openSearchOptions + ("refresh_policy" -> "wait_for", "spark.flint.write.id.name" -> "aInt") @@ -240,12 +202,13 @@ class FlintDataSourceV2ITSuite .save(indexName) val schema = StructType(Seq(StructField("aInt", IntegerType))) - val dfResult1 = spark.sqlContext.read - .format("flint") - .options(openSearchOptions) - .schema(schema) - .load(indexName) - checkAnswer(dfResult1, df) + checkAnswer( + spark.sqlContext.read + .format("flint") + .options(openSearchOptions) + .schema(schema) + .load(indexName), + df) } }) } From 8c9db09d55ac3e2181e87e6c5a00ca5a5074d9e8 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 24 May 2023 13:16:10 -0700 Subject: [PATCH 3/3] add IT for FlintWriter Signed-off-by: Peng Huo --- .../core/FlintOpenSearchClientSuite.scala | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala index dbd488edee..a54d40a4e8 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/core/FlintOpenSearchClientSuite.scala @@ -58,4 +58,36 @@ class FlintOpenSearchClientSuite extends AnyFlatSpec with OpenSearchSuite with M reader.close() } } + + it should "write docs to index successfully " in { + val indexName = "t0001" + withIndexName(indexName) { + val mappings = + """{ + | "properties": { + | "aInt": { + | "type": "integer" + | } + | } + |}""".stripMargin + + val options = openSearchOptions + ("refresh_policy" -> "wait_for") + val flintClient = new FlintOpenSearchClient(new FlintOptions(options.asJava)) + index(indexName, oneNodeSetting, mappings, Seq.empty) + val writer = flintClient.createWriter(indexName) + writer.write("""{"create":{}}""") + writer.write("\n") + writer.write("""{"aInt":1}""") + writer.write("\n") + writer.flush() + writer.close() + + val match_all = null + val reader = flintClient.createReader(indexName, match_all) + reader.hasNext shouldBe true + reader.next shouldBe """{"aInt":1}""" + reader.hasNext shouldBe false + reader.close() + } + } }