Skip to content

Commit

Permalink
Auto merge of #130733 - okaneco:is_ascii, r=scottmcm
Browse files Browse the repository at this point in the history
Optimize `is_ascii` for `str` and `[u8]` further

Replace the existing optimized function with one that enables auto-vectorization.

This is especially beneficial on x86-64 as `pmovmskb` can be emitted with careful structuring of the code. The instruction can detect non-ASCII characters one vector register width at a time instead of the current `usize` at a time check.

The resulting implementation is completely safe.

`case00_libcore` is the current implementation, `case04_while_loop` is this PR.
```
benchmarks:
    ascii::is_ascii_slice::long::case00_libcore                             22.25/iter  +/- 1.09
    ascii::is_ascii_slice::long::case04_while_loop                           6.78/iter  +/- 0.92
    ascii::is_ascii_slice::medium::case00_libcore                            2.81/iter  +/- 0.39
    ascii::is_ascii_slice::medium::case04_while_loop                         1.56/iter  +/- 0.78
    ascii::is_ascii_slice::short::case00_libcore                             5.55/iter  +/- 0.85
    ascii::is_ascii_slice::short::case04_while_loop                          3.75/iter  +/- 0.22
    ascii::is_ascii_slice::unaligned_both_long::case00_libcore              26.59/iter  +/- 0.66
    ascii::is_ascii_slice::unaligned_both_long::case04_while_loop            5.78/iter  +/- 0.16
    ascii::is_ascii_slice::unaligned_both_medium::case00_libcore             2.97/iter  +/- 0.32
    ascii::is_ascii_slice::unaligned_both_medium::case04_while_loop          2.41/iter  +/- 0.10
    ascii::is_ascii_slice::unaligned_head_long::case00_libcore              23.71/iter  +/- 0.79
    ascii::is_ascii_slice::unaligned_head_long::case04_while_loop            7.83/iter  +/- 1.31
    ascii::is_ascii_slice::unaligned_head_medium::case00_libcore             3.69/iter  +/- 0.54
    ascii::is_ascii_slice::unaligned_head_medium::case04_while_loop          7.05/iter  +/- 0.32
    ascii::is_ascii_slice::unaligned_tail_long::case00_libcore              24.44/iter  +/- 1.41
    ascii::is_ascii_slice::unaligned_tail_long::case04_while_loop            5.12/iter  +/- 0.18
    ascii::is_ascii_slice::unaligned_tail_medium::case00_libcore             3.24/iter  +/- 0.40
    ascii::is_ascii_slice::unaligned_tail_medium::case04_while_loop          2.86/iter  +/- 0.14

```

`unaligned_head_medium` is the main regression in the benchmarks. It is a 32 byte string being sliced `bytes[1..]`.

The first commit can be used to run the benchmarks against the current core implementation.

Previous implementation was done in #74066

---

Two potential drawbacks of this implementation are that it increases instruction count and may regress other platforms/architectures. The benches here may also be too artificial to glean much insight from.
https://rust.godbolt.org/z/G9znGfY36
  • Loading branch information
