Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extended log.rs tests for unary/binary and f32/f64 casting #13034

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,94 @@ mod tests {

use super::*;

use arrow::array::{Float32Array, Float64Array};
use arrow::array::{Float32Array, Float64Array, Int64Array};
use arrow::compute::SortOptions;
use datafusion_common::cast::{as_float32_array, as_float64_array};
use datafusion_common::DFSchema;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::simplify::SimplifyContext;

#[test]
#[should_panic]
fn test_log_invalid_base_type() {
let args = [
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
10.0, 100.0, 1000.0, 10000.0,
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

let _ = LogFunc::new().invoke(&args);
}

#[test]
fn test_log_invalid_value() {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

let result = LogFunc::new().invoke(&args);
result.expect_err("expected error");
}

#[test]
fn test_log_f64_unary() {
let args = [
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

let result = LogFunc::new()
.invoke(&args)
.expect("failed to initialize function log");

match result {
ColumnarValue::Array(arr) => {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");

assert_eq!(floats.len(), 4);
assert!((floats.value(0) - 1.0).abs() < 1e-10);
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}

#[test]
fn test_log_f32_unary() {
let args = [
ColumnarValue::Array(Arc::new(Float32Array::from(vec![
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

let result = LogFunc::new()
.invoke(&args)
.expect("failed to initialize function log");

match result {
ColumnarValue::Array(arr) => {
let floats = as_float32_array(&arr)
.expect("failed to convert result to a Float64Array");

assert_eq!(floats.len(), 4);
assert!((floats.value(0) - 1.0).abs() < 1e-10);
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}

#[test]
fn test_log_f64() {
let args = [
Expand Down
31 changes: 31 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,37 @@ select log(a, 64) a, log(b), log(10, b) from signed_integers;
NaN 2 2
NaN 4 4

# log overloaded base 10 float64 and float32 casting scalar
query RR rowsort
select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b;
----
1 2

# log overloaded base 10 float64 and float32 casting with columns
query RR rowsort
select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers;
----
0.301029995664 NaN
0.602059991328 NULL
NaN 2
NaN 4

# log float64 and float32 casting scalar
query RR rowsort
select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b;
----
3 4

# log float64 and float32 casting with columns
query RR rowsort
select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers;
----
1 NaN
2 NULL
NaN 3.321928
NaN 6.643856


## log10

# log10 scalar function
Expand Down