Skip to content

Commit

Permalink
chore: extract static invoke expressions to folders based on spark gr…
Browse files Browse the repository at this point in the history
…ouping (#1217)

* extract static invoke expressions to folders based on spark grouping

* Update native/spark-expr/src/static_invoke/mod.rs

Co-authored-by: Andy Grove <agrove@apache.org>

---------

Co-authored-by: Andy Grove <agrove@apache.org>
  • Loading branch information
rluvaton and andygrove authored Jan 6, 2025
1 parent e39ffa6 commit 5c389d1
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 58 deletions.
5 changes: 3 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use crate::scalar_funcs::hash_expressions::{
};
use crate::scalar_funcs::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round,
spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc,
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_round, spark_unhex,
spark_unscaled_value, spark_xxhash64, SparkChrFunc,
};
use crate::spark_read_side_padding;
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ mod list;
mod regexp;
pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
pub use schema_adapter::SparkSchemaAdapterFactory;
pub use static_invoke::*;

pub mod spark_hash;
mod stddev;
Expand Down
59 changes: 3 additions & 56 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,23 @@ use arrow::datatypes::IntervalDayTime;
use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
Int64Array, Int64Builder, Int8Array,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder};
use arrow_array::builder::IntervalDayTimeBuilder;
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
use datafusion::physical_expr_common::datum;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Result as DataFusionResult, ScalarValue,
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
};
use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
};
use std::fmt::Write;
use std::{cmp::min, sync::Arc};

mod unhex;
Expand Down Expand Up @@ -390,57 +388,6 @@ pub fn spark_round(
}
}

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function read_side_padding",
))),
}
}

fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);

for string in string_array.iter() {
match string {
Some(string) => {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
builder.append_value(string);
} else {
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
}
}
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
Expand Down
20 changes: 20 additions & 0 deletions native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod read_side_padding;

pub use read_side_padding::spark_read_side_padding;
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow_array::builder::GenericStringBuilder;
use arrow_array::Array;
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use std::fmt::Write;
use std::sync::Arc;

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function read_side_padding",
))),
}
}

fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);

for string in string_array.iter() {
match string {
Some(string) => {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
builder.append_value(string);
} else {
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
}
}
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
20 changes: 20 additions & 0 deletions native/spark-expr/src/static_invoke/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod char_varchar_utils;

pub use char_varchar_utils::spark_read_side_padding;

0 comments on commit 5c389d1

Please sign in to comment.