diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index a1e5dfdbf739e..37df283a9e5b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -105,37 +105,40 @@ object JDBCRDD extends Logging { * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. */ - def compileFilter(f: Filter): Option[String] = { + def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { + def quote(colName: String): String = dialect.quoteIdentifier(colName) + Option(f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualTo(attr, value) => s"${quote(attr)} = ${compileValue(value)}" case EqualNullSafe(attr, value) => - s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + - s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" - case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + val col = quote(attr) + s"(NOT ($col != ${compileValue(value)} OR $col IS NULL OR " + + s"${compileValue(value)} IS NULL) OR ($col IS NULL AND ${compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"${quote(attr)} < ${compileValue(value)}" + case GreaterThan(attr, value) => s"${quote(attr)} > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${compileValue(value)}" + case IsNull(attr) => s"${quote(attr)} IS NULL" + case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" + case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'" + case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" case In(attr, value) if value.isEmpty => - s"CASE WHEN ${attr} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"$attr IN (${compileValue(value)})" - case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" + case In(attr, value) => s"${quote(attr)} IN (${compileValue(value)})" + case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) case Or(f1, f2) => // We can't compile Or filter unless both sub-filters are compiled successfully. // It applies too for the following And filter. // If we can make sure compileFilter supports all filters, we can remove this check. - val or = Seq(f1, f2).flatMap(compileFilter(_)) + val or = Seq(f1, f2).flatMap(compileFilter(_, dialect)) if (or.size == 2) { or.map(p => s"($p)").mkString(" OR ") } else { null } case And(f1, f2) => - val and = Seq(f1, f2).flatMap(compileFilter(_)) + val and = Seq(f1, f2).flatMap(compileFilter(_, dialect)) if (and.size == 2) { and.map(p => s"($p)").mkString(" AND ") } else { @@ -214,7 +217,9 @@ private[jdbc] class JDBCRDD( * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ private val filterWhereClause: String = - filters.flatMap(JDBCRDD.compileFilter).map(p => s"($p)").mkString(" AND ") + filters + .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) + .map(p => s"($p)").mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 672c21c6ac734..6abb27db8531e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -113,7 +114,7 @@ private[sql] case class JDBCRelation( // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter(JDBCRDD.compileFilter(_).isEmpty) + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b16be457ed5c3..af5f01c493e84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -202,6 +202,21 @@ class JDBCSuite extends SparkFunSuite |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + """create table test."mixedCaseCols" ("Name" TEXT(32), "Id" INTEGER NOT NULL)""") + .executeUpdate() + conn.prepareStatement("""insert into test."mixedCaseCols" values ('fred', 1)""").executeUpdate() + conn.prepareStatement("""insert into test."mixedCaseCols" values ('mary', 2)""").executeUpdate() + conn.prepareStatement("""insert into test."mixedCaseCols" values (null, 3)""").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE mixedCaseCols + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -604,30 +619,32 @@ class JDBCSuite extends SparkFunSuite test("compile filters") { val compileFilter = PrivateMethod[Option[String]]('compileFilter) - def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") - assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") - assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") + def doCompileFilter(f: Filter): String = + JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("") + assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""") + assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""") assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) - === "(col0 = 0) AND (col1 = 'def')") + === """("col0" = 0) AND ("col1" = 'def')""") assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) - === "(col0 = 2) OR (col1 = 'ghi')") - assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5") + === """("col0" = 2) OR ("col1" = 'ghi')""") + assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""") assert(doCompileFilter(LessThan("col3", - Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'") - assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'") - assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5") - assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3") - assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3") - assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')") + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""") + assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) + === """"col4" < '1983-08-04'""") + assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""") + assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""") + assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""") + assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""") assert(doCompileFilter(In("col1", Array.empty)) === - "CASE WHEN col1 IS NULL THEN NULL ELSE FALSE END") + """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) - === "(NOT (col1 IN ('mno', 'pqr')))") - assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") - assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + === """(NOT ("col1" IN ('mno', 'pqr')))""") + assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""") + assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""") assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) - === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) " - + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')") + === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """ + + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""") } test("Dialect unregister") { @@ -824,4 +841,24 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } + + test("SPARK-18141: Predicates on quoted column names in the jdbc data source") { + assert(sql("SELECT * FROM mixedCaseCols WHERE Id < 1").collect().size == 0) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id <= 1").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id > 1").collect().size == 2) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id >= 1").collect().size == 3) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id != 2").collect().size == 2) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id <=> 2").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE 'fr%'").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE '%ed'").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name LIKE '%re%'").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS NULL").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name IS NOT NULL").collect().size == 2) + assert(sql("SELECT * FROM mixedCaseCols").filter($"Name".isin()).collect().size == 0) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name IN ('mary', 'fred')").collect().size == 2) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name NOT IN ('fred')").collect().size == 1) + assert(sql("SELECT * FROM mixedCaseCols WHERE Id = 1 OR Name = 'mary'").collect().size == 2) + assert(sql("SELECT * FROM mixedCaseCols WHERE Name = 'mary' AND Id = 2").collect().size == 1) + } }