bors committed Dec 22, 2024
2 parents 00bf74d + 1b5c02b commit c113247
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 15 deletions.
47 changes: 44 additions & 3 deletions library/core/benches/ascii/is_ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ macro_rules! benches {
// Ensure we benchmark cases where the functions are called with strings
// that are not perfectly aligned or have a length which is not a
// multiple of size_of::<usize>() (or both)
benches!(mod unaligned_head MEDIUM[1..] $($name $arg $body)+);
benches!(mod unaligned_tail MEDIUM[..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both MEDIUM[1..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_head_medium MEDIUM[1..] $($name $arg $body)+);
benches!(mod unaligned_tail_medium MEDIUM[..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both_medium MEDIUM[1..(MEDIUM.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_head_long LONG[1..] $($name $arg $body)+);
benches!(mod unaligned_tail_long LONG[..(LONG.len() - 1)] $($name $arg $body)+);
benches!(mod unaligned_both_long LONG[1..(LONG.len() - 1)] $($name $arg $body)+);
};

(mod $mod_name: ident $input: ident [$range: expr] $($name: ident $arg: ident $body: block)+) => {
Expand Down Expand Up @@ -49,6 +52,44 @@ benches! {
fn case03_align_to_unrolled(bytes: &[u8]) {
is_ascii_align_to_unrolled(bytes)
}

fn case04_while_loop(bytes: &[u8]) {
// Process chunks of 32 bytes at a time in the fast path to enable
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
// can be OR'd together and then the resulting vector can be tested for
// non-ASCII bytes.
const CHUNK_SIZE: usize = 32;

let mut i = 0;

while i + CHUNK_SIZE <= bytes.len() {
let chunk_end = i + CHUNK_SIZE;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset.
let mut count = 0;
while i < chunk_end {
count += bytes[i].is_ascii() as u8;
i += 1;
}

// All bytes should be <= 127 so count is equal to chunk size.
if count != CHUNK_SIZE as u8 {
return false;
}
}

// Process the remaining `bytes.len() % N` bytes.
let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i].is_ascii();
i += 1;
}

is_ascii
}
}

// These are separate since it's easier to debug errors if they don't go through
Expand Down
70 changes: 58 additions & 12 deletions library/core/src/slice/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
use core::ascii::EscapeDefault;

use crate::fmt::{self, Write};
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
use crate::intrinsics::const_eval_select;
use crate::{ascii, iter, mem, ops};
use crate::{ascii, iter, ops};

#[cfg(not(test))]
impl [u8] {
Expand Down Expand Up @@ -328,14 +329,6 @@ impl<'a> fmt::Debug for EscapeAscii<'a> {
}
}

/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
#[inline]
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

/// ASCII test *without* the chunk-at-a-time optimizations.
///
/// This is carefully structured to produce nice small code -- it's smaller in
Expand Down Expand Up @@ -366,6 +359,7 @@ pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool {
///
/// If any of these loads produces something for which `contains_nonascii`
/// (above) returns true, then we know the answer is false.
#[cfg(not(all(target_arch = "x86_64", target_feature = "sse2")))]
#[inline]
#[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior
const fn is_ascii(s: &[u8]) -> bool {
Expand All @@ -376,7 +370,14 @@ const fn is_ascii(s: &[u8]) -> bool {
if const {
is_ascii_simple(s)
} else {
const USIZE_SIZE: usize = mem::size_of::<usize>();
/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

const USIZE_SIZE: usize = size_of::<usize>();

let len = s.len();
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);
Expand All @@ -386,7 +387,7 @@ const fn is_ascii(s: &[u8]) -> bool {
//
// We also do this for architectures where `size_of::<usize>()` isn't
// sufficient alignment for `usize`, because it's a weird edge case.
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < mem::align_of::<usize>() {
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < align_of::<usize>() {
return is_ascii_simple(s);
}

Expand Down Expand Up @@ -420,7 +421,7 @@ const fn is_ascii(s: &[u8]) -> bool {
// have alignment information it should have given a `usize::MAX` for
// `align_offset` earlier, sending things through the scalar path instead of
// this one, so this check should pass if it's reachable.
debug_assert!(word_ptr.is_aligned_to(mem::align_of::<usize>()));
debug_assert!(word_ptr.is_aligned_to(align_of::<usize>()));

// Read subsequent words until the last aligned word, excluding the last
// aligned word by itself to be done in tail check later, to ensure that
Expand Down Expand Up @@ -455,3 +456,48 @@ const fn is_ascii(s: &[u8]) -> bool {
}
)
}

/// ASCII test optimized to use the `pmovmskb` instruction available on `x86-64`
/// platforms.
///
/// Other platforms are not likely to benefit from this code structure, so they
/// use SWAR techniques to test for ASCII in `usize`-sized chunks.
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
#[inline]
const fn is_ascii(bytes: &[u8]) -> bool {
// Process chunks of 32 bytes at a time in the fast path to enable
// auto-vectorization and use of `pmovmskb`. Two 128-bit vector registers
// can be OR'd together and then the resulting vector can be tested for
// non-ASCII bytes.
const CHUNK_SIZE: usize = 32;

let mut i = 0;

while i + CHUNK_SIZE <= bytes.len() {
let chunk_end = i + CHUNK_SIZE;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset.
let mut count = 0;
while i < chunk_end {
count += bytes[i].is_ascii() as u8;
i += 1;
}

// All bytes should be <= 127 so count is equal to chunk size.
if count != CHUNK_SIZE as u8 {
return false;
}
}

// Process the remaining `bytes.len() % N` bytes.
let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i].is_ascii();
i += 1;
}

is_ascii
}
16 changes: 16 additions & 0 deletions tests/codegen/slice-is-ascii.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ only-x86_64
//@ compile-flags: -C opt-level=3
#![crate_type = "lib"]

/// Check that the fast-path of `is_ascii` uses a `pmovmskb` instruction.
/// Platforms lacking an equivalent instruction use other techniques for
/// optimizing `is_ascii`.
// CHECK-LABEL: @is_ascii_autovectorized
#[no_mangle]
pub fn is_ascii_autovectorized(s: &[u8]) -> bool {
// CHECK: load <32 x i8>
// CHECK-NEXT: icmp slt <32 x i8>
// CHECK-NEXT: bitcast <32 x i1>
// CHECK-NEXT: icmp eq i32
s.is_ascii()
}

0 comments on commit c113247

Please sign in to comment.