Skip to content

Commit

Permalink
[SPARK-18127] Add hooks and extension points to Spark
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This patch adds support for customizing the spark session by injecting user-defined custom extensions. This allows a user to add custom analyzer rules/checks, optimizer rules, planning strategies or even a customized parser.

## How was this patch tested?

Unit Tests in SparkSessionExtensionSuite

Author: Sameer Agarwal <sameerag@cs.berkeley.edu>

Closes #17724 from sameeragarwal/session-extensions.
  • Loading branch information
sameeragarwal authored and gatorsmile committed Apr 26, 2017
1 parent 0a7f5f2 commit caf3920
Show file tree
Hide file tree
Showing 7 changed files with 418 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class AbstractSqlParser extends ParserInterface with Logging {

/** Creates/Resolves DataType for a given SQL string. */
def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
// TODO add this to the parser interface.
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}

Expand All @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
}

/** Creates FunctionIdentifier for a given SQL string. */
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,51 @@

package org.apache.spark.sql.catalyst.parser

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}

/**
* Interface for a parser.
*/
@DeveloperApi
trait ParserInterface {
/** Creates LogicalPlan for a given SQL string. */
/**
* Parse a string to a [[LogicalPlan]].
*/
@throws[ParseException]("Text cannot be parsed to a LogicalPlan")
def parsePlan(sqlText: String): LogicalPlan

/** Creates Expression for a given SQL string. */
/**
* Parse a string to an [[Expression]].
*/
@throws[ParseException]("Text cannot be parsed to an Expression")
def parseExpression(sqlText: String): Expression

/** Creates TableIdentifier for a given SQL string. */
/**
* Parse a string to a [[TableIdentifier]].
*/
@throws[ParseException]("Text cannot be parsed to a TableIdentifier")
def parseTableIdentifier(sqlText: String): TableIdentifier

/** Creates FunctionIdentifier for a given SQL string. */
/**
* Parse a string to a [[FunctionIdentifier]].
*/
@throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier

/**
* Creates StructType for a given SQL string, which is a comma separated list of field
* definitions which will preserve the correct Hive metadata.
* Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
* of field definitions which will preserve the correct Hive metadata.
*/
@throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType

/**
* Parse a string to a [[DataType]].
*/
@throws[ParseException]("Text cannot be parsed to a DataType")
def parseDataType(sqlText: String): DataType
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,10 @@ object StaticSQLConf {
"SQL configuration and the current database.")
.booleanConf
.createWithDefault(false)

val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
.doc("Name of the class used to configure Spark Session extensions. The class should " +
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
.stringConf
.createOptional
}
45 changes: 39 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState],
@transient private val parentSessionState: Option[SessionState])
@transient private val parentSessionState: Option[SessionState],
@transient private[sql] val extensions: SparkSessionExtensions)
extends Serializable with Closeable with Logging { self =>

private[sql] def this(sc: SparkContext) {
this(sc, None, None)
this(sc, None, None, new SparkSessionExtensions)
}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -219,7 +220,7 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions)
}

/**
Expand All @@ -235,7 +236,7 @@ class SparkSession private(
* implementation is Hive, this will initialize the metastore, which may take some time.
*/
private[sql] def cloneSession(): SparkSession = {
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions)
result.sessionState // force copy of SessionState
result
}
Expand Down Expand Up @@ -754,6 +755,8 @@ object SparkSession {

private[this] val options = new scala.collection.mutable.HashMap[String, String]

private[this] val extensions = new SparkSessionExtensions

private[this] var userSuppliedContext: Option[SparkContext] = None

private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
Expand Down Expand Up @@ -847,6 +850,17 @@ object SparkSession {
}
}

/**
* Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules,
* Optimizer rules, Planning Strategies or a customized parser.
*
* @since 2.2.0
*/
def withExtensions(f: SparkSessionExtensions => Unit): Builder = {
f(extensions)
this
}

/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
Expand Down Expand Up @@ -903,7 +917,26 @@ object SparkSession {
}
sc
}
session = new SparkSession(sparkContext)

// Initialize extensions if the user has defined a configurator class.
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e @ (_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}

session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) }
defaultSession.set(session)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import scala.collection.mutable

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

/**
* :: Experimental ::
* Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability
* regarding binary compatibility and source compatibility of methods here.
*
* This current provides the following extension points:
* - Analyzer Rules.
* - Check Analysis Rules
* - Optimizer Rules.
* - Planning Strategies.
* - Customized Parser.
* - (External) Catalog listeners.
*
* The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for
* example:
* {{{
* SparkSession.builder()
* .master("...")
* .conf("...", true)
* .withExtensions { extensions =>
* extensions.injectResolutionRule { session =>
* ...
* }
* extensions.injectParser { (session, parser) =>
* ...
* }
* }
* .getOrCreate()
* }}}
*
* Note that none of the injected builders should assume that the [[SparkSession]] is fully
* initialized and should not touch the session's internals (e.g. the SessionState).
*/
@DeveloperApi
@Experimental
@InterfaceStability.Unstable
class SparkSessionExtensions {
type RuleBuilder = SparkSession => Rule[LogicalPlan]
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
* Build the analyzer resolution `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
resolutionRuleBuilders.map(_.apply(session))
}

/**
* Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer
* rules will be executed as part of the resolution phase of analysis.
*/
def injectResolutionRule(builder: RuleBuilder): Unit = {
resolutionRuleBuilders += builder
}

private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
* Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
postHocResolutionRuleBuilders.map(_.apply(session))
}

/**
* Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer
* rules will be executed after resolution.
*/
def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
postHocResolutionRuleBuilders += builder
}

private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder]

/**
* Build the check analysis `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = {
checkRuleBuilders.map(_.apply(session))
}

/**
* Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will
* be executed after the analysis phase. A check analysis rule is used to detect problems with a
* LogicalPlan and should throw an exception when a problem is found.
*/
def injectCheckRule(builder: CheckRuleBuilder): Unit = {
checkRuleBuilders += builder
}

private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
optimizerRules.map(_.apply(session))
}

/**
* Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be
* executed during the operator optimization batch. An optimizer rule is used to improve the
* quality of an analyzed logical plan; these rules should never modify the result of the
* LogicalPlan.
*/
def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
}

private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]

private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
plannerStrategyBuilders.map(_.apply(session))
}

/**
* Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will
* be used to convert a `LogicalPlan` into a executable
* [[org.apache.spark.sql.execution.SparkPlan]].
*/
def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
plannerStrategyBuilders += builder
}

private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder]

private[sql] def buildParser(
session: SparkSession,
initial: ParserInterface): ParserInterface = {
parserBuilders.foldLeft(initial) { (parser, builder) =>
builder(session, parser)
}
}

/**
* Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session
* and an initial parser. The latter allows for a user to create a partial parser and to delegate
* to the underlying parser for completeness. If a user injects more parsers, then the parsers
* are stacked on top of each other.
*/
def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}
}
Loading

0 comments on commit caf3920

Please sign in to comment.