Skip to content

Commit

Permalink
Optimize is_ascii for str and [u8] further
Browse files Browse the repository at this point in the history
Replace the existing optimized function with one that enables
use of vector instructions.
This is especially beneficial on x86-64 as `pmovmskb` can be
emitted with careful structuring of the code. The instruction
can detect non-ASCII characters a vector register width at a time
instead of the current `usize` at a time check.
This results in a completely safe implementation.

Remove previous implementation's alignment test
  • Loading branch information
okaneco committed Sep 23, 2024
1 parent 8caa7d6 commit bc78676
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 152 deletions.
127 changes: 29 additions & 98 deletions library/core/src/slice/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use core::ascii::EscapeDefault;

use crate::fmt::{self, Write};
use crate::{ascii, iter, mem, ops};
use crate::{ascii, iter, ops};

#[cfg(not(test))]
impl [u8] {
Expand Down Expand Up @@ -297,14 +297,6 @@ impl<'a> fmt::Debug for EscapeAscii<'a> {
}
}

/// Returns `true` if any byte in the word `v` is nonascii (>= 128). Snarfed
/// from `../str/mod.rs`, which does something similar for utf8 validation.
#[inline]
const fn contains_nonascii(v: usize) -> bool {
const NONASCII_MASK: usize = usize::repeat_u8(0x80);
(NONASCII_MASK & v) != 0
}

/// ASCII test *without* the chunk-at-a-time optimizations.
///
/// This is carefully structured to produce nice small code -- it's smaller in
Expand All @@ -323,100 +315,39 @@ pub const fn is_ascii_simple(mut bytes: &[u8]) -> bool {
bytes.is_empty()
}

