Skip to content

Commit

Permalink
rdar://112431669: Auto load Iceberg extensions (apache#1805)
Browse files Browse the repository at this point in the history
* Auto load IcebergSparkExtensions

* Add test

* Review comments

* Temp: upgrade to 1.3.0.1-apple to try to pass tests

* Fix tests

Co-authored-by: Szehon Ho <szehon.apache@gmail.com>
  • Loading branch information
2 people authored and GitHub Enterprise committed Jul 18, 2023
1 parent 2f432eb commit 5d61242
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ICEBERG_ENABLED =
buildConf("spark.sql.extensions.iceberg.enabled")
.internal()
.doc("Whether to automatically load org.apache.iceberg.spark.extensions" +
".IcebergSparkSessionExtensions by default.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
21 changes: 20 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import scala.util.control.NonFatal

import com.apple.boson.BosonConf
Expand Down Expand Up @@ -839,6 +840,8 @@ class SparkSession private(
@Stable
object SparkSession extends Logging {

def icebergClass: String = "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions"

/**
* Builder for [[SparkSession]].
*/
Expand Down Expand Up @@ -1312,6 +1315,14 @@ object SparkSession extends Logging {
}
}

private def loadIcebergExtension(sparkContext: SparkContext): Seq[String] = {
if (sparkContext.getConf.getBoolean(SQLConf.ICEBERG_ENABLED.key, isIcebergEnabled)) {
Seq(icebergClass)
} else {
Seq.empty
}
}

/**
* Initialize extensions specified in [[StaticSQLConf]]. The classes will be applied to the
* extensions passed into this function.
Expand All @@ -1320,7 +1331,8 @@ object SparkSession extends Logging {
sparkContext: SparkContext,
extensions: SparkSessionExtensions): SparkSessionExtensions = {
val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
.getOrElse(Seq.empty) ++ loadBosonExtension(sparkContext)
.getOrElse(Seq.empty) ++ loadBosonExtension(sparkContext) ++
loadIcebergExtension(sparkContext)
extensionConfClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
Expand Down Expand Up @@ -1363,4 +1375,11 @@ object SparkSession extends Logging {
val v = System.getenv("BOSON")
v == null || v.toBoolean
}

/**
* Whether Iceberg extension is enabled
*/
def isIcebergEnabled: Boolean = {
Try(Utils.classForName(icebergClass, false)).isSuccess
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec}
import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
import org.apache.spark.sql.internal.SQLConf.{COLUMN_BATCH_SIZE, ICEBERG_ENABLED}
import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS
import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, Metadata, StructType}
import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch, ColumnarMap, ColumnVector}
Expand All @@ -62,15 +63,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
}

private def withSession(
builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = {
builders: Seq[SparkSessionExtensionsProvider], pairs: (String, String)*)
(f: SparkSession => Unit): Unit = {
val builder = SparkSession.builder().master("local[1]")
builders.foreach(builder.withExtensions)
val spark = builder.getOrCreate()
val configuredBuilder =
if (!SQLConf.get.contains(SQLConf.ICEBERG_ENABLED.key)) {
builder.config(SQLConf.ICEBERG_ENABLED.key, "false")
} else {
builder
}
val spark = configuredBuilder.getOrCreate()
try f(spark) finally {
stop(spark)
}
}

test("Test Iceberg extension") {
withSQLConf(SQLConf.ICEBERG_ENABLED.key -> "true") {
withSession(Seq()) { session =>
assert(session.sessionState.planner.strategies.contains(
ExtendedDataSourceV2Strategy(session)))
}
}
}

test("inject analyzer rule") {
withSession(Seq(_.injectResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
Expand Down Expand Up @@ -341,6 +358,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
val session = SparkSession.builder()
.master("local[1]")
.config(SPARK_SESSION_EXTENSIONS.key, classOf[MyExtensions].getCanonicalName)
.config(ICEBERG_ENABLED.key, false)
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
Expand All @@ -364,6 +382,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
.config(SPARK_SESSION_EXTENSIONS.key, Seq(
classOf[MyExtensions2].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.config(ICEBERG_ENABLED.key, false)
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.containsSlice(
Expand Down Expand Up @@ -392,6 +411,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
.config(SPARK_SESSION_EXTENSIONS.key, Seq(
classOf[MyExtensions].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.config(ICEBERG_ENABLED.key, false)
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.count(_ === MySparkStrategy(session)) === 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,8 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession {
val message = intercept[AnalysisException] {
sql("SELECT * FROM v")
}.getMessage
assert(message.contains(s"Invalid view text: $dropView." +
s" The view ${table.qualifiedName} may have been tampered with"))
assert(message.contains(s"has an incompatible schema change " +
s"and column 1 cannot be resolved"))
}
}

Expand Down

0 comments on commit 5d61242

Please sign in to comment.