Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with handling of null values in dictionaries #70

Merged
merged 12 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use datafusion::arrow::array::{
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
PrimitiveArray, RunArray, StringArray, StringViewArray,
PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray,
};
use datafusion::arrow::compute::kernels::cast;
use datafusion::arrow::compute::take;
Expand Down Expand Up @@ -245,6 +245,34 @@ fn invoke_array_array<R: InvokeResult>(
}
}

/// Transform keys that may be pointing to values with nulls to nulls themselves.
/// keys = `[0, 1, 2, 3]`, values = `[null, "a", null, "b"]`
/// into
/// keys = `[null, 0, null, 1]`, values = `["a", "b"]`
///
/// Arrow / `DataFusion` assumes that dictionary values do not contain nulls, nulls are encoded by the keys.
/// Not following this invariant causes invalid dictionary arrays to be built later on inside of `DataFusion`
/// when arrays are concacted and such.
fn remap_dictionary_key_nulls(keys: PrimitiveArray<Int64Type>, values: ArrayRef) -> DictionaryArray<Int64Type> {
// fast path: no nulls in values
if values.null_count() == 0 {
return DictionaryArray::new(keys, values);
}

let mut new_keys_builder = PrimitiveBuilder::<Int64Type>::new();

for key in &keys {
match key {
Some(k) if values.is_null(k.as_usize()) => new_keys_builder.append_null(),
Some(k) => new_keys_builder.append_value(k),
None => new_keys_builder.append_null(),
}
}

let new_keys = new_keys_builder.finish();
DictionaryArray::new(new_keys, values)
}

fn invoke_array_scalars<R: InvokeResult>(
json_array: &ArrayRef,
path: &[JsonPath],
Expand Down Expand Up @@ -281,7 +309,7 @@ fn invoke_array_scalars<R: InvokeResult>(
let type_ids = values.as_union().type_ids();
keys = mask_dictionary_keys(&keys, type_ids);
}
Ok(Arc::new(DictionaryArray::new(keys, values)))
Ok(Arc::new(remap_dictionary_key_nulls(keys, values)))
} else {
// this is what cast would do under the hood to unpack a dictionary into an array of its values
Ok(take(&values, json_array.keys(), None)?)
Expand Down
66 changes: 64 additions & 2 deletions tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, RecordBatch};
use datafusion::arrow::datatypes::{Field, Int8Type, Schema};
use datafusion::arrow::array::{Array, ArrayRef, DictionaryArray, RecordBatch};
use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema};
use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType};
use datafusion::assert_batches_eq;
use datafusion::common::ScalarValue;
Expand Down Expand Up @@ -1280,6 +1280,68 @@ async fn test_dict_haystack() {
assert_batches_eq!(expected, &batches);
}

fn check_for_null_dictionary_values(array: &dyn Array) {
let array = array.as_any().downcast_ref::<DictionaryArray<Int64Type>>().unwrap();
let keys_array = array.keys();
let keys = keys_array
.iter()
.filter_map(|x| x.map(|v| usize::try_from(v).unwrap()))
.collect::<Vec<_>>();
let values_array = array.values();
// no non-null keys should point to a null value
for i in 0..values_array.len() {
if values_array.is_null(i) {
// keys should not contain
if keys.contains(&i) {
println!("keys: {:?}", keys);
println!("values: {:?}", values_array);
panic!("keys should not contain null values");
}
}
}
}

/// Test that we don't output nulls in dictionary values.
/// This can cause issues with arrow-rs and DataFusion; they expect nulls to be in keys.
#[tokio::test]
async fn test_dict_get_no_null_values() {
let ctx = build_dict_schema().await;

let sql = "select json_get(x, 'baz') v from data";
let expected = [
"+------------+",
"| v |",
"+------------+",
"| |",
"| {str=fizz} |",
"| |",
"| {str=abcd} |",
"| |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| |",
"+------------+",
];
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
assert_batches_eq!(expected, &batches);
for batch in batches {
check_for_null_dictionary_values(batch.column(0).as_ref());
}

let sql = "select json_get_str(x, 'baz') v from data";
let expected = [
"+------+", "| v |", "+------+", "| |", "| fizz |", "| |", "| abcd |", "| |", "| fizz |",
"| fizz |", "| fizz |", "| fizz |", "| |", "+------+",
];
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
assert_batches_eq!(expected, &batches);
for batch in batches {
check_for_null_dictionary_values(batch.column(0).as_ref());
}
}

#[tokio::test]
async fn test_dict_haystack_filter() {
let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null";
Expand Down