From fac825c5ac0e083b337ca7b78254747a21137dec Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 23 May 2023 14:26:30 -0700 Subject: [PATCH 1/3] Serialized and deserialize partition index in Flint metadata Signed-off-by: Chen Dai --- .../scala/org/opensearch/flint/spark/FlintSpark.scala | 3 ++- .../flint/spark/skipping/FlintSparkSkippingIndex.scala | 2 +- .../flint/spark/skipping/FlintSparkSkippingStrategy.scala | 8 +++++++- .../skipping/partition/PartitionSkippingStrategy.scala | 7 +++++-- .../flint/spark/FlintSparkSkippingIndexSuite.scala | 8 ++++++-- 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index a1a6ffc16b..da631de206 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -144,7 +144,8 @@ object FlintSpark { allColumns.getOrElse( colName, throw new IllegalArgumentException(s"Column $colName does not exist"))) - .map(col => new PartitionSkippingStrategy(col.name, col.dataType)) + .map(col => + new PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType)) .foreach(indexedCol => indexedColumns = indexedColumns :+ indexedCol) this } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index 211fc585ea..080442842c 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -48,7 +48,7 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS } private def getMetaInfo: String = { - Serialization.write(indexedColumns.map(_.indexedColumn)) + Serialization.write(indexedColumns) } private def getSchema: String = { diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 2b5b14b503..e9f29be100 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -10,10 +10,16 @@ package org.opensearch.flint.spark.skipping */ trait FlintSparkSkippingStrategy { + /** + * Skipping strategy kind. + */ + val kind: String + /** * Indexed column name and its Spark SQL type. */ - val indexedColumn: (String, String) + val columnName: String + val columnType: String /** * @return diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index e3c4003c70..fd814873e1 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -10,11 +10,14 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy /** * Skipping strategy for partitioned columns of source table. */ -class PartitionSkippingStrategy(override val indexedColumn: (String, String)) +class PartitionSkippingStrategy( + override val kind: String = "partition", + override val columnName: String, + override val columnType: String) extends FlintSparkSkippingStrategy { override def outputSchema(): Map[String, String] = { - Map(indexedColumn._1 -> convertToFlintType(indexedColumn._2)) + Map(columnName -> convertToFlintType(columnType)) } // TODO: move this mapping info to single place diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala index de78e45526..0cca56bfd4 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala @@ -72,10 +72,14 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { | "kind": "SkippingIndex", | "indexedColumns": [ | { - | "year": "int" + | "kind": "partition", + | "columnName": "year", + | "columnType": "int" | }, | { - | "month": "int" + | "kind": "partition", + | "columnName": "month", + | "columnType": "int" | }] | }, | "properties": { From 05cde40f76a6d666256429ea0bcaa41a4446e5bf Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 23 May 2023 15:59:47 -0700 Subject: [PATCH 2/3] Add query rewriter rule Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkExtensions.scala | 8 +- .../ApplyFlintSparkSkippingIndex.scala | 80 +++++++++++++++++++ .../spark/FlintSparkSkippingIndexSuite.scala | 19 ++++- 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala index f0a80c92a6..2afc32b0dc 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkExtensions.scala @@ -5,9 +5,15 @@ package org.opensearch.flint.spark +import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex + import org.apache.spark.sql.SparkSessionExtensions class FlintSparkExtensions extends (SparkSessionExtensions => Unit) { - override def apply(v1: SparkSessionExtensions): Unit = {} + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectOptimizerRule { spark => + new ApplyFlintSparkSkippingIndex(new FlintSpark(spark)) + } + } } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala new file mode 100644 index 0000000000..5fd885f3e0 --- /dev/null +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping + +import org.json4s._ +import org.json4s.native.JsonMethods._ +import org.json4s.native.Serialization +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.FlintSpark +import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy + +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} + +/** + * Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and + * table scan operator to leverage additional skipping data structure and accelerate query by + * reducing data scanned significantly. + * + * @param flint + * Flint Spark API + */ +class ApplyFlintSparkSkippingIndex(val flint: FlintSpark) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter( + condition, + relation @ LogicalRelation( + baseRelation @ HadoopFsRelation(location, _, _, _, _, _), + _, + Some(table), + false)) => + val indexName = + FlintSparkSkippingIndex.getIndexName(table.identifier.table) // TODO: ignore schema name + val index = flint.describeIndex(indexName) + val indexedCols = parseMetadata(index) + + if (indexedCols.isEmpty) { + filter + } else { + filter + } + } + + private def parseMetadata(index: Option[FlintMetadata]): Seq[FlintSparkSkippingStrategy] = { + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + if (index.isDefined) { + + // TODO: move all these JSON parsing to FlintMetadata once Flint spec finalized + val json = parse(index.get.getContent) + val kind = (json \ "_meta" \ "kind").extract[String] + + if (kind == "SkippingIndex") { + val indexedColumns = (json \ "_meta" \ "indexedColumns").asInstanceOf[JArray] + + indexedColumns.arr.map { colInfo => + val kind = (colInfo \ "kind").extract[String] + val columnName = (colInfo \ "columnName").extract[String] + val columnType = (colInfo \ "columnType").extract[String] + + kind match { + case "partition" => + new PartitionSkippingStrategy(columnName = columnName, columnType = columnType) + case other => + throw new IllegalStateException(s"Unknown skipping strategy: $other") + } + } + } else { + Seq() + } + } else { + Seq() + } + } +} diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala index 0cca56bfd4..6ca75b671f 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala @@ -10,7 +10,7 @@ import scala.Option._ import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.opensearch.flint.OpenSearchSuite import org.opensearch.flint.spark.FlintSpark.FLINT_INDEX_STORE_LOCATION -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.opensearch.flint.spark.skipping.{ApplyFlintSparkSkippingIndex, FlintSparkSkippingIndex} import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -97,6 +97,23 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { |""".stripMargin) } + test("applicable query can be rewritten with skipping index") { + flint + .skippingIndex() + .onTable(testTable) + .addPartitionIndex("year", "month") + .create() + + val query = sql(s""" + | SELECT name + | FROM $testTable + | WHERE year = 2023 AND month = 04 + |""".stripMargin) + + val rewriter = new ApplyFlintSparkSkippingIndex(flint) + rewriter.apply(query.queryExecution.optimizedPlan) + } + test("can have only 1 skipping index on a table") { flint .skippingIndex() From d6dbd46279adadeed4c599295219d40855e959c9 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 24 May 2023 15:59:08 -0700 Subject: [PATCH 3/3] Refactor skipping index interface with query Signed-off-by: Chen Dai --- .../opensearch/flint/spark/FlintSpark.scala | 47 ++++++++++-- .../flint/spark/FlintSparkIndex.scala | 14 ++++ .../ApplyFlintSparkSkippingIndex.scala | 71 +++++++++---------- .../skipping/FlintSparkSkippingIndex.scala | 43 +++++++++-- .../skipping/FlintSparkSkippingStrategy.scala | 13 ++++ .../partition/PartitionSkippingStrategy.scala | 11 +++ .../spark/FlintSparkSkippingIndexSuite.scala | 13 ++-- 7 files changed, 158 insertions(+), 54 deletions(-) diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index da631de206..dd56cbd2b9 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -7,6 +7,9 @@ package org.opensearch.flint.spark import scala.collection.JavaConverters._ +import org.json4s.{Formats, JArray, NoTypeHints} +import org.json4s.native.JsonMethods.parse +import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions} import org.opensearch.flint.core.FlintOptions._ import org.opensearch.flint.core.metadata.FlintMetadata @@ -66,16 +69,17 @@ class FlintSpark(val spark: SparkSession) { * @return * Flint index metadata */ - def describeIndex(indexName: String): Option[FlintMetadata] = { + def describeIndex(indexName: String): Option[FlintSparkIndex] = { if (flintClient.exists(indexName)) { - Some(flintClient.getIndexMetadata(indexName)) + val metadata = flintClient.getIndexMetadata(indexName) + Some(deserialize(metadata)) } else { Option.empty } } /** - * Delete index. + * Delete a Flint index. * * @param indexName * index name @@ -90,6 +94,41 @@ class FlintSpark(val spark: SparkSession) { false } } + + /* + * TODO: Remove all these JSON parsing logic once Flint spec finalized + * and FlintMetadata is strong-typed + * + * For now, deserialize skipping strategies out of Flint metadata json + * ex. extract Seq(Partition("year", "int"), ValueList("name")) from + * { "_meta": { "indexedColumns": [ {...partition...}, {...value list...} ] } } + * + */ + private def deserialize(metadata: FlintMetadata): FlintSparkIndex = { + implicit val formats: Formats = Serialization.formats(NoTypeHints) + + val meta = parse(metadata.getContent) \ "_meta" + val tableName = (meta \ "source").extract[String] + val indexType = (meta \ "kind").extract[String] + val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray] + + indexType match { + case "SkippingIndex" => + val strategies = indexedColumns.arr.map { colInfo => + val skippingType = (colInfo \ "kind").extract[String] + val columnName = (colInfo \ "columnName").extract[String] + val columnType = (colInfo \ "columnType").extract[String] + + skippingType match { + case "partition" => + new PartitionSkippingStrategy(columnName = columnName, columnType = columnType) + case other => + throw new IllegalStateException(s"Unknown skipping strategy: $other") + } + } + new FlintSparkSkippingIndex(spark, tableName, strategies) + } + } } object FlintSpark { @@ -156,7 +195,7 @@ object FlintSpark { def create(): Unit = { require(tableName.nonEmpty, "table name cannot be empty") - flint.createIndex(new FlintSparkSkippingIndex(tableName, indexedColumns)) + flint.createIndex(new FlintSparkSkippingIndex(flint.spark, tableName, indexedColumns)) } } } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 949d367c3c..244e3ce37b 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -7,11 +7,18 @@ package org.opensearch.flint.spark import org.opensearch.flint.core.metadata.FlintMetadata +import org.apache.spark.sql.DataFrame + /** * Flint index interface in Spark. */ trait FlintSparkIndex { + /** + * Index type + */ + val kind: String + /** * @return * Flint index name @@ -24,4 +31,11 @@ trait FlintSparkIndex { */ def metadata(): FlintMetadata + /** + * Query current Flint index by Spark data frame. + * + * @return + * data frame + */ + def query(): DataFrame } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 5fd885f3e0..ca8aadf8dd 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -5,13 +5,11 @@ package org.opensearch.flint.spark.skipping -import org.json4s._ -import org.json4s.native.JsonMethods._ -import org.json4s.native.Serialization -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSpark -import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{And, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -28,53 +26,52 @@ class ApplyFlintSparkSkippingIndex(val flint: FlintSpark) extends Rule[LogicalPl override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter( - condition, + condition: Predicate, relation @ LogicalRelation( baseRelation @ HadoopFsRelation(location, _, _, _, _, _), _, Some(table), false)) => - val indexName = - FlintSparkSkippingIndex.getIndexName(table.identifier.table) // TODO: ignore schema name + // Spark optimize recursively + // if (location.isInstanceOf[FlintSparkSkippingFileIndex]) { + // return filter + // } + + val indexName = getSkippingIndexName(table.identifier.table) // TODO: ignore schema name val index = flint.describeIndex(indexName) - val indexedCols = parseMetadata(index) - if (indexedCols.isEmpty) { + if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) { + val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] + val rewrittenPredicate = rewriteToPredicateOnSkippingIndex(skippingIndex, condition) + val selectedFiles = getSelectedFilesToScanAfterSkip(skippingIndex, rewrittenPredicate) + filter } else { filter } } - private def parseMetadata(index: Option[FlintMetadata]): Seq[FlintSparkSkippingStrategy] = { - implicit val formats: Formats = Serialization.formats(NoTypeHints) - - if (index.isDefined) { - - // TODO: move all these JSON parsing to FlintMetadata once Flint spec finalized - val json = parse(index.get.getContent) - val kind = (json \ "_meta" \ "kind").extract[String] + private def rewriteToPredicateOnSkippingIndex( + index: FlintSparkSkippingIndex, + condition: Predicate): Predicate = { - if (kind == "SkippingIndex") { - val indexedColumns = (json \ "_meta" \ "indexedColumns").asInstanceOf[JArray] + index.indexedColumns + .map(index => index.rewritePredicate(condition)) + .filter(pred => pred.isDefined) + .map(pred => pred.get) + .reduce(And(_, _)) + } - indexedColumns.arr.map { colInfo => - val kind = (colInfo \ "kind").extract[String] - val columnName = (colInfo \ "columnName").extract[String] - val columnType = (colInfo \ "columnType").extract[String] + private def getSelectedFilesToScanAfterSkip( + index: FlintSparkSkippingIndex, + rewrittenPredicate: Predicate): Set[String] = { - kind match { - case "partition" => - new PartitionSkippingStrategy(columnName = columnName, columnType = columnType) - case other => - throw new IllegalStateException(s"Unknown skipping strategy: $other") - } - } - } else { - Seq() - } - } else { - Seq() - } + index + .query() + .filter(new Column(rewrittenPredicate)) + .select(FILE_PATH_COLUMN) + .collect + .map(_.getString(0)) + .toSet } } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index 080442842c..8b903dc0a6 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -9,7 +9,10 @@ import org.json4s._ import org.json4s.native.Serialization import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSparkIndex -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getIndexName, FILE_PATH_COLUMN} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} /** * Flint skipping index in Spark. @@ -17,14 +20,20 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getIndexName * @param tableName * source table name */ -class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkSkippingStrategy]) +case class FlintSparkSkippingIndex( + spark: SparkSession, + tableName: String, + indexedColumns: Seq[FlintSparkSkippingStrategy]) extends FlintSparkIndex { + /** Skipping index type */ + override val kind: String = SKIPPING_INDEX_TYPE + /** Required by json4s write function */ implicit val formats: Formats = Serialization.formats(NoTypeHints) /** Output schema of the skipping index */ - val outputSchema: Map[String, String] = { + private val outputSchema: Map[String, String] = { val schema = indexedColumns .flatMap(_.outputSchema().toList) .toMap @@ -33,20 +42,28 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS } override def name(): String = { - getIndexName(tableName) + getSkippingIndexName(tableName) } override def metadata(): FlintMetadata = { new FlintMetadata(s"""{ | "_meta": { - | "kind": "SkippingIndex", - | "indexedColumns": $getMetaInfo + | "kind": "$kind", + | "indexedColumns": $getMetaInfo, + | "source": "$tableName" | }, | "properties": $getSchema | } |""".stripMargin) } + override def query(): DataFrame = { + spark.read + .format("flint") + .schema(getDfSchema) + .load(name()) + } + private def getMetaInfo: String = { Serialization.write(indexedColumns) } @@ -56,10 +73,22 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS colName -> ("type" -> colType) }) } + + private def getDfSchema: StructType = { + StructType(outputSchema.map { + case (colName, "integer") => + StructField(colName, IntegerType, nullable = false) + case (colName, "keyword") => + StructField(colName, StringType, nullable = false) + }.toSeq) + } } object FlintSparkSkippingIndex { + /** Index type name */ + val SKIPPING_INDEX_TYPE = "SkippingIndex" + /** File path column name */ val FILE_PATH_COLUMN = "file_path" @@ -75,5 +104,5 @@ object FlintSparkSkippingIndex { * @return * Flint skipping index name */ - def getIndexName(tableName: String): String = s"flint_${tableName}_skipping_index" + def getSkippingIndexName(tableName: String): String = s"flint_${tableName}_skipping_index" } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index e9f29be100..6f2aac9066 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -5,6 +5,8 @@ package org.opensearch.flint.spark.skipping +import org.apache.spark.sql.catalyst.expressions.Predicate + /** * Skipping index strategy that defines skipping data structure building and reading logic. */ @@ -26,4 +28,15 @@ trait FlintSparkSkippingStrategy { * output schema mapping from Flint field name to Flint field type */ def outputSchema(): Map[String, String] + + /** + * Rewrite a predicate (filtering condition) on source table into another predicate on index + * data based on current skipping strategy. + * + * @param predicate + * filtering condition on source table + * @return + * rewritten filtering condition on index data + */ + def rewritePredicate(predicate: Predicate): Option[Predicate] } diff --git a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala index fd814873e1..fd55e556bf 100644 --- a/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala +++ b/flint/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategy.scala @@ -7,6 +7,9 @@ package org.opensearch.flint.spark.skipping.partition import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualTo, Expression, Literal, Predicate} + /** * Skipping strategy for partitioned columns of source table. */ @@ -27,4 +30,12 @@ class PartitionSkippingStrategy( case "int" => "integer" } } + + override def rewritePredicate(predicate: Predicate): Option[Predicate] = { + val newPred = predicate.collect { + case EqualTo(AttributeReference(columnName, _, _, _), value: Literal) => + EqualTo(UnresolvedAttribute(columnName), value) + } + newPred.headOption + } } diff --git a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala index 6ca75b671f..7fe78f36ef 100644 --- a/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala +++ b/flint/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexSuite.scala @@ -53,7 +53,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { override def afterEach(): Unit = { super.afterEach() - val indexName = FlintSparkSkippingIndex.getIndexName(testTable) + val indexName = FlintSparkSkippingIndex.getSkippingIndexName(testTable) flint.deleteIndex(indexName) } @@ -65,9 +65,9 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { .create() val indexName = s"flint_${testTable}_skipping_index" - val metadata = flint.describeIndex(indexName) - metadata shouldBe defined - metadata.get.getContent should matchJson(""" { + val index = flint.describeIndex(indexName) + index shouldBe defined + index.get.metadata().getContent should matchJson(s"""{ | "_meta": { | "kind": "SkippingIndex", | "indexedColumns": [ @@ -80,7 +80,8 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { | "kind": "partition", | "columnName": "month", | "columnType": "int" - | }] + | }], + | "source": "$testTable" | }, | "properties": { | "year": { @@ -107,7 +108,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite with OpenSearchSuite { val query = sql(s""" | SELECT name | FROM $testTable - | WHERE year = 2023 AND month = 04 + | WHERE year = 2023 AND month = 5 |""".stripMargin) val rewriter = new ApplyFlintSparkSkippingIndex(flint)