Skip to content

Commit

Permalink
feat: mask (#1900)
Browse files Browse the repository at this point in the history
Mask sets entries of an array to null. I like the analogy to light: the
array is a sequence of lights (each value might be a different
wavelength). Null is represented by the absence oflight. Placing a mask
(i.e. a piece of plastic with slits) over the array causes those values
where the mask is present (i.e. "on", "true") to be dark.

An example in pseudo-code:

```rust
a = [1, 2, 3, 4, 5]
a_mask = [t, f, f, t, f]
mask(a, a_mask) == [null, 2, 3, null, 5]
```

Specializations
---------------

I only fallback to Arrow for two of the core arrays:

- Sparse. I was skeptical that I could do better than decompressing and
applying it.
- Constant. If the mask is sparse, SparseArray might be a good choice. I
didn't investigate.

For the non-core arrays, I'm missing the following. I'm not clear that I
can beat decompression forrun end. The others are easy enough but some
amount of typing and testing.

- fastlanes
- fsst
- roaring
- runend
- runend-bool
- zigzag

Naming
------

Pandas also calls this operation
[`mask`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.mask.html)
but accepts an optional second argument which is an array of values to
use instead of null (which makes Pandas' mask more like an `if_else`).

Arrow-rs calls this
[`nullif`](https://arrow.apache.org/rust/arrow/compute/fn.nullif.html).

Arrow-cpp has [`if_else(condition, consequent,
alternate)`](https://arrow.apache.org/docs/cpp/compute.html#cpp-compute-scalar-selections)
and [`replace_with_mask(array, mask,
replacements)`](https://arrow.apache.org/docs/cpp/compute.html#replace-functions)
both of which can implement our `mask` by passing a `NullArray` as the
third argument.
  • Loading branch information
danking authored Feb 19, 2025
1 parent f420bca commit eadc1fe
Show file tree
Hide file tree
Showing 32 changed files with 1,220 additions and 84 deletions.
58 changes: 58 additions & 0 deletions encodings/alp/src/alp_rd/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use vortex_array::compute::{mask, MaskFn};
use vortex_array::{Array, IntoArray};
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::{ALPRDArray, ALPRDEncoding};

impl MaskFn<ALPRDArray> for ALPRDEncoding {
fn mask(&self, array: &ALPRDArray, filter_mask: Mask) -> VortexResult<Array> {
Ok(ALPRDArray::try_new(
array.dtype().as_nullable(),
mask(&array.left_parts(), filter_mask)?,
array.left_parts_dict(),
array.right_parts(),
array.right_bit_width(),
array.left_parts_patches(),
)?
.into_array())
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::test_harness::test_mask;
use vortex_array::IntoArray as _;

use crate::{ALPRDFloat, RDEncoder};

#[rstest]
#[case(0.1f32, 0.2f32, 3e25f32)]
#[case(0.1f64, 0.2f64, 3e100f64)]
fn test_mask_simple<T: ALPRDFloat>(#[case] a: T, #[case] b: T, #[case] outlier: T) {
test_mask(
RDEncoder::new(&[a, b])
.encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier]))
.into_array(),
);
}

#[rstest]
#[case(0.1f32, 3e25f32)]
#[case(0.5f64, 1e100f64)]
fn test_mask_with_nulls<T: ALPRDFloat>(#[case] a: T, #[case] outlier: T) {
test_mask(
RDEncoder::new(&[a])
.encode(&PrimitiveArray::from_option_iter([
Some(a),
None,
Some(outlier),
Some(a),
None,
]))
.into_array(),
);
}
}
7 changes: 6 additions & 1 deletion encodings/alp/src/alp_rd/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use vortex_array::compute::{FilterFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::compute::{FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::vtable::ComputeVTable;
use vortex_array::Array;

use crate::ALPRDEncoding;

mod filter;
mod mask;
mod scalar_at;
mod slice;
mod take;
Expand All @@ -14,6 +15,10 @@ impl ComputeVTable for ALPRDEncoding {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
Some(self)
}
Expand Down
22 changes: 21 additions & 1 deletion encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use num_traits::AsPrimitive;
use vortex_array::compute::{FillForwardFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::vtable::ComputeVTable;
Expand All @@ -16,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding {
None
}

fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
Some(self)
}
Expand All @@ -29,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding {
}
}

impl MaskFn<ByteBoolArray> for ByteBoolEncoding {
fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<Array> {
ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?)
.map(IntoArray::into_array)
}
}

impl ScalarAtFn<ByteBoolArray> for ByteBoolEncoding {
fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
Ok(Scalar::bool(
Expand Down Expand Up @@ -139,6 +150,7 @@ impl FillForwardFn<ByteBoolArray> for ByteBoolEncoding {

#[cfg(test)]
mod tests {
use vortex_array::compute::test_harness::test_mask;
use vortex_array::compute::{compare, scalar_at, slice, Operator};

use super::*;
Expand Down Expand Up @@ -211,4 +223,12 @@ mod tests {
let s = scalar_at(&arr, 4).unwrap();
assert!(s.is_null());
}

#[test]
fn test_mask_byte_bool() {
test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array());
test_mask(
ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
);
}
}
4 changes: 4 additions & 0 deletions encodings/dict/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ name = "dict_compare"
harness = false
required-features = ["test-harness"]

[[bench]]
name = "dict_mask"
harness = false

[[bench]]
name = "chunked_dict_array_builder"
harness = false
Expand Down
59 changes: 59 additions & 0 deletions encodings/dict/benches/dict_mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#![allow(clippy::unwrap_used)]

use divan::Bencher;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng as _};
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::mask;
use vortex_array::IntoArray as _;
use vortex_dict::DictArray;
use vortex_mask::Mask;

fn main() {
divan::main();
}

fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> Mask {
let indices = (0..len)
.filter(|_| rng.gen_bool(fraction_masked))
.collect::<Vec<usize>>();
Mask::from_indices(len, indices)
}

#[divan::bench(args = [
(0.9, 0.9),
(0.9, 0.5),
(0.9, 0.1),
(0.9, 0.01),
(0.5, 0.9),
(0.5, 0.5),
(0.5, 0.1),
(0.5, 0.01),
(0.1, 0.9),
(0.1, 0.5),
(0.1, 0.1),
(0.1, 0.01),
(0.01, 0.9),
(0.01, 0.5),
(0.01, 0.1),
(0.01, 0.01),
])]
fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f64)) {
let mut rng = StdRng::seed_from_u64(0);

let len = 65_535;
let codes = PrimitiveArray::from_iter((0..len).map(|_| {
if rng.gen_bool(fraction_valid) {
1u64
} else {
0u64
}
}))
.into_array();
let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array();
let array = DictArray::try_new(codes, values).unwrap().into_array();
let filter_mask = filter_mask(len, fraction_masked, &mut rng);
bencher
.with_inputs(|| (&array, filter_mask.clone()))
.bench_values(|(array, filter_mask)| mask(array, filter_mask).unwrap());
}
36 changes: 35 additions & 1 deletion encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ impl SliceFn<DictArray> for DictEncoding {
#[cfg(test)]
mod test {
use vortex_array::accessor::ArrayAccessor;
use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray};
use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
use vortex_array::compute::test_harness::test_mask;
use vortex_array::compute::{compare, scalar_at, slice, Operator};
use vortex_array::{Array, IntoArray, IntoArrayVariant};
use vortex_dtype::{DType, Nullability};
Expand Down Expand Up @@ -198,4 +199,37 @@ mod test {
Scalar::bool(true, Nullability::Nullable)
);
}

