Skip to content

Commit

Permalink
Fix take with nullable indices on arrays with patches (#2336)
Browse files Browse the repository at this point in the history
If the array is nonnullable after a take it might end up being nullable,
however, same take might not affect the values of patches thus resulting
in
nonullable patches. Need to ensure that after take the patches values
have same
nullability as the target array
  • Loading branch information
robert3005 authored Feb 12, 2025
1 parent 9d1e36f commit ea4a432
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 29 deletions.
33 changes: 19 additions & 14 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use vortex_scalar::Scalar;
use crate::{match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPFloat};

impl ComputeVTable for ALPEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<Array>> {
Some(self)
}
Expand All @@ -28,10 +32,6 @@ impl ComputeVTable for ALPEncoding {
fn take_fn(&self) -> Option<&dyn TakeFn<Array>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
Some(self)
}
}

impl ScalarAtFn<ALPArray> for ALPEncoding {
Expand Down Expand Up @@ -60,16 +60,21 @@ impl ScalarAtFn<ALPArray> for ALPEncoding {

impl TakeFn<ALPArray> for ALPEncoding {
fn take(&self, array: &ALPArray, indices: &Array) -> VortexResult<Array> {
Ok(ALPArray::try_new(
take(array.encoded(), indices)?,
array.exponents(),
array
.patches()
.map(|p| p.take(indices))
.transpose()?
.flatten(),
)?
.into_array())
let taken_encoded = take(array.encoded(), indices)?;
let taken_patches = array
.patches()
.map(|p| p.take(indices))
.transpose()?
.flatten()
.map(|p| {
p.cast_values(
&array
.dtype()
.with_nullability(taken_encoded.dtype().nullability()),
)
})
.transpose()?;
Ok(ALPArray::try_new(taken_encoded, array.exponents(), taken_patches)?.into_array())
}
}

Expand Down
13 changes: 5 additions & 8 deletions encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ impl ALPRDArray {
PType::try_from(left_parts.dtype()).vortex_expect("left_parts dtype must be uint");

// we enforce right_parts to be non-nullable uint
if right_parts.dtype().is_nullable() {
vortex_bail!("right_parts dtype must be non-nullable");
}
if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
}
Expand All @@ -82,11 +79,11 @@ impl ALPRDArray {

let patches = left_parts_patches
.map(|patches| {
if patches.values().dtype().is_nullable() {
vortex_bail!("patches must be non-nullable: {}", patches.values());
if !patches.values().all_valid()? {
vortex_bail!("patches must be all valid: {}", patches.values());
}
let metadata =
patches.to_metadata(left_parts.len(), &left_parts.dtype().as_nonnullable());
let patches = patches.cast_values(left_parts.dtype())?;
let metadata = patches.to_metadata(left_parts.len(), left_parts.dtype());
let (_, _, indices, values) = patches.into_parts();
children.push(indices);
children.push(values);
Expand Down Expand Up @@ -146,7 +143,7 @@ impl ALPRDArray {
/// The dtype of the patches of the left parts of the array.
#[inline]
fn left_parts_patches_dtype(&self) -> DType {
DType::Primitive(self.metadata().left_parts_ptype, Nullability::NonNullable)
DType::Primitive(self.metadata().left_parts_ptype, self.dtype().nullability())
}

/// The leftmost (most significant) bits of the floating point values stored in the array.
Expand Down
48 changes: 44 additions & 4 deletions encodings/alp/src/alp_rd/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
use vortex_array::compute::{take, TakeFn};
use vortex_array::compute::{fill_null, take, TakeFn};
use vortex_array::{Array, IntoArray};
use vortex_error::VortexResult;
use vortex_scalar::{Scalar, ScalarValue};

use crate::{ALPRDArray, ALPRDEncoding};

impl TakeFn<ALPRDArray> for ALPRDEncoding {
fn take(&self, array: &ALPRDArray, indices: &Array) -> VortexResult<Array> {
let taken_left_parts = take(array.left_parts(), indices)?;
let left_parts_exceptions = array
.left_parts_patches()
.map(|patches| patches.take(indices))
.transpose()?
.flatten();
.flatten()
.map(|p| {
let values_dtype = p
.values()
.dtype()
.with_nullability(taken_left_parts.dtype().nullability());
p.cast_values(&values_dtype)
})
.transpose()?;

let taken_left_parts = take(array.left_parts(), indices)?;
let right_parts = fill_null(
take(array.right_parts(), indices)?,
Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
)?;
Ok(ALPRDArray::try_new(
if taken_left_parts.dtype().is_nullable() {
array.dtype().as_nullable()
Expand All @@ -21,7 +34,7 @@ impl TakeFn<ALPRDArray> for ALPRDEncoding {
},
taken_left_parts,
array.left_parts_dict(),
take(array.right_parts(), indices)?,
right_parts,
array.right_bit_width(),
left_parts_exceptions,
)?
Expand Down Expand Up @@ -59,4 +72,31 @@ mod test {

assert_eq!(taken.as_slice::<T>(), &[a, outlier]);
}

#[rstest]
#[case(0.1f32, 0.2f32, 3e25f32)]
#[case(0.1f64, 0.2f64, 3e100f64)]
fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
let array = PrimitiveArray::from_iter([a, b, outlier]);
let encoded = RDEncoder::new(&[a, b]).encode(&array);

assert!(encoded.left_parts_patches().is_some());
assert!(encoded
.left_parts_patches()
.unwrap()
.dtype()
.is_unsigned_int());

let taken = take(
encoded.as_ref(),
PrimitiveArray::from_option_iter([Some(0), Some(2), None]).as_ref(),
)
.unwrap()
.into_primitive()
.unwrap();

assert_eq!(taken.as_slice::<T>()[0], a);
assert_eq!(taken.as_slice::<T>()[1], outlier);
assert!(!taken.validity_mask().unwrap().value(2));
}
}
22 changes: 20 additions & 2 deletions encodings/fastlanes/src/bitpacking/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ fn take_primitive<T: NativePType + BitPacking, I: NativePType>(
}
if let Some(patches) = array.patches() {
if let Some(patches) = patches.take(indices)? {
return unpatched_taken.patch(patches);
let cast_patches = patches.cast_values(unpatched_taken.dtype())?;
return unpatched_taken.patch(cast_patches);
}
}

Expand All @@ -131,6 +132,7 @@ mod test {
use vortex_array::{IntoArray, IntoArrayVariant};
use vortex_buffer::{buffer, Buffer};

use crate::bitpacking::compute::take::take_primitive;
use crate::BitPackedArray;

#[test]
Expand Down Expand Up @@ -217,12 +219,28 @@ mod test {
let start =
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();

let taken_primitive = super::take_primitive::<u32, u64>(
let taken_primitive = take_primitive::<u32, u64>(
&start,
&PrimitiveArray::from_iter([0u64, 1, 2, 3]),
Validity::NonNullable,
)
.unwrap();
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 3, 4]);
}

#[test]
fn take_nullable_with_nullables() {
let start =
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();

let taken_primitive = take(
&start,
PrimitiveArray::from_option_iter([Some(0u64), Some(1), None, Some(3)]),
)
.unwrap()
.into_primitive()
.unwrap();
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 1, 4]);
assert_eq!(taken_primitive.invalid_count().unwrap(), 1);
}
}
24 changes: 23 additions & 1 deletion vortex-array/src/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::aliases::hash_map::HashMap;
use crate::array::PrimitiveArray;
use crate::compute::{
filter, scalar_at, search_sorted, search_sorted_usize, search_sorted_usize_many, slice, take,
SearchResult, SearchSortedSide,
try_cast, SearchResult, SearchSortedSide,
};
use crate::variants::PrimitiveArrayTrait;
use crate::{Array, IntoArray, IntoArrayVariant};
Expand Down Expand Up @@ -110,6 +110,19 @@ impl Patches {
offset,
array_len
);
Self::new_unchecked(array_len, offset, indices, values)
}

/// Construct new patches without validating any of the arguments
///
/// # Safety
///
/// Users have to assert that
/// * Indices and values have the same length
/// * Indices is an unsigned integer type
/// * Indices must be sorted
/// * Last value in indices is smaller than array_len
pub fn new_unchecked(array_len: usize, offset: usize, indices: Array, values: Array) -> Self {
Self {
array_len,
offset,
Expand Down Expand Up @@ -180,6 +193,15 @@ impl Patches {
})
}

pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
Ok(Self::new_unchecked(
self.array_len,
self.offset,
self.indices,
try_cast(self.values, values_dtype)?,
))
}

/// Get the patched value at a given index if it exists.
pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
if let Some(patch_idx) = self.search_index(index)?.to_found() {
Expand Down

0 comments on commit ea4a432

Please sign in to comment.