From 8b7e1952fd08a398ee5661c8f627b8c9c146cb65 Mon Sep 17 00:00:00 2001 From: tianhanhu Date: Fri, 5 May 2023 10:06:40 +0800 Subject: [PATCH] [SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/36726 supports TimestampNTZ type in JDBC data source and https://github.com/apache/spark/pull/37013 applies a fix to pass more test cases with H2. The problem is that Java Timestamp is a poorly defined class and different JDBC drivers implement "getTimestamp" and "setTimestamp" with different expected behaviors in mind. The general conversion implementation would work with some JDBC dialects and their drivers but not others. This issue is discovered when testing with PostgreSQL database. This PR adds a `dialect` parameter to `makeGetter` for applying dialect specific conversions when reading a Java Timestamp into TimestampNTZType. `makeSetter` already has a `dialect` field and we will use that for converting back to Java Timestamp. ### Why are the changes needed? Fix TimestampNTZ support for PostgreSQL. Allows other JDBC dialects to provide dialect specific implementation for converting between Java Timestamp and Spark TimestampNTZType. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit test. I added new test cases for `PostgresIntegrationSuite` to cover TimestampNTZ read and writes. Closes #40678 from tianhanhu/SPARK-43040_jdbc_timestamp_ntz. Authored-by: tianhanhu Signed-off-by: Wenchen Fan --- .../sql/jdbc/PostgresIntegrationSuite.scala | 35 +++++++++++++++++ .../execution/datasources/jdbc/JDBCRDD.scala | 3 +- .../datasources/jdbc/JdbcUtils.scala | 38 +++++++++++++------ .../apache/spark/sql/jdbc/JdbcDialects.scala | 30 ++++++++++++++- .../spark/sql/jdbc/PostgresDialect.scala | 11 +++++- 5 files changed, 102 insertions(+), 15 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index ff5127ce350f5..f840876fc5d00 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.math.{BigDecimal => JBigDecimal} import java.sql.{Connection, Date, Timestamp} import java.text.SimpleDateFormat +import java.time.LocalDateTime import java.util.Properties import org.apache.spark.sql.Column @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "c0 money)").executeUpdate() conn.prepareStatement("INSERT INTO money_types VALUES " + "('$1,000.00')").executeUpdate() + + conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate() + conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES + |('2013-04-05 12:01:02'), + |('2013-04-05 18:01:02.123'), + |('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate() } test("Type mapping for various types") { @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).length === 1) assert(row(0).getString(0) === "$1,000.00") } + + test("SPARK-43040: timestamp_ntz read test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop) + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2))) + assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000))) + assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000))) + } + + test("SPARK-43040: timestamp_ntz roundtrip test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + + val sparkQuery = """ + |select + | timestamp_ntz'2020-12-10 11:22:33' as col0 + """.stripMargin + + val df_expected = sqlContext.sql(sparkQuery) + df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + + val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + assert(df_actual.collect()(0) == df_expected.collect()(0)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 70e29f5d7195c..e241951abe392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD( stmt.setFetchSize(options.fetchSize) stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() - val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) + val rowsIterator = + JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics) CompletionIterator[InternalRow, Iterator[InternalRow]]( new InterruptibleIterator(context, rowsIterator), close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fe53ba91d9592..d907ce6b100cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper { /** * Convert a [[ResultSet]] into an iterator of Catalyst Rows. */ - def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { + def resultSetToRows( + resultSet: ResultSet, + schema: StructType): Iterator[Row] = { + resultSetToRows(resultSet, schema, NoopDialect) + } + + def resultSetToRows( + resultSet: ResultSet, + schema: StructType, + dialect: JdbcDialect): Iterator[Row] = { val inputMetrics = Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() - val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) + val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics) internalRows.map(fromRow) } private[spark] def resultSetToSparkInternalRows( resultSet: ResultSet, + dialect: JdbcDialect, schema: StructType, inputMetrics: InputMetrics): Iterator[InternalRow] = { new NextIterator[InternalRow] { private[this] val rs = resultSet - private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) + private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema) private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) override protected def close(): Unit = { @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper { * Creates `JDBCValueGetter`s according to [[StructType]], which can set * each value from `ResultSet` to each field of [[InternalRow]] correctly. */ - private def makeGetters(schema: StructType): Array[JDBCValueGetter] = { + private def makeGetters( + dialect: JdbcDialect, + schema: StructType): Array[JDBCValueGetter] = { val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) - replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) + replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata)) } - private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { + private def makeGetter( + dt: DataType, + dialect: JdbcDialect, + metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper { (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { - row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t)) + row.setLong(pos, + DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t))) } else { row.update(pos, null) } @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper { case TimestampNTZType => (stmt: PreparedStatement, row: Row, pos: Int) => - val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos)) - stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros)) + stmt.setTimestamp(pos + 1, + dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos))) case DateType => if (conf.datetimeJava8ApiEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e7a74ee3aa9c6..93a311be2f867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Driver, Statement, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, LocalDateTime} import java.util import scala.collection.mutable.ArrayBuilder @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging { */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the + * value stored in a remote database. + * JDBC dialects should override this function to provide implementations that suite their + * JDBC drivers. + * @param t Timestamp returned from JDBC driver getTimestamp method. + * @return A LocalDateTime representing the same wall clock time as the timestamp in database. + */ + @Since("3.5.0") + def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t)) + } + + /** + * Converts a LocalDateTime representing a TimestampNTZ type to an + * instance of `java.sql.Timestamp`. + * @param ldt representing a TimestampNTZType. + * @return A Java Timestamp representing this LocalDateTime. + */ + @Since("3.5.0") + def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + val micros = localDateTimeToMicros(ldt) + toJavaTimestampNoRebase(micros) + } + /** * Returns a factory for creating connections to the given JDBC URL. * In general, creating a connection has nothing to do with JDBC partition id. @@ -682,6 +708,6 @@ object JdbcDialects { /** * NOOP dialect object, always returning the neutral element. */ -private object NoopDialect extends JdbcDialect { +object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index b53a0e66ba752..b42d575ae2d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, SQLException, Types} +import java.sql.{Connection, SQLException, Timestamp, Types} +import java.time.LocalDateTime import java.util import java.util.Locale @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { case _ => None } + override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + t.toLocalDateTime + } + + override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + Timestamp.valueOf(ldt) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.VARCHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))