Skip to content

Commit

Permalink
fix: Correctly handle take on dense union of a single selected type (#…
Browse files Browse the repository at this point in the history
…6209)

* fix: use filter instead of filter_primitive

* fix: remove pub(crate) from filter_primitive

* fix: run cargo fmt

* fix: clippy
  • Loading branch information
gstvg authored Aug 8, 2024
1 parent b90c799 commit 12ff1ea
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
5 changes: 1 addition & 4 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,7 @@ fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate)
}

/// `filter` implementation for primitive arrays
pub(crate) fn filter_primitive<T>(
array: &PrimitiveArray<T>,
predicate: &FilterPredicate,
) -> PrimitiveArray<T>
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
Expand Down
27 changes: 21 additions & 6 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};

use num::{One, Zero};

use crate::filter::{filter_primitive, FilterBuilder};

/// Take elements by index from [Array], creating a new [Array] from those indexes.
///
/// ```text
Expand Down Expand Up @@ -251,13 +249,12 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
let children = fields.iter()
.map(|(field_type_id, _)| {
let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
let predicate = FilterBuilder::new(&mask).build();

let indices = filter_primitive(&offsets, &predicate);
let indices = crate::filter::filter(&offsets, &mask)?;

let values = values.child(field_type_id);

take_impl(values, &indices)
take_impl(values, indices.as_primitive::<Int32Type>())
})
.collect::<Result<_, _>>()?;

Expand Down Expand Up @@ -885,7 +882,7 @@ mod tests {
use super::*;
use arrow_array::builder::*;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use arrow_schema::{Field, Fields, TimeUnit};
use arrow_schema::{Field, Fields, TimeUnit, UnionFields};

fn test_take_decimal_arrays(
data: Vec<Option<i128>>,
Expand Down Expand Up @@ -2308,4 +2305,22 @@ mod tests {
take(&union, &indices, None).unwrap().to_data()
);
}

#[test]
fn test_take_union_dense_all_match_issue_6206() {
let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));

let array = UnionArray::try_new(
fields,
ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
Some(ScalarBuffer::from_iter(0_i32..5)),
vec![ints],
)
.unwrap();

let indicies = Int64Array::from(vec![0, 2, 4]);
let array = take(&array, &indicies, None).unwrap();
assert_eq!(array.len(), 3);
}
}

0 comments on commit 12ff1ea

Please sign in to comment.