Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add spark_signed_integer_remainder native function for compatibility with spark #1416

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,22 @@ impl PhysicalPlanner {
data_type,
)))
}
(DataFusionOperator::Modulo, Ok(l), Ok(r))
if l.is_signed_integer() && r.is_signed_integer() =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
let fun_expr = create_comet_physical_fun(
"signed_integer_remainder",
data_type.clone(),
&self.session_ctx.state(),
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
"signed_integer_remainder",
fun_expr,
vec![left, right],
data_type,
)))
}
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
}
}
Expand Down
4 changes: 4 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ harness = false
name = "decimal_div"
harness = false

[[bench]]
name = "signed_integer_remainder"
harness = false

[[bench]]
name = "aggregate"
harness = false
Expand Down
50 changes: 50 additions & 0 deletions native/spark-expr/benches/signed_integer_remainder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::builder::Int64Builder;
use arrow_schema::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::spark_signed_integer_remainder;
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
// create input data
let mut c1 = Int64Builder::new();
let mut c2 = Int64Builder::new();
for i in 0..1000 {
c1.append_value(99999999 + i);
c2.append_value(88888888 - i);
c1.append_value(i64::MIN);
c2.append_value(-1);
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());

let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)];
c.bench_function("signed_integer_remainder", |b| {
b.iter(|| {
black_box(spark_signed_integer_remainder(
black_box(&args),
black_box(&DataType::Int64),
))
})
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
11 changes: 9 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
use crate::hash_funcs::*;
use crate::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex,
spark_unscaled_value, SparkChrFunc,
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
spark_signed_integer_remainder, spark_unhex, spark_unscaled_value, SparkChrFunc,
};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -90,6 +90,13 @@ pub fn create_comet_physical_fun(
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
"signed_integer_remainder" => {
make_comet_scalar_udf!(
"signed_integer_remainder",
spark_signed_integer_remainder,
data_type
)
}
"murmur3_hash" => {
let func = Arc::new(spark_murmur3_hash);
make_comet_scalar_udf!("murmur3_hash", func, without data_type)
Expand Down
4 changes: 2 additions & 2 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ pub use hash_funcs::*;
pub use json_funcs::ToJson;
pub use math_funcs::{
create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal,
spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr,
NormalizeNaNAndZero,
spark_round, spark_signed_integer_remainder, spark_unhex, spark_unscaled_value, CheckOverflow,
NegativeExpr, NormalizeNaNAndZero,
};
pub use string_funcs::*;

Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/math_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod floor;
pub(crate) mod hex;
pub mod internal;
mod negative;
mod remainder;
mod round;
pub(crate) mod unhex;
mod utils;
Expand All @@ -31,5 +32,6 @@ pub use floor::spark_floor;
pub use hex::spark_hex;
pub use internal::*;
pub use negative::{create_negate_expr, NegativeExpr};
pub use remainder::spark_signed_integer_remainder;
pub use round::spark_round;
pub use unhex::spark_unhex;
105 changes: 105 additions & 0 deletions native/spark-expr/src/math_funcs/remainder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::{Array, ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array};
use arrow_schema::DataType;
use datafusion_common::DataFusionError;
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::ops::Rem;
use std::sync::Arc;

macro_rules! signed_integer_rem {
($left:expr, $right:expr, $array_type:ty, $min_val:expr) => {{
let left = $left.as_any().downcast_ref::<$array_type>().unwrap();
let right = $right.as_any().downcast_ref::<$array_type>().unwrap();
let result: $array_type =
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
if l == $min_val && r == -1 {
Ok(0)
} else {
Ok(l.rem(r))
}
})?;
Ok(ColumnarValue::Array(Arc::new(result)))
}};
}

// spark compatible `remainder` function for signed integers.
// COMET-1412: if the left is the minimum integer value and the right is -1, the result is 0.
pub fn spark_signed_integer_remainder(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];

let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
(l.to_array_of_size(r.len())?, Arc::clone(r))
}
(ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
(Arc::clone(l), r.to_array_of_size(l.len())?)
}
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
};
match (left.data_type(), right.data_type(), data_type) {
(DataType::Int8, DataType::Int8, DataType::Int8) => {
signed_integer_rem!(left, right, Int8Array, i8::MIN)
}
(DataType::Int16, DataType::Int16, DataType::Int16) => {
signed_integer_rem!(left, right, Int16Array, i16::MIN)
}
(DataType::Int32, DataType::Int32, DataType::Int32) => {
signed_integer_rem!(left, right, Int32Array, i32::MIN)
}
(DataType::Int64, DataType::Int64, DataType::Int64) => {
signed_integer_rem!(left, right, Int64Array, i64::MIN)
}
_ => Err(DataFusionError::Internal(format!(
"Invalid data type for spark_signed_integer_remainder operation: {:?} {:?}",
left.data_type(),
right.data_type()
))),
}
}

#[cfg(test)]
mod test {
use crate::math_funcs::remainder::spark_signed_integer_remainder;
use arrow_array::Int8Array;
use arrow_schema::DataType;
use datafusion_common::cast::as_int8_array;
use datafusion_expr_common::columnar_value::ColumnarValue;
use std::sync::Arc;

#[test]
fn test_spark_signed_integer_remainder() -> datafusion_common::Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int8Array::from(vec![9i8, i8::MIN]))),
ColumnarValue::Array(Arc::new(Int8Array::from(vec![5i8, -1]))),
];
let ColumnarValue::Array(result) = spark_signed_integer_remainder(&args, &DataType::Int8)?
else {
unreachable!()
};
let results = as_int8_array(&result)?;
let expected = Int8Array::from(vec![4, 0]);
assert_eq!(results, &expected);
Ok(())
}
}
14 changes: 14 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2641,4 +2641,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("COMET-1412: test smallest signed integer value % -1") {
Seq[(String, Any)](
("short", Short.MinValue),
("int", Int.MinValue),
("long", Long.MinValue),
("double", Double.MinValue)).foreach { case (t, v) =>
withTable("t1") {
sql(s"create table t1(c1 $t, c2 short) using parquet")
sql(s"insert into t1 values($v, -1), (52, 10), (10, 0)")
checkSparkAnswerAndOperator("select c1 % c2, c1 % 1 from t1 order by c1")
}
}
}

}
Loading