Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix take with nullable indices on arrays with patches #2336

Merged
merged 7 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use vortex_scalar::Scalar;
use crate::{match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPFloat};

impl ComputeVTable for ALPEncoding {
fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
Some(self)
}

fn filter_fn(&self) -> Option<&dyn FilterFn<Array>> {
Some(self)
}
Expand All @@ -28,10 +32,6 @@ impl ComputeVTable for ALPEncoding {
fn take_fn(&self) -> Option<&dyn TakeFn<Array>> {
Some(self)
}

fn compare_fn(&self) -> Option<&dyn CompareFn<Array>> {
Some(self)
}
}

impl ScalarAtFn<ALPArray> for ALPEncoding {
Expand Down Expand Up @@ -60,16 +60,21 @@ impl ScalarAtFn<ALPArray> for ALPEncoding {

impl TakeFn<ALPArray> for ALPEncoding {
fn take(&self, array: &ALPArray, indices: &Array) -> VortexResult<Array> {
Ok(ALPArray::try_new(
take(array.encoded(), indices)?,
array.exponents(),
array
.patches()
.map(|p| p.take(indices))
.transpose()?
.flatten(),
)?
.into_array())
let taken_encoded = take(array.encoded(), indices)?;
let taken_patches = array
.patches()
.map(|p| p.take(indices))
.transpose()?
.flatten()
.map(|p| {
p.cast_values(
&array
.dtype()
.with_nullability(taken_encoded.dtype().nullability()),
)
})
.transpose()?;
Ok(ALPArray::try_new(taken_encoded, array.exponents(), taken_patches)?.into_array())
}
}

Expand Down
13 changes: 5 additions & 8 deletions encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ impl ALPRDArray {
PType::try_from(left_parts.dtype()).vortex_expect("left_parts dtype must be uint");

// we enforce right_parts to be non-nullable uint
if right_parts.dtype().is_nullable() {
vortex_bail!("right_parts dtype must be non-nullable");
}
if !right_parts.dtype().is_unsigned_int() || right_parts.dtype().is_nullable() {
vortex_bail!(MismatchedTypes: "non-nullable uint", right_parts.dtype());
}
Expand All @@ -82,11 +79,11 @@ impl ALPRDArray {

let patches = left_parts_patches
.map(|patches| {
if patches.values().dtype().is_nullable() {
vortex_bail!("patches must be non-nullable: {}", patches.values());
if !patches.values().all_valid()? {
vortex_bail!("patches must be all valid: {}", patches.values());
}
let metadata =
patches.to_metadata(left_parts.len(), &left_parts.dtype().as_nonnullable());
let patches = patches.cast_values(left_parts.dtype())?;
let metadata = patches.to_metadata(left_parts.len(), left_parts.dtype());
let (_, _, indices, values) = patches.into_parts();
children.push(indices);
children.push(values);
Expand Down Expand Up @@ -146,7 +143,7 @@ impl ALPRDArray {
/// The dtype of the patches of the left parts of the array.
#[inline]
fn left_parts_patches_dtype(&self) -> DType {
DType::Primitive(self.metadata().left_parts_ptype, Nullability::NonNullable)
DType::Primitive(self.metadata().left_parts_ptype, self.dtype().nullability())
}

/// The leftmost (most significant) bits of the floating point values stored in the array.
Expand Down
48 changes: 44 additions & 4 deletions encodings/alp/src/alp_rd/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
use vortex_array::compute::{take, TakeFn};
use vortex_array::compute::{fill_null, take, TakeFn};
use vortex_array::{Array, IntoArray};
use vortex_error::VortexResult;
use vortex_scalar::{Scalar, ScalarValue};

use crate::{ALPRDArray, ALPRDEncoding};

impl TakeFn<ALPRDArray> for ALPRDEncoding {
fn take(&self, array: &ALPRDArray, indices: &Array) -> VortexResult<Array> {
let taken_left_parts = take(array.left_parts(), indices)?;
let left_parts_exceptions = array
.left_parts_patches()
.map(|patches| patches.take(indices))
.transpose()?
.flatten();
.flatten()
.map(|p| {
let values_dtype = p
.values()
.dtype()
.with_nullability(taken_left_parts.dtype().nullability());
p.cast_values(&values_dtype)
})
.transpose()?;

let taken_left_parts = take(array.left_parts(), indices)?;
let right_parts = fill_null(
take(array.right_parts(), indices)?,
Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
)?;
Ok(ALPRDArray::try_new(
Comment on lines +25 to +28
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a thing to alp-rd to not require right parts to be nonnullable and just fill in the value with 0 (left parts will be null anyway)

if taken_left_parts.dtype().is_nullable() {
array.dtype().as_nullable()
Expand All @@ -21,7 +34,7 @@ impl TakeFn<ALPRDArray> for ALPRDEncoding {
},
taken_left_parts,
array.left_parts_dict(),
take(array.right_parts(), indices)?,
right_parts,
array.right_bit_width(),
left_parts_exceptions,
)?
Expand Down Expand Up @@ -59,4 +72,31 @@ mod test {

assert_eq!(taken.as_slice::<T>(), &[a, outlier]);
}

#[rstest]
#[case(0.1f32, 0.2f32, 3e25f32)]
#[case(0.1f64, 0.2f64, 3e100f64)]
fn take_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
let array = PrimitiveArray::from_iter([a, b, outlier]);
let encoded = RDEncoder::new(&[a, b]).encode(&array);

assert!(encoded.left_parts_patches().is_some());
assert!(encoded
.left_parts_patches()
.unwrap()
.dtype()
.is_unsigned_int());

let taken = take(
encoded.as_ref(),
PrimitiveArray::from_option_iter([Some(0), Some(2), None]).as_ref(),
)
.unwrap()
.into_primitive()
.unwrap();

assert_eq!(taken.as_slice::<T>()[0], a);
assert_eq!(taken.as_slice::<T>()[1], outlier);
assert!(!taken.validity_mask().unwrap().value(2));
}
}
22 changes: 20 additions & 2 deletions encodings/fastlanes/src/bitpacking/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ fn take_primitive<T: NativePType + BitPacking, I: NativePType>(
}
if let Some(patches) = array.patches() {
if let Some(patches) = patches.take(indices)? {
return unpatched_taken.patch(patches);
let cast_patches = patches.cast_values(unpatched_taken.dtype())?;
return unpatched_taken.patch(cast_patches);
}
}

Expand All @@ -131,6 +132,7 @@ mod test {
use vortex_array::{IntoArray, IntoArrayVariant};
use vortex_buffer::{buffer, Buffer};

use crate::bitpacking::compute::take::take_primitive;
use crate::BitPackedArray;

#[test]
Expand Down Expand Up @@ -217,12 +219,28 @@ mod test {
let start =
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();

let taken_primitive = super::take_primitive::<u32, u64>(
let taken_primitive = take_primitive::<u32, u64>(
&start,
&PrimitiveArray::from_iter([0u64, 1, 2, 3]),
Validity::NonNullable,
)
.unwrap();
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 3, 4]);
}

#[test]
fn take_nullable_with_nullables() {
let start =
BitPackedArray::encode(&buffer![1i32, 2i32, 3i32, 4i32].into_array(), 1).unwrap();

let taken_primitive = take(
&start,
PrimitiveArray::from_option_iter([Some(0u64), Some(1), None, Some(3)]),
)
.unwrap()
.into_primitive()
.unwrap();
assert_eq!(taken_primitive.as_slice::<i32>(), &[1i32, 2, 1, 4]);
assert_eq!(taken_primitive.invalid_count().unwrap(), 1);
}
}
24 changes: 23 additions & 1 deletion vortex-array/src/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::aliases::hash_map::HashMap;
use crate::array::PrimitiveArray;
use crate::compute::{
filter, scalar_at, search_sorted, search_sorted_usize, search_sorted_usize_many, slice, take,
SearchResult, SearchSortedSide,
try_cast, SearchResult, SearchSortedSide,
};
use crate::variants::PrimitiveArrayTrait;
use crate::{Array, IntoArray, IntoArrayVariant};
Expand Down Expand Up @@ -110,6 +110,19 @@ impl Patches {
offset,
array_len
);
Self::new_unchecked(array_len, offset, indices, values)
}

/// Construct new patches without validating any of the arguments
///
/// # Safety
///
/// Users have to assert that
/// * Indices and values have the same length
/// * Indices is an unsigned integer type
/// * Indices must be sorted
/// * Last value in indices is smaller than array_len
pub fn new_unchecked(array_len: usize, offset: usize, indices: Array, values: Array) -> Self {
Self {
array_len,
offset,
Expand Down Expand Up @@ -180,6 +193,15 @@ impl Patches {
})
}

pub fn cast_values(self, values_dtype: &DType) -> VortexResult<Self> {
Ok(Self::new_unchecked(
self.array_len,
self.offset,
self.indices,
try_cast(self.values, values_dtype)?,
))
}

/// Get the patched value at a given index if it exists.
pub fn get_patched(&self, index: usize) -> VortexResult<Option<Scalar>> {
if let Some(patch_idx) = self.search_index(index)?.to_found() {
Expand Down
Loading