Skip to content

Commit

Permalink
fix: Mark cast from float/double to decimal as incompatible (#1372)
Browse files Browse the repository at this point in the history
* add failing test

* Mark cast from float/double to decimal as incompat

* update docs

* update cast tests

* link to issue

* fix regressions

* use unique table name in test

* use withTable

* address feedback
  • Loading branch information
andygrove authored Feb 7, 2025
1 parent 19c4405 commit 26b8d57
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@ The following cast operations are generally compatible with Spark except for the
| float | integer | |
| float | long | |
| float | double | |
| float | decimal | |
| float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| double | boolean | |
| double | byte | |
| double | short | |
| double | integer | |
| double | long | |
| double | float | |
| double | decimal | |
| double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| decimal | byte | |
| decimal | short | |
Expand Down Expand Up @@ -154,6 +152,8 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| float | decimal | There can be rounding differences |
| double | decimal | There can be rounding differences |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,19 @@ object CometCast {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
Incompatible(Some("There can be rounding differences"))
case _ => Unsupported
}

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1371
Incompatible(Some("There can be rounding differences"))
case _ => Unsupported
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ object ParquetGenerator {
}

case class DataGenOptions(
allowNull: Boolean,
generateNegativeZero: Boolean,
generateArray: Boolean,
generateStruct: Boolean,
generateMap: Boolean)
allowNull: Boolean = true,
generateNegativeZero: Boolean = true,
generateArray: Boolean = false,
generateStruct: Boolean = false,
generateMap: Boolean = false)
18 changes: 16 additions & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.DoubleType)
}

test("cast FloatType to DecimalType(10,2)") {
ignore("cast FloatType to DecimalType(10,2)") {
// // https://github.com/apache/datafusion-comet/issues/1371
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}

test("cast FloatType to DecimalType(10,2) - allow incompat") {
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}
}

test("cast FloatType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
val r = new Random(0)
Expand Down Expand Up @@ -401,10 +408,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.FloatType)
}

test("cast DoubleType to DecimalType(10,2)") {
ignore("cast DoubleType to DecimalType(10,2)") {
// https://github.com/apache/datafusion-comet/issues/1371
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}

test("cast DoubleType to DecimalType(10,2) - allow incompat") {
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}
}

test("cast DoubleType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
val r = new Random(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,35 @@ import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}

/**
* Test suite dedicated to Comet native aggregate operator
*/
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

test("avg decimal") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(random, spark, filename, 10000, DataGenOptions())
}
val tableName = "avg_decimal"
withTable(tableName) {
val table = spark.read.parquet(filename).coalesce(1)
table.createOrReplaceTempView(tableName)
// we fall back to Spark for avg on decimal due to the following issue
// https://github.com/apache/datafusion-comet/issues/1371
// once this is fixed, we should change this test to
// checkSparkAnswerAndNumOfAggregates
checkSparkAnswer(s"SELECT c1, avg(c7) FROM $tableName GROUP BY c1 ORDER BY c1")
}
}
}

test("stddev_pop should return NaN for some cases") {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
Expand Down Expand Up @@ -867,10 +889,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {

withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true",
CometConf.COMET_SHUFFLE_MODE.key -> "native") {
Seq(true, false).foreach { dictionaryEnabled =>
withSQLConf("parquet.enable.dictionary" -> dictionaryEnabled.toString) {
val table = "t1"
val table = s"final_decimal_avg_$dictionaryEnabled"
withTable(table) {
sql(s"create table $table(a decimal(38, 37), b INT) using parquet")
sql(s"insert into $table values(-0.0000000000000000000000000000000000002, 1)")
Expand All @@ -884,13 +907,13 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)")
sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)")

checkSparkAnswerAndNumOfAggregates("SELECT b , AVG(a) FROM t1 GROUP BY b", 2)
checkSparkAnswerAndNumOfAggregates("SELECT AVG(a) FROM t1", 2)
checkSparkAnswerAndNumOfAggregates(s"SELECT b , AVG(a) FROM $table GROUP BY b", 2)
checkSparkAnswerAndNumOfAggregates(s"SELECT AVG(a) FROM $table", 2)
checkSparkAnswerAndNumOfAggregates(
"SELECT b, MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1 GROUP BY b",
s"SELECT b, MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM $table GROUP BY b",
2)
checkSparkAnswerAndNumOfAggregates(
"SELECT MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1",
s"SELECT MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM $table",
2)
}
}
Expand All @@ -915,7 +938,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_MODE.key -> "native") {
val table = "t1"
val table = "avg_null_handling"
withTable(table) {
sql(s"create table $table(a double, b double) using parquet")
sql(s"insert into $table values(1, 1.0)")
Expand Down

0 comments on commit 26b8d57

Please sign in to comment.