#[test]
fn test_mask_dict_array() {
let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array())
.unwrap()
.into_array();
test_mask(array);

let array = dict_encode(
&PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
.into_array(),
)
.unwrap()
.into_array();
test_mask(array);

let array = dict_encode(
&VarBinArray::from_iter(
[
Some("hello"),
None,
Some("hello"),
Some("good"),
Some("good"),
],
DType::Utf8(Nullability::Nullable),
)
.into_array(),
)
.unwrap()
.into_array();
test_mask(array);
}
}
38 changes: 36 additions & 2 deletions encodings/sparse/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@ impl FilterFn<SparseArray> for SparseEncoding {
mod test {
use rstest::{fixture, rstest};
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::test_harness::test_binary_numeric;
use vortex_array::compute::{filter, search_sorted, slice, SearchResult, SearchSortedSide};
use vortex_array::compute::test_harness::{test_binary_numeric, test_mask};
use vortex_array::compute::{
filter, search_sorted, slice, try_cast, SearchResult, SearchSortedSide,
};
use vortex_array::validity::Validity;
use vortex_array::{Array, IntoArray, IntoArrayVariant};
use vortex_buffer::buffer;
use vortex_dtype::{DType, Nullability, PType};
use vortex_mask::Mask;
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -223,4 +226,35 @@ mod test {
fn test_sparse_binary_numeric(array: Array) {
test_binary_numeric::<i32>(array)
}

#[test]
fn test_mask_sparse_array() {
let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
test_mask(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
try_cast(
buffer![100i32, 200, 300].into_array(),
null_fill_value.dtype(),
)
.unwrap(),
5,
null_fill_value,
)
.unwrap()
.into_array(),
);

let ten_fill_value = Scalar::from(10i32);
test_mask(
SparseArray::try_new(
buffer![1u64, 2, 4].into_array(),
buffer![100i32, 200, 300].into_array(),
5,
ten_fill_value,
)
.unwrap()
.into_array(),
)
}
}
10 changes: 3 additions & 7 deletions vortex-array/src/array/bool/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@ impl CastFn<BoolArray> for BoolEncoding {
vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype);
}

