Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize is_ascii for str and [u8] further #130733

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions library/core/benches/ascii/is_ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ macro_rules! benches {
// Ensure we benchmark cases where the functions are called with strings
// that are not perfectly aligned or have a length which is not a
// multiple of size_of::<usize>() (or both)
benches!(mod unaligned_head MEDIUM[1..] $($name $arg $body)+);
benches!(mod unaligned_tail MEDIUM[..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both MEDIUM[1..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_head_medium MEDIUM[1..] $($name $arg $body)+);
benches!(mod unaligned_tail_medium MEDIUM[..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both_medium MEDIUM[1..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_head_long LONG[1..] $($name $arg $body)+);
benches!(mod unaligned_tail_long LONG[..(LONG.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both_long LONG[1..(LONG.len() - 1)] $($name $arg $body)+);
};

(mod $mod_name: ident $input: ident [$range: expr] $($name: ident $arg: ident $body: block)+) => {
Expand Down Expand Up @@ -49,6 +52,44 @@ benches! {
fn case03_align_to_unrolled(bytes: &[u8]) {
is_ascii_align_to_unrolled(bytes)
}

fn case04_while_loop(bytes: &[u8]) {
// Process chunks of 32 bytes at a time in the fast path to enable
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
// can be OR'd together and then the resulting vector can be tested for
// non-ASCII bytes.
const CHUNK_SIZE: usize = 32;

let mut i = 0;

while i + CHUNK_SIZE <= bytes.len() {
let chunk_end = i + CHUNK_SIZE;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset.
let mut count = 0;
while i < chunk_end {
count += bytes[i].is_ascii() as u8;
i += 1;
}

// All bytes should be <= 127 so count is equal to chunk size.
if count != CHUNK_SIZE as u8 {
return false;
}
}

// Process the remaining `bytes.len() % N` bytes.
let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i].is_ascii();
i += 1;
}

is_ascii
}
}

// These are separate since it's easier to debug errors if they don't go through
Expand Down
70 changes: 58 additions & 12 deletions library/core/src/slice/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
use core::ascii::EscapeDefault;

use crate::fmt::{self, Write};
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
use crate::intrinsics::const_eval_select;
use crate::{ascii, iter, mem, ops};
use crate::{ascii, iter, ops};

#[cfg(not(test))]
impl [u8] {
Expand Down Expand Up @@ -308,14 +309,6 @@ impl<'a> fmt::Debug for EscapeAscii<'a> {
}
}

/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
#[inline]
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

/// ASCII test *without* the chunk-at-a-time optimizations.
///
/// This is carefully structured to produce nice small code -- it's smaller in
Expand Down Expand Up @@ -346,6 +339,7 @@ pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool {
///
/// If any of these loads produces something for which `contains_nonascii`
/// (above) returns true, then we know the answer is false.
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
#[inline]
#[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior
const fn is_ascii(s: &[u8]) -> bool {
Expand All @@ -356,7 +350,14 @@ const fn is_ascii(s: &[u8]) -> bool {
if const {
is_ascii_simple(s)
} else {
const USIZE_SIZE: usize = mem::size_of::<usize>();
/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

const USIZE_SIZE: usize = size_of::<usize>();

let len = s.len();
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);
Expand All @@ -366,7 +367,7 @@ const fn is_ascii(s: &[u8]) -> bool {
//
// We also do this for architectures where `size_of::<usize>()` isn't
// sufficient alignment for `usize`, because it's a weird edge case.
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < mem::align_of::<usize>() {
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < align_of::<usize>() {
return is_ascii_simple(s);
}

Expand Down Expand Up @@ -400,7 +401,7 @@ const fn is_ascii(s: &[u8]) -> bool {
// have alignment information it should have given a `usize::MAX` for
// `align_offset` earlier, sending things through the scalar path instead of
// this one, so this check should pass if it's reachable.
debug_assert!(word_ptr.is_aligned_to(mem::align_of::<usize>()));
debug_assert!(word_ptr.is_aligned_to(align_of::<usize>()));

// Read subsequent words until the last aligned word, excluding the last
// aligned word by itself to be done in tail check later, to ensure that
Expand Down Expand Up @@ -435,3 +436,48 @@ const fn is_ascii(s: &[u8]) -> bool {
}
)
}

/// ASCII test optimized to use the `pmovmskb` instruction available on `x86-64`
/// platforms.
///
/// Other platforms are not likely to benefit from this code structure, so they
/// use SWAR techniques to test for ASCII in `usize`-sized chunks.
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
#[inline]
const fn is_ascii(bytes: &[u8]) -> bool {
// Process chunks of 32 bytes at a time in the fast path to enable
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
// can be OR'd together and then the resulting vector can be tested for
// non-ASCII bytes.
const CHUNK_SIZE: usize = 32;

let mut i = 0;

while i + CHUNK_SIZE <= bytes.len() {
let chunk_end = i + CHUNK_SIZE;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset.
let mut count = 0;
while i < chunk_end {
count += bytes[i].is_ascii() as u8;
i += 1;
}

// All bytes should be <= 127 so count is equal to chunk size.
if count != CHUNK_SIZE as u8 {
return false;
}
}

// Process the remaining `bytes.len() % N` bytes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My brain sees this and things as_chunks::<CHUNK_SIZE>(), but I guess that's not really worth doing right now when we can't loop over slice iterators in const fn anyway. Maybe in the Glorious Const Future™.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just call is_ascii_simple in const, so there's no reason to do any const-specific adjustments IMO.

let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i].is_ascii();
i += 1;
}

is_ascii
}
16 changes: 16 additions & 0 deletions tests/codegen/slice-is-ascii.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ only-x86_64
//@ compile-flags: -C opt-level=3
#![crate_type = "lib"]

/// Check that the fast-path of `is_ascii` uses a `pmovmskb` instruction.
/// Platforms lacking an equivalent instruction use other techniques for
/// optimizing `is_ascii`.
// CHECK-LABEL: @is_ascii_autovectorized
#[no_mangle]
pub fn is_ascii_autovectorized(s: &[u8]) -> bool {
// CHECK: load <32 x i8>
// CHECK-NEXT: icmp slt <32 x i8>
// CHECK-NEXT: bitcast <32 x i1>
// CHECK-NEXT: icmp eq i32
Comment on lines +12 to +14
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sure makes me tempted to write it with portable simd, but I don't know where we are on that just yet, so probably not something to do this PR.

s.is_ascii()
}
Loading