Skip to content

Commit

Permalink
feat: use portable simd
Browse files Browse the repository at this point in the history
  • Loading branch information
0ax1 committed Feb 16, 2025
1 parent 16e6902 commit 1a3a5d3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 57 deletions.
92 changes: 35 additions & 57 deletions encodings/runend/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{Array, IntoArray, IntoArrayVariant};
use vortex_buffer::{buffer, Buffer, BufferMut};
use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType, Nullability};
use vortex_dtype::{
match_each_integer_ptype, match_each_native_ptype, match_each_native_simd_ptype, NativePType,
Nullability,
};
use vortex_error::{VortexExpect, VortexResult};
use vortex_mask::Mask;
use vortex_scalar::Scalar;
Expand Down Expand Up @@ -143,11 +146,11 @@ pub fn runend_decode_primitive(
offset: usize,
length: usize,
) -> VortexResult<PrimitiveArray> {
match_each_native_ptype!(values.ptype(), |$P| {
match_each_native_simd_ptype!(values.ptype(), |$V| {
match_each_integer_ptype!(ends.ptype(), |$E| {
runend_decode_typed_primitive(
runend_decode_typed_primitive::<$V, 64>(
trimmed_ends_iter(ends.as_slice::<$E>(), offset, length),
values.as_slice::<$P>(),
values.as_slice::<$V>(),
values.validity_mask()?,
values.dtype().nullability(),
length,
Expand All @@ -173,71 +176,46 @@ pub fn runend_decode_bools(
})
}

use std::arch::aarch64::*;
use std::simd;

pub fn runend_decode_typed_primitive<T: NativePType>(
use vortex_buffer::Alignment;

pub(crate) fn runend_decode_typed_primitive<T, const LANE_COUNT: usize>(
run_ends: impl Iterator<Item = usize>,
values: &[T],
values_validity: Mask,
values_nullability: Nullability,
length: usize,
) -> VortexResult<PrimitiveArray> {
) -> VortexResult<PrimitiveArray>
where
T: simd::SimdElement + NativePType,
simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
{
Ok(match values_validity {
Mask::AllTrue(_) => {
#[cfg(target_arch = "aarch64")]
{
let mut decoded: BufferMut<T> = BufferMut::with_capacity(length + 32);
let mut current_pos = 0;

macro_rules! neon_fill_run {
($type:ty, $neon_dup:ident, $neon_store:ident, $vector_size:expr) => {
for (end, &value) in run_ends.zip_eq(values) {
unsafe {
let neon_val = $neon_dup(*((&value as *const T) as *const $type));

while current_pos < end {
let dst = decoded.as_mut_ptr().add(current_pos) as *mut $type;
$neon_store(dst, neon_val);
current_pos += $vector_size;
}

current_pos = end;
}
}
};
}

match (size_of::<T>(), std::any::type_name::<T>()) {
(8, "f64") => neon_fill_run!(f64, vdupq_n_f64, vst1q_f64, 2),
(4, "f32") => neon_fill_run!(f32, vdupq_n_f32, vst1q_f32, 4),

(8, "i64") => neon_fill_run!(i64, vdupq_n_s64, vst1q_s64, 2),
(4, "i32") => neon_fill_run!(i32, vdupq_n_s32, vst1q_s32, 4),
(2, "i16") => neon_fill_run!(i16, vdupq_n_s16, vst1q_s16, 8),
(1, "i8") => neon_fill_run!(i8, vdupq_n_s8, vst1q_s8, 16),

(8, "u64") => neon_fill_run!(u64, vdupq_n_u64, vst1q_u64, 2),
(4, "u32") => neon_fill_run!(u32, vdupq_n_u32, vst1q_u32, 4),
(2, "u16") => neon_fill_run!(u16, vdupq_n_u16, vst1q_u16, 8),
(1, "u8") => neon_fill_run!(u8, vdupq_n_u8, vst1q_u8, 16),

_ => {
// Fallback
for (end, &value) in run_ends.zip_eq(values) {
while current_pos < end {
decoded.push(value);
current_pos += 1;
}
}
let mut decoded = BufferMut::<T>::with_capacity_aligned(
length.next_multiple_of(size_of::<simd::Simd<T, LANE_COUNT>>() * 2),
Alignment::of::<simd::Simd<T, LANE_COUNT>>(),
);
let mask = simd::Mask::from_bitmask(u64::MAX);
let mut current_pos = 0;

for (end, &value) in run_ends.zip_eq(values) {
let simd_val = simd::Simd::<T, LANE_COUNT>::splat(value);
while current_pos <= end {
unsafe {
simd_val.store_select_ptr(decoded.as_mut_ptr().add(current_pos), mask);
}
current_pos += LANE_COUNT;
}
current_pos = end;
}

unsafe {
decoded.set_len(length);
}

PrimitiveArray::new(decoded, values_nullability.into())
unsafe {
decoded.set_len(length);
}

PrimitiveArray::new(decoded.freeze(), values_nullability.into())
}
Mask::AllFalse(_) => PrimitiveArray::new(Buffer::<T>::zeroed(length), Validity::AllInvalid),
Mask::Values(mask) => {
Expand Down
1 change: 1 addition & 0 deletions encodings/runend/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(portable_simd)]
pub use array::*;

mod array;
Expand Down

0 comments on commit 1a3a5d3

Please sign in to comment.