Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18141][SQL] Fix to quote column names in the predicate clause of the JDBC RDD generated sql statement #15662

Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
Expand Down Expand Up @@ -105,37 +105,40 @@ object JDBCRDD extends Logging {
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
*/
def compileFilter(f: Filter): Option[String] = {
def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)

Option(f match {
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}"
case EqualNullSafe(attr, value) =>
s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " +
s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
case IsNull(attr) => s"$attr IS NULL"
case IsNotNull(attr) => s"$attr IS NOT NULL"
case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'"
case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'"
case StringContains(attr, value) => s"${attr} LIKE '%${value}%'"
val col = quote(attr)
s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " +
s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}"
case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}"
case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}"
case IsNull(attr) => s"${quote(attr)} IS NULL"
case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
case In(attr, value) if value.isEmpty =>
s"CASE WHEN ${attr} IS NULL THEN NULL ELSE FALSE END"
case In(attr, value) => s"$attr IN (${compileValue(value)})"
case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null)
s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})"
case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null)
case Or(f1, f2) =>
// We can't compile Or filter unless both sub-filters are compiled successfully.
// It applies too for the following And filter.
// If we can make sure compileFilter supports all filters, we can remove this check.
val or = Seq(f1, f2).flatMap(compileFilter(_))
val or = Seq(f1, f2).flatMap(compileFilter(_, dialect))
if (or.size == 2) {
or.map(p => s"($p)").mkString(" OR ")
} else {
null
}
case And(f1, f2) =>
val and = Seq(f1, f2).flatMap(compileFilter(_))
val and = Seq(f1, f2).flatMap(compileFilter(_, dialect))
if (and.size == 2) {
and.map(p => s"($p)").mkString(" AND ")
} else {
Expand Down Expand Up @@ -214,7 +217,9 @@ private[jdbc] class JDBCRDD(
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String =
filters.flatMap(JDBCRDD.compileFilter).map(p => s"($p)").mkString(" AND ")
filters
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
.map(p => s"($p)").mkString(" AND ")

/**
* A WHERE clause representing both `filters`, if any, and the current partition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -113,7 +114,7 @@ private[sql] case class JDBCRelation(

// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
filters.filter(JDBCRDD.compileFilter(_).isEmpty)
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
}

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
Expand Down
74 changes: 56 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ class JDBCSuite extends SparkFunSuite
|partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement(
"""create table test."mixedCaseCols" ("Name" TEXT(32), "Id" INTEGER NOT NULL)""")
.executeUpdate()
conn.prepareStatement("""insert into test."mixedCaseCols" values ('fred', 1)""").executeUpdate()
conn.prepareStatement("""insert into test."mixedCaseCols" values ('mary', 2)""").executeUpdate()
conn.prepareStatement("""insert into test."mixedCaseCols" values (null, 3)""").executeUpdate()
conn.commit()

sql(
s"""
|CREATE TEMPORARY TABLE mixedCaseCols
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
}

Expand Down Expand Up @@ -604,30 +619,32 @@ class JDBCSuite extends SparkFunSuite

test("compile filters") {
val compileFilter = PrivateMethod[Option[String]]('compileFilter)
def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("")
assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3")
assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))")
def doCompileFilter(f: Filter): String =
JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("")
assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""")
assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""")
assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def")))
=== "(col0 = 0) AND (col1 = 'def')")
=== """("col0" = 0) AND ("col1" = 'def')""")
assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi")))
=== "(col0 = 2) OR (col1 = 'ghi')")
assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5")
=== """("col0" = 2) OR ("col1" = 'ghi')""")
assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""")
assert(doCompileFilter(LessThan("col3",
Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'")
assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'")
assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5")
assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3")
assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3")
assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')")
Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""")
assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04")))
=== """"col4" < '1983-08-04'""")
assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""")
assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""")
assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""")
assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""")
assert(doCompileFilter(In("col1", Array.empty)) ===
"CASE WHEN col1 IS NULL THEN NULL ELSE FALSE END")
"""CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""")
assert(doCompileFilter(Not(In("col1", Array("mno", "pqr"))))
=== "(NOT (col1 IN ('mno', 'pqr')))")
assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL")
assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL")
=== """(NOT ("col1" IN ('mno', 'pqr')))""")
assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""")
assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""")
assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def")))
=== "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) "
+ "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')")
=== """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """
+ """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""")
}

test("Dialect unregister") {
Expand Down Expand Up @@ -824,4 +841,25 @@ class JDBCSuite extends SparkFunSuite
val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp")
assert(schema.contains("`order` TEXT"))
}

test("SPARK-18141: Predicates on quoted column names in the jdbc data source") {
assert(sql("SELECT * FROM mixedCaseCols WHERE Id < 1").collect().size == 0)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id <= 1").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id > 1").collect().size == 2)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id >= 1").collect().size == 3)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id != 2").collect().size == 2)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id <=> 2").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE 'fr%'").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE '%ed'").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE '%re%'").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS NULL").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS NOT NULL").collect().size == 2)
assert(sql("SELECT * FROM mixedCaseCols")
.filter($"Name".isin(Array[String]() : _*)).collect().size == 0)
Copy link
Member

@gatorsmile gatorsmile Nov 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.filter($"Name".isin(Array[String]() : _*)).collect().size == 0)

->

.filter($"Name".isin()).collect().size == 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks , @gatorsmile . Fixed it.

assert(sql("SELECT * FROM mixedCaseCols WHERE Name IN ('mary', 'fred')").collect().size == 2)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name NOT IN ('fred')").collect().size == 1)
assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1 OR Name = 'mary'").collect().size == 2)
assert(sql("SELECT * FROM mixedCaseCols WHERE Name = 'mary' AND Id = 2").collect().size == 1)
}
}