diff --git a/encodings/runend/src/compress.rs b/encodings/runend/src/compress.rs index ca14dd4cd4..720b474dcb 100644 --- a/encodings/runend/src/compress.rs +++ b/encodings/runend/src/compress.rs @@ -5,7 +5,10 @@ use vortex_array::validity::Validity; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{Array, IntoArray, IntoArrayVariant}; use vortex_buffer::{buffer, Buffer, BufferMut}; -use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType, Nullability}; +use vortex_dtype::{ + match_each_integer_ptype, match_each_native_ptype, match_each_native_simd_ptype, NativePType, + Nullability, +}; use vortex_error::{VortexExpect, VortexResult}; use vortex_mask::Mask; use vortex_scalar::Scalar; @@ -143,11 +146,11 @@ pub fn runend_decode_primitive( offset: usize, length: usize, ) -> VortexResult { - match_each_native_ptype!(values.ptype(), |$P| { + match_each_native_simd_ptype!(values.ptype(), |$V| { match_each_integer_ptype!(ends.ptype(), |$E| { - runend_decode_typed_primitive( + runend_decode_typed_primitive::<$V, 64>( trimmed_ends_iter(ends.as_slice::<$E>(), offset, length), - values.as_slice::<$P>(), + values.as_slice::<$V>(), values.validity_mask()?, values.dtype().nullability(), length, @@ -173,71 +176,46 @@ pub fn runend_decode_bools( }) } -use std::arch::aarch64::*; +use std::simd; -pub fn runend_decode_typed_primitive( +use vortex_buffer::Alignment; + +pub(crate) fn runend_decode_typed_primitive( run_ends: impl Iterator, values: &[T], values_validity: Mask, values_nullability: Nullability, length: usize, -) -> VortexResult { +) -> VortexResult +where + T: simd::SimdElement + NativePType, + simd::LaneCount: simd::SupportedLaneCount, +{ Ok(match values_validity { Mask::AllTrue(_) => { - #[cfg(target_arch = "aarch64")] - { - let mut decoded: BufferMut = BufferMut::with_capacity(length + 32); - let mut current_pos = 0; - - macro_rules! neon_fill_run { - ($type:ty, $neon_dup:ident, $neon_store:ident, $vector_size:expr) => { - for (end, &value) in run_ends.zip_eq(values) { - unsafe { - let neon_val = $neon_dup(*((&value as *const T) as *const $type)); - - while current_pos < end { - let dst = decoded.as_mut_ptr().add(current_pos) as *mut $type; - $neon_store(dst, neon_val); - current_pos += $vector_size; - } - - current_pos = end; - } - } - }; - } - - match (size_of::(), std::any::type_name::()) { - (8, "f64") => neon_fill_run!(f64, vdupq_n_f64, vst1q_f64, 2), - (4, "f32") => neon_fill_run!(f32, vdupq_n_f32, vst1q_f32, 4), - - (8, "i64") => neon_fill_run!(i64, vdupq_n_s64, vst1q_s64, 2), - (4, "i32") => neon_fill_run!(i32, vdupq_n_s32, vst1q_s32, 4), - (2, "i16") => neon_fill_run!(i16, vdupq_n_s16, vst1q_s16, 8), - (1, "i8") => neon_fill_run!(i8, vdupq_n_s8, vst1q_s8, 16), - - (8, "u64") => neon_fill_run!(u64, vdupq_n_u64, vst1q_u64, 2), - (4, "u32") => neon_fill_run!(u32, vdupq_n_u32, vst1q_u32, 4), - (2, "u16") => neon_fill_run!(u16, vdupq_n_u16, vst1q_u16, 8), - (1, "u8") => neon_fill_run!(u8, vdupq_n_u8, vst1q_u8, 16), - - _ => { - // Fallback - for (end, &value) in run_ends.zip_eq(values) { - while current_pos < end { - decoded.push(value); - current_pos += 1; - } - } + let mut decoded = BufferMut::::with_capacity_aligned( + length.next_multiple_of(size_of::>() * 2), + Alignment::of::>(), + ); + let mask = simd::Mask::from_bitmask(u64::MAX); + let mut current_pos = 0; + + for (end, &value) in run_ends.zip_eq(values) { + let simd_val = simd::Simd::::splat(value); + while current_pos <= end { + unsafe { + simd_val.store_select_ptr(decoded.as_mut_ptr().add(current_pos), mask); } + current_pos += LANE_COUNT; } + current_pos = end; + } - unsafe { - decoded.set_len(length); - } - - PrimitiveArray::new(decoded, values_nullability.into()) + unsafe { + decoded.set_len(length); } + + PrimitiveArray::new(decoded.freeze(), values_nullability.into()) } Mask::AllFalse(_) => PrimitiveArray::new(Buffer::::zeroed(length), Validity::AllInvalid), Mask::Values(mask) => { diff --git a/encodings/runend/src/lib.rs b/encodings/runend/src/lib.rs index d4a77c993c..9f1c1a8f0b 100644 --- a/encodings/runend/src/lib.rs +++ b/encodings/runend/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(portable_simd)] pub use array::*; mod array;