/// Optimized ASCII test that will use usize-at-a-time operations instead of
/// byte-at-a-time operations (when possible).
///
/// The algorithm we use here is pretty simple. If `s` is too short, we just
/// check each byte and be done with it. Otherwise:
///
/// - Read the first word with an unaligned load.
/// - Align the pointer, read subsequent words until end with aligned loads.
/// - Read the last `usize` from `s` with an unaligned load.
///
/// If any of these loads produces something for which `contains_nonascii`
/// (above) returns true, then we know the answer is false.
#[inline]
const fn is_ascii(s: &[u8]) -> bool {
const USIZE_SIZE: usize = mem::size_of::<usize>();

let len = s.len();
let align_offset = s.as_ptr().align_offset(USIZE_SIZE);

// If we wouldn't gain anything from the word-at-a-time implementation, fall
// back to a scalar loop.
//
// We also do this for architectures where `size_of::<usize>()` isn't
// sufficient alignment for `usize`, because it's a weird edge case.
if len < USIZE_SIZE || len < align_offset || USIZE_SIZE < mem::align_of::<usize>() {
return is_ascii_simple(s);
}

// We always read the first word unaligned, which means `align_offset` is
// 0, we'd read the same value again for the aligned read.
let offset_to_aligned = if align_offset == 0 { USIZE_SIZE } else { align_offset };

let start = s.as_ptr();
// SAFETY: We verify `len < USIZE_SIZE` above.
let first_word = unsafe { (start as *const usize).read_unaligned() };

if contains_nonascii(first_word) {
return false;
}
// We checked this above, somewhat implicitly. Note that `offset_to_aligned`
// is either `align_offset` or `USIZE_SIZE`, both of are explicitly checked
// above.
debug_assert!(offset_to_aligned <= len);

// SAFETY: word_ptr is the (properly aligned) usize ptr we use to read the
// middle chunk of the slice.
let mut word_ptr = unsafe { start.add(offset_to_aligned) as *const usize };

// `byte_pos` is the byte index of `word_ptr`, used for loop end checks.
let mut byte_pos = offset_to_aligned;

// Paranoia check about alignment, since we're about to do a bunch of
// unaligned loads. In practice this should be impossible barring a bug in
// `align_offset` though.
// While this method is allowed to spuriously fail in CTFE, if it doesn't
// have alignment information it should have given a `usize::MAX` for
// `align_offset` earlier, sending things through the scalar path instead of
// this one, so this check should pass if it's reachable.
debug_assert!(word_ptr.is_aligned_to(mem::align_of::<usize>()));

// Read subsequent words until the last aligned word, excluding the last
// aligned word by itself to be done in tail check later, to ensure that
// tail is always one `usize` at most to extra branch `byte_pos == len`.
while byte_pos < len - USIZE_SIZE {
// Sanity check that the read is in bounds
debug_assert!(byte_pos + USIZE_SIZE <= len);
// And that our assumptions about `byte_pos` hold.
debug_assert!(matches!(
word_ptr.cast::<u8>().guaranteed_eq(start.wrapping_add(byte_pos)),
// These are from the same allocation, so will hopefully always be
// known to match even in CTFE, but if it refuses to compare them
// that's ok since it's just a debug check anyway.
None | Some(true),
));
const fn is_ascii(bytes: &[u8]) -> bool {
// Constant chosen to enable `pmovmskb` instruction on x86-64
const N: usize = 32;

let mut i = 0;

while i + N <= bytes.len() {
let chunk_end = i + N;

// Get LLVM to produce a `pmovmskb` instruction on x86-64 which
// creates a mask from the most significant bit of each byte.
// ASCII bytes are less than 128 (0x80), so their most significant
// bit is unset. Thus, detecting non-ASCII bytes can be done in one
// instruction.
let mut count = 0;
while i < chunk_end {
count += (bytes[i] <= 127) as u8;
i += 1;
}

// SAFETY: We know `word_ptr` is properly aligned (because of
// `align_offset`), and we know that we have enough bytes between `word_ptr` and the end
let word = unsafe { word_ptr.read() };
if contains_nonascii(word) {
// All bytes should be <= 127 so count is equal to chunk size.
if count != N as u8 {
return false;
}

byte_pos += USIZE_SIZE;
// SAFETY: We know that `byte_pos <= len - USIZE_SIZE`, which means that
// after this `add`, `word_ptr` will be at most one-past-the-end.
word_ptr = unsafe { word_ptr.add(1) };
}

// Sanity check to ensure there really is only one `usize` left. This should
// be guaranteed by our loop condition.
debug_assert!(byte_pos <= len && len - byte_pos <= USIZE_SIZE);

// SAFETY: This relies on `len >= USIZE_SIZE`, which we check at the start.
let last_word = unsafe { (start.add(len - USIZE_SIZE) as *const usize).read_unaligned() };
// Process the remaining `bytes.len() % N` bytes.
let mut is_ascii = true;
while i < bytes.len() {
is_ascii &= bytes[i] <= 127;
i += 1;
}

!contains_nonascii(last_word)
is_ascii
}
54 changes: 0 additions & 54 deletions library/core/tests/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,60 +361,6 @@ fn test_is_ascii_control() {
);
}

// `is_ascii` does a good amount of pointer manipulation and has
// alignment-dependent computation. This is all sanity-checked via
// `debug_assert!`s, so we test various sizes/alignments thoroughly versus an
// "obviously correct" baseline function.
#[test]
fn test_is_ascii_align_size_thoroughly() {
// The "obviously-correct" baseline mentioned above.
fn is_ascii_baseline(s: &[u8]) -> bool {
s.iter().all(|b| b.is_ascii())
}

// Helper to repeat `l` copies of `b0` followed by `l` copies of `b1`.
fn repeat_concat(b0: u8, b1: u8, l: usize) -> Vec<u8> {
use core::iter::repeat;
repeat(b0).take(l).chain(repeat(b1).take(l)).collect()
}

// Miri is too slow
let iter = if cfg!(miri) { 0..20 } else { 0..100 };

for i in iter {
#[cfg(not(miri))]
let cases = &[
b"a".repeat(i),
b"\0".repeat(i),
b"\x7f".repeat(i),
b"\x80".repeat(i),
b"\xff".repeat(i),
repeat_concat(b'a', 0x80u8, i),
repeat_concat(0x80u8, b'a', i),
];

#[cfg(miri)]
let cases = &[b"a".repeat(i), b"\x80".repeat(i), repeat_concat(b'a', 0x80u8, i)];

for case in cases {
for pos in 0..=case.len() {
// Potentially misaligned head
let prefix = &case[pos..];
assert_eq!(is_ascii_baseline(prefix), prefix.is_ascii(),);

// Potentially misaligned tail
let suffix = &case[..case.len() - pos];

assert_eq!(is_ascii_baseline(suffix), suffix.is_ascii(),);

// Both head and tail are potentially misaligned
let mid = &case[(pos / 2)..(case.len() - (pos / 2))];
assert_eq!(is_ascii_baseline(mid), mid.is_ascii(),);
}
}
}
}

#[test]
fn ascii_const() {
// test that the `is_ascii` methods of `char` and `u8` are usable in a const context
Expand Down

0 comments on commit bc78676

Please sign in to comment.