Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query rewrite for partition skipping index #1651

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ package org.opensearch.flint.spark

import scala.collection.JavaConverters._

import org.json4s.{Formats, JArray, NoTypeHints}
import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.FlintOptions._
import org.opensearch.flint.core.metadata.FlintMetadata
Expand Down Expand Up @@ -66,16 +69,17 @@ class FlintSpark(val spark: SparkSession) {
* @return
* Flint index metadata
*/
def describeIndex(indexName: String): Option[FlintMetadata] = {
def describeIndex(indexName: String): Option[FlintSparkIndex] = {
if (flintClient.exists(indexName)) {
Some(flintClient.getIndexMetadata(indexName))
val metadata = flintClient.getIndexMetadata(indexName)
Some(deserialize(metadata))
} else {
Option.empty
}
}

/**
* Delete index.
* Delete a Flint index.
*
* @param indexName
* index name
Expand All @@ -90,6 +94,41 @@ class FlintSpark(val spark: SparkSession) {
false
}
}

/*
* TODO: Remove all these JSON parsing logic once Flint spec finalized
* and FlintMetadata is strong-typed
*
* For now, deserialize skipping strategies out of Flint metadata json
* ex. extract Seq(Partition("year", "int"), ValueList("name")) from
* { "_meta": { "indexedColumns": [ {...partition...}, {...value list...} ] } }
*
*/
private def deserialize(metadata: FlintMetadata): FlintSparkIndex = {
implicit val formats: Formats = Serialization.formats(NoTypeHints)

val meta = parse(metadata.getContent) \ "_meta"
val tableName = (meta \ "source").extract[String]
val indexType = (meta \ "kind").extract[String]
val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray]

indexType match {
case "SkippingIndex" =>
val strategies = indexedColumns.arr.map { colInfo =>
val skippingType = (colInfo \ "kind").extract[String]
val columnName = (colInfo \ "columnName").extract[String]
val columnType = (colInfo \ "columnType").extract[String]

skippingType match {
case "partition" =>
new PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
}
new FlintSparkSkippingIndex(spark, tableName, strategies)
}
}
}

object FlintSpark {
Expand Down Expand Up @@ -144,7 +183,8 @@ object FlintSpark {
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist")))
.map(col => new PartitionSkippingStrategy(col.name, col.dataType))
.map(col =>
new PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType))
.foreach(indexedCol => indexedColumns = indexedColumns :+ indexedCol)
this
}
Expand All @@ -155,7 +195,7 @@ object FlintSpark {
def create(): Unit = {
require(tableName.nonEmpty, "table name cannot be empty")

flint.createIndex(new FlintSparkSkippingIndex(tableName, indexedColumns))
flint.createIndex(new FlintSparkSkippingIndex(flint.spark, tableName, indexedColumns))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@

package org.opensearch.flint.spark

import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSessionExtensions

class FlintSparkExtensions extends (SparkSessionExtensions => Unit) {

override def apply(v1: SparkSessionExtensions): Unit = {}
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectOptimizerRule { spark =>
new ApplyFlintSparkSkippingIndex(new FlintSpark(spark))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ package org.opensearch.flint.spark

import org.opensearch.flint.core.metadata.FlintMetadata

import org.apache.spark.sql.DataFrame

/**
* Flint index interface in Spark.
*/
trait FlintSparkIndex {

/**
* Index type
*/
val kind: String

/**
* @return
* Flint index name
Expand All @@ -24,4 +31,11 @@ trait FlintSparkIndex {
*/
def metadata(): FlintMetadata

/**
* Query current Flint index by Spark data frame.
*
* @return
* data frame
*/
def query(): DataFrame
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{And, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}

/**
* Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and
* table scan operator to leverage additional skipping data structure and accelerate query by
* reducing data scanned significantly.
*
* @param flint
* Flint Spark API
*/
class ApplyFlintSparkSkippingIndex(val flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(
condition: Predicate,
relation @ LogicalRelation(
baseRelation @ HadoopFsRelation(location, _, _, _, _, _),
_,
Some(table),
false)) =>
// Spark optimize recursively
// if (location.isInstanceOf[FlintSparkSkippingFileIndex]) {
// return filter
// }

val indexName = getSkippingIndexName(table.identifier.table) // TODO: ignore schema name
val index = flint.describeIndex(indexName)

if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val rewrittenPredicate = rewriteToPredicateOnSkippingIndex(skippingIndex, condition)
val selectedFiles = getSelectedFilesToScanAfterSkip(skippingIndex, rewrittenPredicate)

filter
} else {
filter
}
}

private def rewriteToPredicateOnSkippingIndex(
index: FlintSparkSkippingIndex,
condition: Predicate): Predicate = {

index.indexedColumns
.map(index => index.rewritePredicate(condition))
.filter(pred => pred.isDefined)
.map(pred => pred.get)
.reduce(And(_, _))
}

private def getSelectedFilesToScanAfterSkip(
index: FlintSparkSkippingIndex,
rewrittenPredicate: Predicate): Set[String] = {

index
.query()
.filter(new Column(rewrittenPredicate))
.select(FILE_PATH_COLUMN)
.collect
.map(_.getString(0))
.toSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,31 @@ import org.json4s._
import org.json4s.native.Serialization
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getIndexName, FILE_PATH_COLUMN}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

/**
* Flint skipping index in Spark.
*
* @param tableName
* source table name
*/
class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkSkippingStrategy])
case class FlintSparkSkippingIndex(
spark: SparkSession,
tableName: String,
indexedColumns: Seq[FlintSparkSkippingStrategy])
extends FlintSparkIndex {

/** Skipping index type */
override val kind: String = SKIPPING_INDEX_TYPE

/** Required by json4s write function */
implicit val formats: Formats = Serialization.formats(NoTypeHints)

/** Output schema of the skipping index */
val outputSchema: Map[String, String] = {
private val outputSchema: Map[String, String] = {
val schema = indexedColumns
.flatMap(_.outputSchema().toList)
.toMap
Expand All @@ -33,33 +42,53 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS
}

override def name(): String = {
getIndexName(tableName)
getSkippingIndexName(tableName)
}

override def metadata(): FlintMetadata = {
new FlintMetadata(s"""{
| "_meta": {
| "kind": "SkippingIndex",
| "indexedColumns": $getMetaInfo
| "kind": "$kind",
| "indexedColumns": $getMetaInfo,
| "source": "$tableName"
| },
| "properties": $getSchema
| }
|""".stripMargin)
}

override def query(): DataFrame = {
spark.read
.format("flint")
.schema(getDfSchema)
.load(name())
}

private def getMetaInfo: String = {
Serialization.write(indexedColumns.map(_.indexedColumn))
Serialization.write(indexedColumns)
}

private def getSchema: String = {
Serialization.write(outputSchema.map { case (colName, colType) =>
colName -> ("type" -> colType)
})
}

private def getDfSchema: StructType = {
StructType(outputSchema.map {
case (colName, "integer") =>
StructField(colName, IntegerType, nullable = false)
case (colName, "keyword") =>
StructField(colName, StringType, nullable = false)
}.toSeq)
}
}

object FlintSparkSkippingIndex {

/** Index type name */
val SKIPPING_INDEX_TYPE = "SkippingIndex"

/** File path column name */
val FILE_PATH_COLUMN = "file_path"

Expand All @@ -75,5 +104,5 @@ object FlintSparkSkippingIndex {
* @return
* Flint skipping index name
*/
def getIndexName(tableName: String): String = s"flint_${tableName}_skipping_index"
def getSkippingIndexName(tableName: String): String = s"flint_${tableName}_skipping_index"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,38 @@

package org.opensearch.flint.spark.skipping

import org.apache.spark.sql.catalyst.expressions.Predicate

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
*/
trait FlintSparkSkippingStrategy {

/**
* Skipping strategy kind.
*/
val kind: String

/**
* Indexed column name and its Spark SQL type.
*/
val indexedColumn: (String, String)
val columnName: String
val columnType: String

/**
* @return
* output schema mapping from Flint field name to Flint field type
*/
def outputSchema(): Map[String, String]

/**
* Rewrite a predicate (filtering condition) on source table into another predicate on index
* data based on current skipping strategy.
*
* @param predicate
* filtering condition on source table
* @return
* rewritten filtering condition on index data
*/
def rewritePredicate(predicate: Predicate): Option[Predicate]
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ package org.opensearch.flint.spark.skipping.partition

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualTo, Expression, Literal, Predicate}

/**
* Skipping strategy for partitioned columns of source table.
*/
class PartitionSkippingStrategy(override val indexedColumn: (String, String))
class PartitionSkippingStrategy(
override val kind: String = "partition",
override val columnName: String,
override val columnType: String)
extends FlintSparkSkippingStrategy {

override def outputSchema(): Map[String, String] = {
Map(indexedColumn._1 -> convertToFlintType(indexedColumn._2))
Map(columnName -> convertToFlintType(columnType))
}

// TODO: move this mapping info to single place
Expand All @@ -24,4 +30,12 @@ class PartitionSkippingStrategy(override val indexedColumn: (String, String))
case "int" => "integer"
}
}

override def rewritePredicate(predicate: Predicate): Option[Predicate] = {
val newPred = predicate.collect {
case EqualTo(AttributeReference(columnName, _, _, _), value: Literal) =>
EqualTo(UnresolvedAttribute(columnName), value)
}
newPred.headOption
}
}
Loading