// If the types are the same, return the array,
// otherwise set the array nullability as the dtype nullability.
if dtype.is_nullable() || array.all_valid()? {
Ok(BoolArray::new(array.boolean_buffer(), dtype.nullability()).into_array())
} else {
vortex_bail!("Cannot cast null array to non-nullable type");
}
let new_nullability = dtype.nullability();
let new_validity = array.validity().cast_nullability(new_nullability)?;
BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array)
}
}

Expand Down
13 changes: 13 additions & 0 deletions vortex-array/src/array/bool/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::array::{BoolArray, BoolEncoding};
use crate::compute::MaskFn;
use crate::{Array, IntoArray};

impl MaskFn<BoolArray> for BoolEncoding {
fn mask(&self, array: &BoolArray, mask: Mask) -> VortexResult<Array> {
BoolArray::try_new(array.boolean_buffer(), array.validity().mask(&mask)?)
.map(IntoArray::into_array)
}
}
9 changes: 7 additions & 2 deletions vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::array::BoolEncoding;
use crate::compute::{
BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MinMaxFn, ScalarAtFn,
SliceFn, TakeFn, ToArrowFn,
BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, MinMaxFn,
ScalarAtFn, SliceFn, TakeFn, ToArrowFn,
};
use crate::vtable::ComputeVTable;
use crate::Array;
Expand All @@ -12,6 +12,7 @@ mod fill_null;
pub mod filter;
mod flatten;
mod invert;
mod mask;
mod min_max;
mod scalar_at;
mod slice;
Expand Down Expand Up @@ -43,6 +44,10 @@ impl ComputeVTable for BoolEncoding {
Some(self)
}

fn mask_fn(&self) -> Option<&dyn MaskFn<Array>> {
Some(self)
}

fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<Array>> {
Some(self)
}
Expand Down
6 changes: 6 additions & 0 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ mod tests {
use vortex_dtype::Nullability;

use crate::array::{BoolArray, PrimitiveArray};
use crate::compute::test_harness::test_mask;
use crate::compute::{scalar_at, slice};
use crate::patches::Patches;
use crate::validity::Validity;
Expand Down Expand Up @@ -374,4 +375,9 @@ mod tests {
let (values, _byte_bit_offset) = arr.into_bool().unwrap().into_boolean_builder();
assert_eq!(values.as_slice(), &[254, 127]);
}

#[test]
fn test_mask_primitive_array() {
test_mask(BoolArray::from_iter([true, false, true, true, false]).into_array());
}
}
Loading

0 comments on commit eadc1fe

Please sign in to comment.