Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32364][SQL] Use CaseInsensitiveMap for DataFrameReader/Writer options #29160

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma
override def contains(k: String): Boolean =
keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT))

override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = {
override def +[B1 >: T](kv: (String, B1)): CaseInsensitiveMap[B1] = {
new CaseInsensitiveMap(originalMap.filter(!_._1.equalsIgnoreCase(kv._1)) + kv)
}

def ++(xs: TraversableOnce[(String, T)]): CaseInsensitiveMap[T] = {
xs.foldLeft(this)(_ + _)
}

override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator

override def -(key: String): Map[String, T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Ma
override def contains(k: String): Boolean =
keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT))

override def updated[B1 >: T](key: String, value: B1): Map[String, B1] = {
override def updated[B1 >: T](key: String, value: B1): CaseInsensitiveMap[B1] = {
new CaseInsensitiveMap[B1](originalMap.filter(!_._1.equalsIgnoreCase(key)) + (key -> value))
}

def ++(xs: IterableOnce[(String, T)]): CaseInsensitiveMap[T] = {
xs.iterator.foldLeft(this) { (m, kv) => m.updated(kv._1, kv._2) }
}

override def iterator: Iterator[(String, T)] = keyLowerCasedMap.iterator

override def removed(key: String): Map[String, T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailureSafeParser}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.command.DDLUtils
Expand Down Expand Up @@ -238,7 +238,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
Some("paths" -> objectMapper.writeValueAsString(paths.toArray))
}

val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption
val finalOptions = sessionOptions ++ extraOptions.originalMap ++ pathsOption
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val (table, catalog, ident) = provider match {
case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty =>
Expand Down Expand Up @@ -276,7 +276,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
options = extraOptions.originalMap).resolveRelation())
}

/**
Expand All @@ -290,7 +290,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
// properties should override settings in extraOptions.
this.extraOptions ++= properties.asScala
// explicit url and dbtable should override all
this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
this.extraOptions ++= Seq(JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
format("jdbc").load()
}

Expand Down Expand Up @@ -879,6 +879,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

private var userSpecifiedSchema: Option[StructType] = None

private val extraOptions = new scala.collection.mutable.HashMap[String, String]
private var extraOptions = CaseInsensitiveMap[String](Map.empty)

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchT
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateTableAsSelectStatement, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
Expand Down Expand Up @@ -768,7 +769,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
// connectionProperties should override settings in extraOptions.
this.extraOptions ++= connectionProperties.asScala
// explicit url and dbtable should override all
this.extraOptions += ("url" -> url, "dbtable" -> table)
this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
format("jdbc").save()
}

Expand Down Expand Up @@ -960,7 +961,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private val extraOptions = new scala.collection.mutable.HashMap[String, String]
private var extraOptions = CaseInsensitiveMap[String](Map.empty)

private var partitioningColumns: Option[Seq[String]] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,8 @@ class JDBCSuite extends QueryTest
testJdbcOptions(new JDBCOptions(parameters))
testJdbcOptions(new JDBCOptions(CaseInsensitiveMap(parameters)))
// test add/remove key-value from the case-insensitive map
var modifiedParameters = CaseInsensitiveMap(Map.empty) ++ parameters
var modifiedParameters =
(CaseInsensitiveMap(Map.empty) ++ parameters).asInstanceOf[Map[String, String]]
testJdbcOptions(new JDBCOptions(modifiedParameters))
modifiedParameters -= "dbtable"
assert(modifiedParameters.get("dbTAblE").isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,28 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
assert(LastOptions.parameters("opt3") == "3")
}

test("SPARK-32364: path argument of load function should override all existing options") {
spark.read
.format("org.apache.spark.sql.test")
.option("paTh", "1")
.option("PATH", "2")
.option("Path", "3")
.option("patH", "4")
.load("5")
assert(LastOptions.parameters("path") == "5")
}

test("SPARK-32364: path argument of save function should override all existing options") {
Seq(1).toDF.write
.format("org.apache.spark.sql.test")
.option("paTh", "1")
.option("PATH", "2")
.option("Path", "3")
.option("patH", "4")
.save("5")
assert(LastOptions.parameters("path") == "5")
}

test("pass partitionBy as options") {
Seq(1).toDF.write
.format("org.apache.spark.sql.test")
Expand Down