Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32364][SQL][2.4] Use CaseInsensitiveMap for DataFrameReader/Writer options #29209

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma
override def contains(k: String): Boolean =
keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT))

override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = {
override def +[B1 >: T](kv: (String, B1)): CaseInsensitiveMap[B1] = {
new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(kv._1)) + kv)
}

def ++(xs: TraversableOnce[(String, T)]): CaseInsensitiveMap[T] = {
xs.foldLeft(this)(_ + _)
}

override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator

override def -(key: String): Map[String, T] = {
new CaseInsensitiveMap(originalMap.filterKeys(!_.equalsIgnoreCase(key)))
}

def toMap: Map[String, T] = originalMap
}

object CaseInsensitiveMap {
Expand Down
25 changes: 22 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser}
import org.apache.spark.sql.execution.datasources.csv._
Expand Down Expand Up @@ -91,6 +92,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Adds an input option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand All @@ -107,27 +111,39 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Adds an input option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString)

/**
* Adds an input option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Long): DataFrameReader = option(key, value.toString)

/**
* Adds an input option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Double): DataFrameReader = option(key, value.toString)

/**
* (Scala-specific) Adds input options for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand All @@ -144,6 +160,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Adds input options for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand Down Expand Up @@ -234,7 +253,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
// properties should override settings in extraOptions.
this.extraOptions ++= properties.asScala
// explicit url and dbtable should override all
this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
this.extraOptions ++= Seq(JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
format("jdbc").load()
}

Expand Down Expand Up @@ -305,7 +324,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
connectionProperties: Properties): DataFrame = {
assertNoSpecifiedSchema("jdbc")
// connectionProperties should override settings in extraOptions.
val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
val params = extraOptions ++ connectionProperties.asScala
val options = new JDBCOptions(url, table, params)
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i) : Partition
Expand Down Expand Up @@ -790,6 +809,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

private var userSpecifiedSchema: Option[StructType] = None

private val extraOptions = new scala.collection.mutable.HashMap[String, String]
private var extraOptions = CaseInsensitiveMap[String](Map.empty)

}
25 changes: 22 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
Expand Down Expand Up @@ -98,6 +99,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
/**
* Adds an output option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand All @@ -114,27 +118,39 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
/**
* Adds an output option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)

/**
* Adds an output option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)

/**
* Adds an output option for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)

/**
* (Scala-specific) Adds output options for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand All @@ -151,6 +167,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
/**
* Adds output options for the underlying data source.
*
* All options are maintained in a case-insensitive way in terms of key names.
* If a new option has the same key case-insensitively, it will override the existing option.
*
* You can set the following option(s):
* <ul>
* <li>`timeZone` (default session local timezone): sets the string that indicates a timezone
Expand Down Expand Up @@ -251,7 +270,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source,
df.sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
val options = sessionOptions.filterKeys(!extraOptions.contains(_)) ++ extraOptions.toMap

val writer = ws.createWriter(
UUID.randomUUID.toString, df.logicalPlan.output.toStructType, mode,
Expand Down Expand Up @@ -512,7 +531,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
// connectionProperties should override settings in extraOptions.
this.extraOptions ++= connectionProperties.asScala
// explicit url and dbtable should override all
this.extraOptions += ("url" -> url, "dbtable" -> table)
this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
format("jdbc").save()
}

Expand Down Expand Up @@ -692,7 +711,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private val extraOptions = new scala.collection.mutable.HashMap[String, String]
private var extraOptions = CaseInsensitiveMap[String](Map.empty)

private var partitioningColumns: Option[Seq[String]] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.math.BigDecimal
import java.sql.{Date, DriverManager, SQLException, Timestamp}
import java.util.{Calendar, GregorianCalendar, Properties}

import scala.collection.JavaConverters._

import org.h2.jdbc.JdbcSQLException
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}

Expand Down Expand Up @@ -1261,7 +1263,8 @@ class JDBCSuite extends QueryTest
testJdbcOptions(new JDBCOptions(parameters))
testJdbcOptions(new JDBCOptions(CaseInsensitiveMap(parameters)))
// test add/remove key-value from the case-insensitive map
var modifiedParameters = CaseInsensitiveMap(Map.empty) ++ parameters
var modifiedParameters =
(CaseInsensitiveMap(Map.empty) ++ parameters).asInstanceOf[Map[String, String]]
testJdbcOptions(new JDBCOptions(modifiedParameters))
modifiedParameters -= "dbtable"
assert(modifiedParameters.get("dbTAblE").isEmpty)
Expand Down Expand Up @@ -1585,4 +1588,23 @@ class JDBCSuite extends QueryTest
checkNotPushdown(sql("SELECT name, theid FROM predicateOption WHERE theid = 1")),
Row("fred", 1) :: Nil)
}

test("SPARK-32364: JDBCOption constructor") {
val extraOptions = CaseInsensitiveMap[String](Map("UrL" -> "url1", "dBTable" -> "table1"))
val connectionProperties = new Properties()
connectionProperties.put("url", "url2")
connectionProperties.put("dbtable", "table2")

// connection property should override the options in extraOptions
val params = extraOptions ++ connectionProperties.asScala
assert(params.size == 2)
assert(params.get("uRl").contains("url2"))
assert(params.get("DbtaBle").contains("table2"))

// JDBCOptions constructor parameter should overwrite the existing conf
val options = new JDBCOptions(url, "table3", params)
assert(options.asProperties.size == 2)
assert(options.asProperties.get("url") == url)
assert(options.asProperties.get("dbtable") == "table3")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,28 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
assert(LastOptions.parameters("opt3") == "3")
}

test("SPARK-32364: path argument of load function should override all existing options") {
spark.read
.format("org.apache.spark.sql.test")
.option("paTh", "1")
.option("PATH", "2")
.option("Path", "3")
.option("patH", "4")
.load("5")
assert(LastOptions.parameters("path") == "5")
}

test("SPARK-32364: path argument of save function should override all existing options") {
Seq(1).toDF.write
.format("org.apache.spark.sql.test")
.option("paTh", "1")
.option("PATH", "2")
.option("Path", "3")
.option("patH", "4")
.save("5")
assert(LastOptions.parameters("path") == "5")
}

test("pass partitionBy as options") {
Seq(true, false).foreach { flag =>
withSQLConf(SQLConf.LEGACY_PASS_PARTITION_BY_AS_OPTIONS.key -> s"$flag") {
Expand Down