diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 65ccbe1e01a0..8e06b07f5ef4 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -552,10 +552,7 @@ fn filter_native(values: &[T], predicate: &FilterPredicate) } /// `filter` implementation for primitive arrays -pub(crate) fn filter_primitive( - array: &PrimitiveArray, - predicate: &FilterPredicate, -) -> PrimitiveArray +fn filter_primitive(array: &PrimitiveArray, predicate: &FilterPredicate) -> PrimitiveArray where T: ArrowPrimitiveType, { diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index b66133ac71f0..ed7179fd36ce 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -31,8 +31,6 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode}; use num::{One, Zero}; -use crate::filter::{filter_primitive, FilterBuilder}; - /// Take elements by index from [Array], creating a new [Array] from those indexes. /// /// ```text @@ -251,13 +249,12 @@ fn take_impl( let children = fields.iter() .map(|(field_type_id, _)| { let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id); - let predicate = FilterBuilder::new(&mask).build(); - let indices = filter_primitive(&offsets, &predicate); + let indices = crate::filter::filter(&offsets, &mask)?; let values = values.child(field_type_id); - take_impl(values, &indices) + take_impl(values, indices.as_primitive::()) }) .collect::>()?; @@ -885,7 +882,7 @@ mod tests { use super::*; use arrow_array::builder::*; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use arrow_schema::{Field, Fields, TimeUnit}; + use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; fn test_take_decimal_arrays( data: Vec>, @@ -2308,4 +2305,22 @@ mod tests { take(&union, &indices, None).unwrap().to_data() ); } + + #[test] + fn test_take_union_dense_all_match_issue_6206() { + let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]); + let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + + let array = UnionArray::try_new( + fields, + ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]), + Some(ScalarBuffer::from_iter(0_i32..5)), + vec![ints], + ) + .unwrap(); + + let indicies = Int64Array::from(vec![0, 2, 4]); + let array = take(&array, &indicies, None).unwrap(); + assert_eq!(array.len(), 3); + } }