Skip to content

Commit

Permalink
fix: Unsigned type related bugs (#1095)
Browse files Browse the repository at this point in the history
## Which issue does this PR close?

Closes #1067

## Rationale for this change

Bug fix. A few expressions were failing some unsigned type related tests

## What changes are included in this PR?

 - For `u8`/`u16`, switched to use `generate_cast_to_signed!` in order to copy full i16/i32 width instead of padding zeros in the higher bits
 - `u64` becomes `Decimal(20, 0)` but there was a bug in `round()`  (`>` vs `>=`)

## How are these changes tested?

Put back tests for unsigned types
  • Loading branch information
kazuyukitanimura authored Nov 19, 2024
1 parent 59da6ce commit ca3a529
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 21 deletions.
9 changes: 5 additions & 4 deletions native/core/src/parquet/read/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ make_int_variant_impl!(Int32ToDoubleType, copy_i32_to_f64, 8);
make_int_variant_impl!(FloatToDoubleType, copy_f32_to_f64, 8);

// unsigned type require double the width and zeroes are written for the second half
// perhaps because they are implemented as the next size up signed type?
// because they are implemented as the next size up signed type
make_int_variant_impl!(UInt8Type, copy_i32_to_u8, 2);
make_int_variant_impl!(UInt16Type, copy_i32_to_u16, 4);
make_int_variant_impl!(UInt32Type, copy_i32_to_u32, 8);
Expand Down Expand Up @@ -586,8 +586,6 @@ macro_rules! generate_cast_to_unsigned {
};
}

generate_cast_to_unsigned!(copy_i32_to_u8, i32, u8, 0_u8);
generate_cast_to_unsigned!(copy_i32_to_u16, i32, u16, 0_u16);
generate_cast_to_unsigned!(copy_i32_to_u32, i32, u32, 0_u32);

macro_rules! generate_cast_to_signed {
Expand Down Expand Up @@ -624,6 +622,9 @@ generate_cast_to_signed!(copy_i64_to_i64, i64, i64);
generate_cast_to_signed!(copy_i64_to_i128, i64, i128);
generate_cast_to_signed!(copy_u64_to_u128, u64, u128);
generate_cast_to_signed!(copy_f32_to_f64, f32, f64);
// even for u8/u16, need to copy full i16/i32 width for Spark compatibility
generate_cast_to_signed!(copy_i32_to_u8, i32, i16);
generate_cast_to_signed!(copy_i32_to_u16, i32, i32);

// Shared implementation for variants of Binary type
macro_rules! make_plain_binary_impl {
Expand Down Expand Up @@ -1096,7 +1097,7 @@ mod test {
let source =
hex::decode("8a000000dbffffff1800000034ffffff300000001d000000abffffff37fffffff1000000")
.unwrap();
let expected = hex::decode("8a00db001800340030001d00ab003700f100").unwrap();
let expected = hex::decode("8a00dbff180034ff30001d00abff37fff100").unwrap();
let num = source.len() / 4;
let mut dest: Vec<u8> = vec![b' '; num * 2];
copy_i32_to_u8(source.as_bytes(), dest.as_mut_slice(), num);
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ pub fn spark_round(
DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
DataType::Decimal128(_, scale) if *scale > 0 => {
DataType::Decimal128(_, scale) if *scale >= 0 => {
let f = decimal_round_f(scale, point);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_array(array, precision, scale, &f)
Expand Down
5 changes: 1 addition & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -861,10 +861,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// primitives
checkSparkAnswerAndOperator(
"SELECT CAST(struct(_1, _2, _3, _4, _5, _6, _7, _8) as string) FROM tbl")
// TODO: enable tests for unsigned ints (_9, _10, _11, _12) once
// https://github.com/apache/datafusion-comet/issues/1067 is resolved
// checkSparkAnswerAndOperator(
// "SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
checkSparkAnswerAndOperator("SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
// decimals
// TODO add _16 when https://github.com/apache/datafusion-comet/issues/1068 is resolved
checkSparkAnswerAndOperator("SELECT CAST(struct(_15, _17) as string) FROM tbl")
Expand Down
14 changes: 4 additions & 10 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
// TODO: enable test for unsigned ints
checkSparkAnswerAndOperator(
"select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " +
"_18, _19, _20 FROM tbl WHERE _2 > 100")
checkSparkAnswerAndOperator("select * FROM tbl WHERE _2 > 100")
}
}
}
Expand Down Expand Up @@ -1115,7 +1112,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100)
withParquetTable(path.toString, "tbl") {
Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col =>
Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col =>
checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl")
}
}
Expand Down Expand Up @@ -1239,9 +1236,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withParquetTable(path.toString, "tbl") {
for (s <- Seq(-5, -1, 0, 1, 5, -1000, 1000, -323, -308, 308, -15, 15, -16, 16, null)) {
// array tests
// TODO: enable test for unsigned ints (_9, _10, _11, _12)
// TODO: enable test for floats (_6, _7, _8, _13)
for (c <- Seq(2, 3, 4, 5, 15, 16, 17)) {
for (c <- Seq(2, 3, 4, 5, 9, 10, 11, 12, 15, 16, 17)) {
checkSparkAnswerAndOperator(s"select _${c}, round(_${c}, ${s}) FROM tbl")
}
// scalar tests
Expand Down Expand Up @@ -1452,9 +1448,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)

withParquetTable(path.toString, "tbl") {
// _9 and _10 (uint8 and uint16) not supported
checkSparkAnswerAndOperator(
"SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
"SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_9), hex(_10), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
}
}
}
Expand Down Expand Up @@ -2334,7 +2329,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
checkSparkAnswerAndOperator(
spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
$"_6",
$"_7",
$"_8",
$"_9",
$"_10",
$"_11",
$"_12",
$"_13",
$"_14",
$"_15",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ abstract class ParquetReadSuite extends CometTestBase {
i.toFloat,
i.toDouble,
i.toString * 48,
java.lang.Byte.toUnsignedInt((-i).toByte),
java.lang.Short.toUnsignedInt((-i).toShort),
(-i).toByte,
(-i).toShort,
java.lang.Integer.toUnsignedLong(-i),
new BigDecimal(UnsignedLong.fromLongBits((-i).toLong).bigIntegerValue()),
i.toString,
Expand Down

0 comments on commit ca3a529

Please sign in to comment.