Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Optimize read_side_padding #772

Merged
merged 8 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
};
use datafusion_comet_spark_expr::scalar_funcs::{
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_murmur3_hash, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, spark_xxhash64,
SparkChrFunc,
spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value,
spark_xxhash64, SparkChrFunc,
};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down Expand Up @@ -67,9 +67,9 @@ pub fn create_comet_physical_fun(
"floor" => {
make_comet_scalar_udf!("floor", spark_floor, data_type)
}
"rpad" => {
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
"read_side_padding" => {
let func = Arc::new(spark_read_side_padding);
make_comet_scalar_udf!("read_side_padding", func, without data_type)
}
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
Expand Down
15 changes: 10 additions & 5 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1724,11 +1724,16 @@ impl PhysicalPlanner {

let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => self
.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?,
None => {
let fun_name = match fun_name.as_str() {
"read_side_padding" => "rpad", // use the same return type as rpad
other => other,
};
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
}
};

let fun_expr =
Expand Down
1 change: 0 additions & 1 deletion native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
unicode-segmentation = "1.11.0"

[dev-dependencies]
arrow-data = {workspace = true}
Expand Down
62 changes: 32 additions & 30 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::{cmp::min, sync::Arc};

use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::builder::GenericStringBuilder;
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
Expand All @@ -35,7 +34,8 @@ use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
};
use unicode_segmentation::UnicodeSegmentation;
use std::fmt::Write;
use std::{cmp::min, sync::Arc};

mod unhex;
pub use unhex::spark_unhex;
Expand Down Expand Up @@ -387,52 +387,54 @@ pub fn spark_round(
}

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match args[0].data_type() {
DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, *length),
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad",
"Unsupported data type {other:?} for function read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad",
"Unsupported arguments {other:?} for function read_side_padding",
))),
}
}

fn spark_rpad_internal<T: OffsetSizeTrait>(
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);

let result = string_array
.iter()
.map(|string| match string {
for string in string_array.iter() {
match string {
Some(string) => {
let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
Comment on lines +422 to +423
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an unit test for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

let char_len = string.chars().count();
if length <= char_len {
builder.append_value(string);
Comment on lines +425 to +426
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the required len is less than string's length, don't we need to take substring of it? Spark RPad does it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implementation already has this issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the line 389 there is an existing comment

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length

Perhaps I should change the name of this method, this is not used for rpad

} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
Ok(Some(string.to_string()))
} else {
let mut s = string.to_string();
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
}
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
Ok(ColumnarValue::Array(Arc::new(result)))
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
// char types.
// See https://github.com/apache/spark/pull/38151
case s: StaticInvoke
if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
Expand All @@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

if (argsExpr.forall(_.isDefined)) {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
builder.setFunc("rpad")
builder.setFunc("read_side_padding")
argsExpr.foreach(arg => builder.addArgs(arg.get))

Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
Expand Down
7 changes: 7 additions & 0 deletions spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how this test related to rpad? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are related as their schema types are CHAR()

cd_gender
FROM customer_demographics
WHERE
cd_gender = 'M' AND
cd_marital_status = 'S' AND
cd_education_status = 'College'
14 changes: 14 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("readSidePadding") {
// https://stackoverflow.com/a/46290728
val table = "test"
withTable(table) {
sql(s"create table $table(col1 CHAR(2)) using parquet")
sql(s"insert into $table values('é')") // unicode 'e\\u{301}'
sql(s"insert into $table values('é')") // unicode '\\u{e9}'
sql(s"insert into $table values('')")
sql(s"insert into $table values('ab')")

checkSparkAnswerAndOperator(s"SELECT * FROM $table")
}
}

test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase {
"agg_sum_integers_no_grouping",
"case_when_column_or_null",
"case_when_scalar",
"char_type",
"filter_highly_selective",
"filter_less_selective",
"if_column_or_null",
Expand Down
Loading