From 62ae19b5ac61d4fd7be06859fa9758e32a19da2e Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 7 Feb 2025 08:11:42 -0500 Subject: [PATCH 01/10] Rewrite UTF-8 validation in shift-based DFA This gives plenty of performance increase on validating strings with many non-ASCII codepoints, which is the normal case for almost every non-English content. Shift-based DFA algorithm does not use SIMD instructions and does not rely on the branch predictor to get a good performance, thus is good as a general, default, architecture-agnostic implementation. There is still a bypass for ASCII-only strings to benefits from auto-vectorization, if the target supports. We use z3 to find a state mapping that only need u32 transition table. This shrinks the table size to 1KiB comparing a u64 states for less cache pressure, and produces faster code on platforms that only support 32-bit shift. Though, it does not affect the throughput on 64-bit platforms when the table is already fully in cache. --- library/core/src/str/solve_dfa.py | 78 +++++++ library/core/src/str/validations.rs | 328 ++++++++++++++++++---------- 2 files changed, 286 insertions(+), 120 deletions(-) create mode 100755 library/core/src/str/solve_dfa.py diff --git a/library/core/src/str/solve_dfa.py b/library/core/src/str/solve_dfa.py new file mode 100755 index 0000000000000..d10767ba9c83d --- /dev/null +++ b/library/core/src/str/solve_dfa.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Use z3 to solve UTF-8 validation DFA for offset and transition table, +# in order to encode transition table into u32. +# We minimize the output variables in the solution to make it deterministic. +# Ref: +# See more detail explanation in `./validations.rs`. +# +# It is expected to find a solution in <30s on a modern machine, and the +# solution is appended to the end of this file. +from z3 import * + +STATE_CNT = 9 + +# The transition table. +# A value X on column Y means state Y should transition to state X on some +# input bytes. We assign state 0 as ERROR and state 1 as ACCEPT (initial). +# Eg. first line: for input byte 00..=7F, transition S1 -> S1, others -> S0. +TRANSITIONS = [ + # 0 1 2 3 4 5 6 7 8 + # First bytes + ((0, 1, 0, 0, 0, 0, 0, 0, 0), "00-7F"), + ((0, 2, 0, 0, 0, 0, 0, 0, 0), "C2-DF"), + ((0, 3, 0, 0, 0, 0, 0, 0, 0), "E0"), + ((0, 4, 0, 0, 0, 0, 0, 0, 0), "E1-EC, EE-EF"), + ((0, 5, 0, 0, 0, 0, 0, 0, 0), "ED"), + ((0, 6, 0, 0, 0, 0, 0, 0, 0), "F0"), + ((0, 7, 0, 0, 0, 0, 0, 0, 0), "F1-F3"), + ((0, 8, 0, 0, 0, 0, 0, 0, 0), "F4"), + # Continuation bytes + ((0, 0, 1, 0, 2, 2, 0, 4, 4), "80-8F"), + ((0, 0, 1, 0, 2, 2, 4, 4, 0), "90-9F"), + ((0, 0, 1, 2, 2, 0, 4, 4, 0), "A0-BF"), + # Illegal + ((0, 0, 0, 0, 0, 0, 0, 0, 0), "C0-C1, F5-FF"), +] + +o = Optimize() +offsets = [BitVec(f"o{i}", 32) for i in range(STATE_CNT)] +trans_table = [BitVec(f"t{i}", 32) for i in range(len(TRANSITIONS))] + +# Add some guiding constraints to make solving faster. +o.add(offsets[0] == 0) +o.add(trans_table[-1] == 0) + +for i in range(len(offsets)): + # Do not over-shift. It's not necessary but makes solving faster. + o.add(offsets[i] < 32 - 5) + for j in range(i): + o.add(offsets[i] != offsets[j]) +for trans, (targets, _) in zip(trans_table, TRANSITIONS): + for src, tgt in enumerate(targets): + o.add((LShR(trans, offsets[src]) & 31) == offsets[tgt]) + +# Minimize ordered outputs to get a unique solution. +goal = Concat(*offsets, *trans_table) +o.minimize(goal) +print(o.check()) +print("Offset[]= ", [o.model()[i].as_long() for i in offsets]) +print("Transitions:") +for (_, label), v in zip(TRANSITIONS, [o.model()[i].as_long() for i in trans_table]): + print(f"{label:14} => {v:#10x}, // {v:032b}") + +# Output should be deterministic: +# sat +# Offset[]= [0, 6, 16, 19, 1, 25, 11, 18, 24] +# Transitions: +# 00-7F => 0x180, // 00000000000000000000000110000000 +# C2-DF => 0x400, // 00000000000000000000010000000000 +# E0 => 0x4c0, // 00000000000000000000010011000000 +# E1-EC, EE-EF => 0x40, // 00000000000000000000000001000000 +# ED => 0x640, // 00000000000000000000011001000000 +# F0 => 0x2c0, // 00000000000000000000001011000000 +# F1-F3 => 0x480, // 00000000000000000000010010000000 +# F4 => 0x600, // 00000000000000000000011000000000 +# 80-8F => 0x21060020, // 00100001000001100000000000100000 +# 90-9F => 0x20060820, // 00100000000001100000100000100000 +# A0-BF => 0x860820, // 00000000100001100000100000100000 +# C0-C1, F5-FF => 0x0, // 00000000000000000000000000000000 diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 8174e4ff97dfc..aa5bd19896266 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -1,7 +1,7 @@ //! Operations related to UTF-8 validation. use super::Utf8Error; -use crate::intrinsics::const_eval_select; +use crate::intrinsics::{const_eval_select, unlikely}; /// Returns the initial codepoint accumulator for the first byte. /// The first byte is special, only want bottom 5 bits for width 2, 4 bits @@ -111,140 +111,228 @@ where Some(ch) } -const NONASCII_MASK: usize = usize::repeat_u8(0x80); +// The shift-based DFA algorithm for UTF-8 validation. +// Ref: +// +// In short, we encode DFA transitions in an array `TRANS_TABLE` such that: +// ``` +// TRANS_TABLE[next_byte] = +// OFFSET[target_state1] << OFFSET[source_state1] | +// OFFSET[target_state2] << OFFSET[source_state2] | +// ... +// ``` +// Where `OFFSET[]` is a compile-time map from each state to a distinct 0..32 value. +// +// To execute the DFA: +// ``` +// let state = OFFSET[initial_state]; +// for byte in .. { +// state = TRANS_TABLE[byte] >> (state & ((1 << BITS_PER_STATE) - 1)); +// } +// ``` +// By choosing `BITS_PER_STATE = 5` and `state: u32`, we can replace the masking by `wrapping_shr` +// and it becomes free on modern ISAs, including x86, x86_64 and ARM. +// +// ``` +// // shrx state, qword ptr [table_addr + 8 * byte], state # On x86-64-v3 +// state = TRANS_TABLE[byte].wrapping_shr(state); +// ``` +// +// The DFA is directly derived from UTF-8 syntax from the RFC3629: +// . +// We assign S0 as ERROR and S1 as ACCEPT. DFA starts at S1. +// Syntax are annotated with DFA states in angle bracket as following: +// +// UTF8-char = (UTF8-1 / UTF8-2 / UTF8-3 / UTF8-4) +// UTF8-1 = %x00-7F +// UTF8-2 = %xC2-DF UTF8-tail +// UTF8-3 = %xE0 %xA0-BF UTF8-tail / +// (%xE1-EC / %xEE-EF) 2( UTF8-tail ) / +// %xED %x80-9F UTF8-tail +// UTF8-4 = %xF0 %x90-BF 2( UTF8-tail ) / +// %xF1-F3 UTF8-tail 2( UTF8-tail ) / +// %xF4 %x80-8F 2( UTF8-tail ) +// UTF8-tail = %x80-BF # Inlined into above usages. +// +// You may notice that encoding 9 states with 5bits per state into 32bit seems impossible, +// but we exploit overlapping bits to find a possible `OFFSET[]` and `TRANS_TABLE[]` solution. +// The SAT solver to find such (minimal) solution is in `./solve_dfa.py`. +// The solution is also appended to the end of that file and is verifiable. +const BITS_PER_STATE: u32 = 5; +const STATE_MASK: u32 = (1 << BITS_PER_STATE) - 1; +const STATE_CNT: usize = 9; +const ST_ERROR: u32 = OFFSETS[0]; +const ST_ACCEPT: u32 = OFFSETS[1]; +// See the end of `./solve_dfa.py`. +const OFFSETS: [u32; STATE_CNT] = [0, 6, 16, 19, 1, 25, 11, 18, 24]; -/// Returns `true` if any byte in the word `x` is nonascii (>= 128). +static TRANS_TABLE: [u32; 256] = { + let mut table = [0u32; 256]; + let mut b = 0; + while b < 256 { + // See the end of `./solve_dfa.py`. + table[b] = match b as u8 { + 0x00..=0x7F => 0x180, + 0xC2..=0xDF => 0x400, + 0xE0 => 0x4C0, + 0xE1..=0xEC | 0xEE..=0xEF => 0x40, + 0xED => 0x640, + 0xF0 => 0x2C0, + 0xF1..=0xF3 => 0x480, + 0xF4 => 0x600, + 0x80..=0x8F => 0x21060020, + 0x90..=0x9F => 0x20060820, + 0xA0..=0xBF => 0x860820, + 0xC0..=0xC1 | 0xF5..=0xFF => 0x0, + }; + b += 1; + } + table +}; + +#[inline(always)] +const fn next_state(st: u32, byte: u8) -> u32 { + TRANS_TABLE[byte as usize].wrapping_shr(st) +} + +/// Check if `byte` is a valid UTF-8 first byte, assuming it must be a valid first or +/// continuation byte. +#[inline(always)] +const fn is_utf8_first_byte(byte: u8) -> bool { + byte as i8 >= 0b1100_0000u8 as i8 +} + +/// # Safety +/// The caller must ensure `bytes[..i]` is a valid UTF-8 prefix and `st` is the DFA state after +/// executing on `bytes[..i]`. #[inline] -const fn contains_nonascii(x: usize) -> bool { - (x & NONASCII_MASK) != 0 +const unsafe fn resolve_error_location(st: u32, bytes: &[u8], i: usize) -> (usize, u8) { + // There are two cases: + // 1. [valid UTF-8..] | *here + // The previous state must be ACCEPT for the case 1, and `valid_up_to = i`. + // 2. [valid UTF-8..] | valid first byte, [valid continuation byte...], *here + // `valid_up_to` is at the latest non-continuation byte, which must exist and + // be in range `(i-3)..i`. + if st & STATE_MASK == ST_ACCEPT { + (i, 1) + // SAFETY: UTF-8 first byte must exist if we are in an intermediate state. + // We use pointer here because `get_unchecked` is not const fn. + } else if is_utf8_first_byte(unsafe { bytes.as_ptr().add(i - 1).read() }) { + (i - 1, 1) + // SAFETY: Same as above. + } else if is_utf8_first_byte(unsafe { bytes.as_ptr().add(i - 2).read() }) { + (i - 2, 2) + } else { + (i - 3, 3) + } +} + +// The simpler but slower algorithm to run DFA with error handling. +// +// # Safety +// The caller must ensure `bytes[..i]` is a valid UTF-8 prefix and `st` is the DFA state after +// executing on `bytes[..i]`. +const unsafe fn run_with_error_handling( + st: &mut u32, + bytes: &[u8], + mut i: usize, +) -> Result<(), Utf8Error> { + while i < bytes.len() { + let new_st = next_state(*st, bytes[i]); + if unlikely(new_st & STATE_MASK == ST_ERROR) { + // SAFETY: Guaranteed by the caller. + let (valid_up_to, error_len) = unsafe { resolve_error_location(*st, bytes, i) }; + return Err(Utf8Error { valid_up_to, error_len: Some(error_len) }); + } + *st = new_st; + i += 1; + } + Ok(()) } /// Walks through `v` checking that it's a valid UTF-8 sequence, /// returning `Ok(())` in that case, or, if it is invalid, `Err(err)`. #[inline(always)] #[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior -pub(super) const fn run_utf8_validation(v: &[u8]) -> Result<(), Utf8Error> { - let mut index = 0; - let len = v.len(); - - const USIZE_BYTES: usize = size_of::(); +pub(super) const fn run_utf8_validation(bytes: &[u8]) -> Result<(), Utf8Error> { + const_eval_select((bytes,), run_utf8_validation_const, run_utf8_validation_rt) +} - let ascii_block_size = 2 * USIZE_BYTES; - let blocks_end = if len >= ascii_block_size { len - ascii_block_size + 1 } else { 0 }; - // Below, we safely fall back to a slower codepath if the offset is `usize::MAX`, - // so the end-to-end behavior is the same at compiletime and runtime. - let align = const_eval_select!( - @capture { v: &[u8] } -> usize: - if const { - usize::MAX - } else { - v.as_ptr().align_offset(USIZE_BYTES) +#[inline] +const fn run_utf8_validation_const(bytes: &[u8]) -> Result<(), Utf8Error> { + let mut st = ST_ACCEPT; + // SAFETY: Start at empty string with valid state ACCEPT. + match unsafe { run_with_error_handling(&mut st, bytes, 0) } { + Err(err) => Err(err), + Ok(()) => { + if st & STATE_MASK == ST_ACCEPT { + Ok(()) + } else { + // SAFETY: `st` is the last state after execution without encountering any error. + let (valid_up_to, _) = unsafe { resolve_error_location(st, bytes, bytes.len()) }; + Err(Utf8Error { valid_up_to, error_len: None }) + } } - ); + } +} - while index < len { - let old_offset = index; - macro_rules! err { - ($error_len: expr) => { - return Err(Utf8Error { valid_up_to: old_offset, error_len: $error_len }) - }; - } +#[inline] +fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { + const MAIN_CHUNK_SIZE: usize = 16; + const ASCII_CHUNK_SIZE: usize = 16; + const { assert!(ASCII_CHUNK_SIZE % MAIN_CHUNK_SIZE == 0) }; - macro_rules! next { - () => {{ - index += 1; - // we needed data, but there was none: error! - if index >= len { - err!(None) - } - v[index] - }}; - } + let mut st = ST_ACCEPT; + let mut i = 0usize; - let first = v[index]; - if first >= 128 { - let w = utf8_char_width(first); - // 2-byte encoding is for codepoints \u{0080} to \u{07ff} - // first C2 80 last DF BF - // 3-byte encoding is for codepoints \u{0800} to \u{ffff} - // first E0 A0 80 last EF BF BF - // excluding surrogates codepoints \u{d800} to \u{dfff} - // ED A0 80 to ED BF BF - // 4-byte encoding is for codepoints \u{10000} to \u{10ffff} - // first F0 90 80 80 last F4 8F BF BF - // - // Use the UTF-8 syntax from the RFC - // - // https://tools.ietf.org/html/rfc3629 - // UTF8-1 = %x00-7F - // UTF8-2 = %xC2-DF UTF8-tail - // UTF8-3 = %xE0 %xA0-BF UTF8-tail / %xE1-EC 2( UTF8-tail ) / - // %xED %x80-9F UTF8-tail / %xEE-EF 2( UTF8-tail ) - // UTF8-4 = %xF0 %x90-BF 2( UTF8-tail ) / %xF1-F3 3( UTF8-tail ) / - // %xF4 %x80-8F 2( UTF8-tail ) - match w { - 2 => { - if next!() as i8 >= -64 { - err!(Some(1)) - } - } - 3 => { - match (first, next!()) { - (0xE0, 0xA0..=0xBF) - | (0xE1..=0xEC, 0x80..=0xBF) - | (0xED, 0x80..=0x9F) - | (0xEE..=0xEF, 0x80..=0xBF) => {} - _ => err!(Some(1)), - } - if next!() as i8 >= -64 { - err!(Some(2)) - } - } - 4 => { - match (first, next!()) { - (0xF0, 0x90..=0xBF) | (0xF1..=0xF3, 0x80..=0xBF) | (0xF4, 0x80..=0x8F) => {} - _ => err!(Some(1)), - } - if next!() as i8 >= -64 { - err!(Some(2)) - } - if next!() as i8 >= -64 { - err!(Some(3)) - } - } - _ => err!(Some(1)), - } - index += 1; - } else { - // Ascii case, try to skip forward quickly. - // When the pointer is aligned, read 2 words of data per iteration - // until we find a word containing a non-ascii byte. - if align != usize::MAX && align.wrapping_sub(index) % USIZE_BYTES == 0 { - let ptr = v.as_ptr(); - while index < blocks_end { - // SAFETY: since `align - index` and `ascii_block_size` are - // multiples of `USIZE_BYTES`, `block = ptr.add(index)` is - // always aligned with a `usize` so it's safe to dereference - // both `block` and `block.add(1)`. - unsafe { - let block = ptr.add(index) as *const usize; - // break if there is a nonascii byte - let zu = contains_nonascii(*block); - let zv = contains_nonascii(*block.add(1)); - if zu || zv { - break; - } - } - index += ascii_block_size; - } - // step from the point where the wordwise loop stopped - while index < len && v[index] < 128 { - index += 1; - } - } else { - index += 1; + while i + MAIN_CHUNK_SIZE <= bytes.len() { + // Fast path: if the current state is ACCEPT, we can skip to the next non-ASCII chunk. + // We also did a quick inspection on the first byte to avoid getting into this path at all + // when handling strings with almost no ASCII, eg. Chinese scripts. + // SAFETY: `i` is in bound. + if st == ST_ACCEPT && unsafe { *bytes.get_unchecked(i) } < 0x80 { + // SAFETY: `i` is in bound. + let rest = unsafe { bytes.get_unchecked(i..) }; + let mut ascii_chunks = rest.array_chunks::(); + let ascii_rest_chunk_cnt = ascii_chunks.len(); + let pos = ascii_chunks + .position(|chunk| { + // NB. Always traverse the whole chunk to enable vectorization, instead of `.any()`. + // LLVM will be fear of memory traps and fallback if loop has short-circuit. + #[expect(clippy::unnecessary_fold)] + let has_non_ascii = chunk.iter().fold(false, |acc, &b| acc || (b >= 0x80)); + has_non_ascii + }) + .unwrap_or(ascii_rest_chunk_cnt); + i += pos * ASCII_CHUNK_SIZE; + if i + MAIN_CHUNK_SIZE > bytes.len() { + break; } } + + // SAFETY: `i` and `i + MAIN_CHUNK_SIZE` are in bound by loop invariant. + let chunk = unsafe { &*bytes.as_ptr().add(i).cast::<[u8; MAIN_CHUNK_SIZE]>() }; + let mut new_st = st; + for &b in chunk { + new_st = next_state(new_st, b); + } + if unlikely(new_st & STATE_MASK == ST_ERROR) { + // Discard the current chunk erronous result, and reuse the trailing chunk handling to + // report the error location. + break; + } + + st = new_st; + i += MAIN_CHUNK_SIZE; + } + + // SAFETY: `st` is the last state after executing `bytes[..i]` without encountering any error. + unsafe { run_with_error_handling(&mut st, bytes, i)? }; + + if unlikely(st & STATE_MASK != ST_ACCEPT) { + // SAFETY: Same as above. + let (valid_up_to, _) = unsafe { resolve_error_location(st, bytes, bytes.len()) }; + return Err(Utf8Error { valid_up_to, error_len: None }); } Ok(()) From ad775c202e9ddcf645ee610dbbddebdc79389243 Mon Sep 17 00:00:00 2001 From: oxalica Date: Sat, 8 Feb 2025 18:58:21 -0500 Subject: [PATCH 02/10] Align transition table and fit it in a single page 1. To reduce the cache footprint. 2. To avoid additional cost when access across pages. --- library/core/src/str/validations.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index aa5bd19896266..8726c1d1f9cd7 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -166,7 +166,11 @@ const ST_ACCEPT: u32 = OFFSETS[1]; // See the end of `./solve_dfa.py`. const OFFSETS: [u32; STATE_CNT] = [0, 6, 16, 19, 1, 25, 11, 18, 24]; -static TRANS_TABLE: [u32; 256] = { +// Keep the whole table in a single page. +#[repr(align(1024))] +struct TransitionTable([u32; 256]); + +static TRANS_TABLE: TransitionTable = { let mut table = [0u32; 256]; let mut b = 0; while b < 256 { @@ -187,12 +191,12 @@ static TRANS_TABLE: [u32; 256] = { }; b += 1; } - table + TransitionTable(table) }; #[inline(always)] const fn next_state(st: u32, byte: u8) -> u32 { - TRANS_TABLE[byte as usize].wrapping_shr(st) + TRANS_TABLE.0[byte as usize].wrapping_shr(st) } /// Check if `byte` is a valid UTF-8 first byte, assuming it must be a valid first or From c17a4e1a04ace8a4184de5caa903ed095bec7287 Mon Sep 17 00:00:00 2001 From: oxalica Date: Mon, 10 Feb 2025 01:38:56 -0500 Subject: [PATCH 03/10] Reuse core::str::from_utf8 in lossy UTF-8 parsing --- library/core/src/str/lossy.rs | 105 +++++---------------------- library/coretests/tests/str_lossy.rs | 3 + 2 files changed, 22 insertions(+), 86 deletions(-) diff --git a/library/core/src/str/lossy.rs b/library/core/src/str/lossy.rs index ed2cefc59a51c..3acb05a0344a4 100644 --- a/library/core/src/str/lossy.rs +++ b/library/core/src/str/lossy.rs @@ -1,5 +1,4 @@ use super::from_utf8_unchecked; -use super::validations::utf8_char_width; use crate::fmt; use crate::fmt::{Formatter, Write}; use crate::iter::FusedIterator; @@ -197,93 +196,27 @@ impl<'a> Iterator for Utf8Chunks<'a> { return None; } - const TAG_CONT_U8: u8 = 128; - fn safe_get(xs: &[u8], i: usize) -> u8 { - *xs.get(i).unwrap_or(&0) - } - - let mut i = 0; - let mut valid_up_to = 0; - while i < self.source.len() { - // SAFETY: `i < self.source.len()` per previous line. - // For some reason the following are both significantly slower: - // while let Some(&byte) = self.source.get(i) { - // while let Some(byte) = self.source.get(i).copied() { - let byte = unsafe { *self.source.get_unchecked(i) }; - i += 1; - - if byte < 128 { - // This could be a `1 => ...` case in the match below, but for - // the common case of all-ASCII inputs, we bypass loading the - // sizeable UTF8_CHAR_WIDTH table into cache. - } else { - let w = utf8_char_width(byte); - - match w { - 2 => { - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - } - 3 => { - match (byte, safe_get(self.source, i)) { - (0xE0, 0xA0..=0xBF) => (), - (0xE1..=0xEC, 0x80..=0xBF) => (), - (0xED, 0x80..=0x9F) => (), - (0xEE..=0xEF, 0x80..=0xBF) => (), - _ => break, - } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - } - 4 => { - match (byte, safe_get(self.source, i)) { - (0xF0, 0x90..=0xBF) => (), - (0xF1..=0xF3, 0x80..=0xBF) => (), - (0xF4, 0x80..=0x8F) => (), - _ => break, - } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - } - _ => break, - } + match super::from_utf8(self.source) { + Ok(valid) => { + // Truncate the slice, no need to touch the pointer. + self.source = &self.source[..0]; + Some(Utf8Chunk { valid, invalid: &[] }) + } + Err(err) => { + let valid_up_to = err.valid_up_to(); + let error_len = err.error_len().unwrap_or(self.source.len() - valid_up_to); + // SAFETY: `valid_up_to` is the valid UTF-8 string length, so is in bound. + let (valid, remaining) = unsafe { self.source.split_at_unchecked(valid_up_to) }; + // SAFETY: `error_len` is the errornous byte sequence length, so is in bound. + let (invalid, after_invalid) = unsafe { remaining.split_at_unchecked(error_len) }; + self.source = after_invalid; + Some(Utf8Chunk { + // SAFETY: All bytes up to `valid_up_to` are valid UTF-8. + valid: unsafe { from_utf8_unchecked(valid) }, + invalid, + }) } - - valid_up_to = i; } - - // SAFETY: `i <= self.source.len()` because it is only ever incremented - // via `i += 1` and in between every single one of those increments, `i` - // is compared against `self.source.len()`. That happens either - // literally by `i < self.source.len()` in the while-loop's condition, - // or indirectly by `safe_get(self.source, i) & 192 != TAG_CONT_U8`. The - // loop is terminated as soon as the latest `i += 1` has made `i` no - // longer less than `self.source.len()`, which means it'll be at most - // equal to `self.source.len()`. - let (inspected, remaining) = unsafe { self.source.split_at_unchecked(i) }; - self.source = remaining; - - // SAFETY: `valid_up_to <= i` because it is only ever assigned via - // `valid_up_to = i` and `i` only increases. - let (valid, invalid) = unsafe { inspected.split_at_unchecked(valid_up_to) }; - - Some(Utf8Chunk { - // SAFETY: All bytes up to `valid_up_to` are valid UTF-8. - valid: unsafe { from_utf8_unchecked(valid) }, - invalid, - }) } } diff --git a/library/coretests/tests/str_lossy.rs b/library/coretests/tests/str_lossy.rs index 6e70ea3e28574..2c977b40bfa66 100644 --- a/library/coretests/tests/str_lossy.rs +++ b/library/coretests/tests/str_lossy.rs @@ -58,6 +58,9 @@ fn chunks() { ("foo\u{10000}bar", b""), ); + // incomplete + assert_chunks!(b"bar\xF1\x80\x80", ("bar", b"\xF1\x80\x80")); + // surrogates assert_chunks!( b"\xED\xA0\x80foo\xED\xBF\xBFbar", From 46a4782d33468c569d7d9001580c6fd13f098934 Mon Sep 17 00:00:00 2001 From: oxalica Date: Wed, 12 Feb 2025 18:37:29 -0500 Subject: [PATCH 04/10] Process partial chunks at beginning and remove `unlikely` hints Hope to have a better latency on short strings and/or the immediate-fail path. --- library/core/src/str/validations.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 8726c1d1f9cd7..3e173ee93a039 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -1,7 +1,7 @@ //! Operations related to UTF-8 validation. use super::Utf8Error; -use crate::intrinsics::{const_eval_select, unlikely}; +use crate::intrinsics::const_eval_select; /// Returns the initial codepoint accumulator for the first byte. /// The first byte is special, only want bottom 5 bits for width 2, 4 bits @@ -243,7 +243,7 @@ const unsafe fn run_with_error_handling( ) -> Result<(), Utf8Error> { while i < bytes.len() { let new_st = next_state(*st, bytes[i]); - if unlikely(new_st & STATE_MASK == ST_ERROR) { + if new_st & STATE_MASK == ST_ERROR { // SAFETY: Guaranteed by the caller. let (valid_up_to, error_len) = unsafe { resolve_error_location(*st, bytes, i) }; return Err(Utf8Error { valid_up_to, error_len: Some(error_len) }); @@ -287,7 +287,9 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { const { assert!(ASCII_CHUNK_SIZE % MAIN_CHUNK_SIZE == 0) }; let mut st = ST_ACCEPT; - let mut i = 0usize; + let mut i = bytes.len() % MAIN_CHUNK_SIZE; + // SAFETY: Start at initial state ACCEPT. + unsafe { run_with_error_handling(&mut st, &bytes[..i], 0)? }; while i + MAIN_CHUNK_SIZE <= bytes.len() { // Fast path: if the current state is ACCEPT, we can skip to the next non-ASCII chunk. @@ -320,20 +322,16 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { for &b in chunk { new_st = next_state(new_st, b); } - if unlikely(new_st & STATE_MASK == ST_ERROR) { - // Discard the current chunk erronous result, and reuse the trailing chunk handling to - // report the error location. - break; + if new_st & STATE_MASK == ST_ERROR { + // SAFETY: `st` is the last state after executing `bytes[..i]` without encountering any error. + return unsafe { run_with_error_handling(&mut st, bytes, i) }; } st = new_st; i += MAIN_CHUNK_SIZE; } - // SAFETY: `st` is the last state after executing `bytes[..i]` without encountering any error. - unsafe { run_with_error_handling(&mut st, bytes, i)? }; - - if unlikely(st & STATE_MASK != ST_ACCEPT) { + if st & STATE_MASK != ST_ACCEPT { // SAFETY: Same as above. let (valid_up_to, _) = unsafe { resolve_error_location(st, bytes, bytes.len()) }; return Err(Utf8Error { valid_up_to, error_len: None }); From dbba13b45968b5e3edb952b3c47ceb902cf9f864 Mon Sep 17 00:00:00 2001 From: oxalica Date: Thu, 13 Feb 2025 07:42:56 -0500 Subject: [PATCH 05/10] Use UTF-8 error length enum to reduce register spill When using `error_len: Option`, `Result<(), Utf8Error>` will be returned on stack and produces suboptimal stack suffling operations. It causes 50%-200% latency increase on the error path. --- library/core/src/str/error.rs | 35 ++++++++++++++++++++++++----- library/core/src/str/lossy.rs | 7 ++++-- library/core/src/str/validations.rs | 31 +++++++++++++------------ 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/library/core/src/str/error.rs b/library/core/src/str/error.rs index 4c8231a2286e1..70a6f54a858cc 100644 --- a/library/core/src/str/error.rs +++ b/library/core/src/str/error.rs @@ -42,11 +42,24 @@ use crate::fmt; /// } /// } /// ``` -#[derive(Copy, Eq, PartialEq, Clone, Debug)] +#[derive(Copy, Eq, PartialEq, Clone)] #[stable(feature = "rust1", since = "1.0.0")] pub struct Utf8Error { pub(super) valid_up_to: usize, - pub(super) error_len: Option, + // Use a single value instead of tagged enum `Option` to make `Result<(), Utf8Error>` fits + // in two machine words, so `run_utf8_validation` does not need to returns values on stack on + // x86(_64). Register spill is very expensive on `run_utf8_validation` and can give up to 200% + // latency penalty on the error path. + pub(super) error_len: Utf8ErrorLen, +} + +#[derive(Copy, Eq, PartialEq, Clone)] +#[repr(u8)] +pub(super) enum Utf8ErrorLen { + Eof = 0, + One, + Two, + Three, } impl Utf8Error { @@ -100,18 +113,28 @@ impl Utf8Error { #[must_use] #[inline] pub const fn error_len(&self) -> Option { - // FIXME(const-hack): This should become `map` again, once it's `const` match self.error_len { - Some(len) => Some(len as usize), - None => None, + Utf8ErrorLen::Eof => None, + // FIXME(136972): Direct `match` gives suboptimal codegen involving two table lookups. + len => Some(len as usize), } } } +#[stable(feature = "rust1", since = "1.0.0")] +impl fmt::Debug for Utf8Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Utf8Error") + .field("valid_up_to", &self.valid_up_to) + .field("error_len", &self.error_len()) + .finish() + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl fmt::Display for Utf8Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(error_len) = self.error_len { + if let Some(error_len) = self.error_len() { write!( f, "invalid utf-8 sequence of {} bytes from index {}", diff --git a/library/core/src/str/lossy.rs b/library/core/src/str/lossy.rs index 3acb05a0344a4..82176084751ed 100644 --- a/library/core/src/str/lossy.rs +++ b/library/core/src/str/lossy.rs @@ -1,4 +1,5 @@ use super::from_utf8_unchecked; +use super::validations::run_utf8_validation; use crate::fmt; use crate::fmt::{Formatter, Write}; use crate::iter::FusedIterator; @@ -196,8 +197,10 @@ impl<'a> Iterator for Utf8Chunks<'a> { return None; } - match super::from_utf8(self.source) { - Ok(valid) => { + match run_utf8_validation(self.source) { + Ok(()) => { + // SAFETY: The whole `source` is valid in UTF-8. + let valid = unsafe { from_utf8_unchecked(&self.source) }; // Truncate the slice, no need to touch the pointer. self.source = &self.source[..0]; Some(Utf8Chunk { valid, invalid: &[] }) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 3e173ee93a039..6aed9bd5354d5 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -1,6 +1,7 @@ //! Operations related to UTF-8 validation. use super::Utf8Error; +use super::error::Utf8ErrorLen; use crate::intrinsics::const_eval_select; /// Returns the initial codepoint accumulator for the first byte. @@ -210,25 +211,26 @@ const fn is_utf8_first_byte(byte: u8) -> bool { /// The caller must ensure `bytes[..i]` is a valid UTF-8 prefix and `st` is the DFA state after /// executing on `bytes[..i]`. #[inline] -const unsafe fn resolve_error_location(st: u32, bytes: &[u8], i: usize) -> (usize, u8) { +const unsafe fn resolve_error_location(st: u32, bytes: &[u8], i: usize) -> Utf8Error { // There are two cases: // 1. [valid UTF-8..] | *here // The previous state must be ACCEPT for the case 1, and `valid_up_to = i`. // 2. [valid UTF-8..] | valid first byte, [valid continuation byte...], *here // `valid_up_to` is at the latest non-continuation byte, which must exist and // be in range `(i-3)..i`. - if st & STATE_MASK == ST_ACCEPT { - (i, 1) + let (valid_up_to, error_len) = if st & STATE_MASK == ST_ACCEPT { + (i, Utf8ErrorLen::One) // SAFETY: UTF-8 first byte must exist if we are in an intermediate state. // We use pointer here because `get_unchecked` is not const fn. } else if is_utf8_first_byte(unsafe { bytes.as_ptr().add(i - 1).read() }) { - (i - 1, 1) + (i - 1, Utf8ErrorLen::One) // SAFETY: Same as above. } else if is_utf8_first_byte(unsafe { bytes.as_ptr().add(i - 2).read() }) { - (i - 2, 2) + (i - 2, Utf8ErrorLen::Two) } else { - (i - 3, 3) - } + (i - 3, Utf8ErrorLen::Three) + }; + Utf8Error { valid_up_to, error_len } } // The simpler but slower algorithm to run DFA with error handling. @@ -245,8 +247,7 @@ const unsafe fn run_with_error_handling( let new_st = next_state(*st, bytes[i]); if new_st & STATE_MASK == ST_ERROR { // SAFETY: Guaranteed by the caller. - let (valid_up_to, error_len) = unsafe { resolve_error_location(*st, bytes, i) }; - return Err(Utf8Error { valid_up_to, error_len: Some(error_len) }); + return Err(unsafe { resolve_error_location(*st, bytes, i) }); } *st = new_st; i += 1; @@ -256,7 +257,7 @@ const unsafe fn run_with_error_handling( /// Walks through `v` checking that it's a valid UTF-8 sequence, /// returning `Ok(())` in that case, or, if it is invalid, `Err(err)`. -#[inline(always)] +#[inline] #[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior pub(super) const fn run_utf8_validation(bytes: &[u8]) -> Result<(), Utf8Error> { const_eval_select((bytes,), run_utf8_validation_const, run_utf8_validation_rt) @@ -273,8 +274,9 @@ const fn run_utf8_validation_const(bytes: &[u8]) -> Result<(), Utf8Error> { Ok(()) } else { // SAFETY: `st` is the last state after execution without encountering any error. - let (valid_up_to, _) = unsafe { resolve_error_location(st, bytes, bytes.len()) }; - Err(Utf8Error { valid_up_to, error_len: None }) + let mut err = unsafe { resolve_error_location(st, bytes, bytes.len()) }; + err.error_len = Utf8ErrorLen::Eof; + Err(err) } } } @@ -333,8 +335,9 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { if st & STATE_MASK != ST_ACCEPT { // SAFETY: Same as above. - let (valid_up_to, _) = unsafe { resolve_error_location(st, bytes, bytes.len()) }; - return Err(Utf8Error { valid_up_to, error_len: None }); + let mut err = unsafe { resolve_error_location(st, bytes, bytes.len()) }; + err.error_len = Utf8ErrorLen::Eof; + return Err(err); } Ok(()) From b69c448886c3d599301f531a4c50572f00f6f442 Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 14 Feb 2025 15:49:11 -0500 Subject: [PATCH 06/10] Prefer pure functions over mutable arguments --- library/core/src/str/validations.rs | 38 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 6aed9bd5354d5..07bc442c9b4f8 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -234,25 +234,27 @@ const unsafe fn resolve_error_location(st: u32, bytes: &[u8], i: usize) -> Utf8E } // The simpler but slower algorithm to run DFA with error handling. +// Returns the final state after execution on the whole slice. // // # Safety // The caller must ensure `bytes[..i]` is a valid UTF-8 prefix and `st` is the DFA state after // executing on `bytes[..i]`. +#[inline] const unsafe fn run_with_error_handling( - st: &mut u32, + mut st: u32, bytes: &[u8], mut i: usize, -) -> Result<(), Utf8Error> { +) -> Result { while i < bytes.len() { - let new_st = next_state(*st, bytes[i]); + let new_st = next_state(st, bytes[i]); if new_st & STATE_MASK == ST_ERROR { // SAFETY: Guaranteed by the caller. - return Err(unsafe { resolve_error_location(*st, bytes, i) }); + return Err(unsafe { resolve_error_location(st, bytes, i) }); } - *st = new_st; + st = new_st; i += 1; } - Ok(()) + Ok(st) } /// Walks through `v` checking that it's a valid UTF-8 sequence, @@ -265,19 +267,15 @@ pub(super) const fn run_utf8_validation(bytes: &[u8]) -> Result<(), Utf8Error> { #[inline] const fn run_utf8_validation_const(bytes: &[u8]) -> Result<(), Utf8Error> { - let mut st = ST_ACCEPT; // SAFETY: Start at empty string with valid state ACCEPT. - match unsafe { run_with_error_handling(&mut st, bytes, 0) } { + match unsafe { run_with_error_handling(ST_ACCEPT, bytes, 0) } { Err(err) => Err(err), - Ok(()) => { - if st & STATE_MASK == ST_ACCEPT { - Ok(()) - } else { - // SAFETY: `st` is the last state after execution without encountering any error. - let mut err = unsafe { resolve_error_location(st, bytes, bytes.len()) }; - err.error_len = Utf8ErrorLen::Eof; - Err(err) - } + Ok(st) if st & STATE_MASK == ST_ACCEPT => Ok(()), + Ok(st) => { + // SAFETY: `st` is the last state after execution without encountering any error. + let mut err = unsafe { resolve_error_location(st, bytes, bytes.len()) }; + err.error_len = Utf8ErrorLen::Eof; + Err(err) } } } @@ -288,10 +286,9 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { const ASCII_CHUNK_SIZE: usize = 16; const { assert!(ASCII_CHUNK_SIZE % MAIN_CHUNK_SIZE == 0) }; - let mut st = ST_ACCEPT; let mut i = bytes.len() % MAIN_CHUNK_SIZE; // SAFETY: Start at initial state ACCEPT. - unsafe { run_with_error_handling(&mut st, &bytes[..i], 0)? }; + let mut st = unsafe { run_with_error_handling(ST_ACCEPT, &bytes[..i], 0)? }; while i + MAIN_CHUNK_SIZE <= bytes.len() { // Fast path: if the current state is ACCEPT, we can skip to the next non-ASCII chunk. @@ -326,7 +323,8 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { } if new_st & STATE_MASK == ST_ERROR { // SAFETY: `st` is the last state after executing `bytes[..i]` without encountering any error. - return unsafe { run_with_error_handling(&mut st, bytes, i) }; + // And we know the next chunk must fail the validation. + return Err(unsafe { run_with_error_handling(st, bytes, i).unwrap_err_unchecked() }); } st = new_st; From 16203e29e300c9cfe24aa8e164b3dfca8fe8fabc Mon Sep 17 00:00:00 2001 From: oxalica Date: Sat, 15 Feb 2025 23:10:28 -0500 Subject: [PATCH 07/10] Fix comments, use `u8::is_ascii` and simplify --- library/core/src/str/validations.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 07bc442c9b4f8..daa19193a7b72 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -135,7 +135,11 @@ where // and it becomes free on modern ISAs, including x86, x86_64 and ARM. // // ``` -// // shrx state, qword ptr [table_addr + 8 * byte], state # On x86-64-v3 +// // On x86-64-v3: (more instructions on ordinary x86_64 but with same cycles-per-byte) +// // shrx state, qword ptr [TRANS_TABLE + 4 * byte], state +// // On aarch64/ARMv8: +// // ldr temp, [TRANS_TABLE, byte, lsl 2] +// // lsr state, temp, state // state = TRANS_TABLE[byte].wrapping_shr(state); // ``` // @@ -290,27 +294,28 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { // SAFETY: Start at initial state ACCEPT. let mut st = unsafe { run_with_error_handling(ST_ACCEPT, &bytes[..i], 0)? }; - while i + MAIN_CHUNK_SIZE <= bytes.len() { + while i < bytes.len() { // Fast path: if the current state is ACCEPT, we can skip to the next non-ASCII chunk. // We also did a quick inspection on the first byte to avoid getting into this path at all // when handling strings with almost no ASCII, eg. Chinese scripts. // SAFETY: `i` is in bound. - if st == ST_ACCEPT && unsafe { *bytes.get_unchecked(i) } < 0x80 { + if st == ST_ACCEPT && unsafe { bytes.get_unchecked(i).is_ascii() } { // SAFETY: `i` is in bound. let rest = unsafe { bytes.get_unchecked(i..) }; let mut ascii_chunks = rest.array_chunks::(); let ascii_rest_chunk_cnt = ascii_chunks.len(); let pos = ascii_chunks .position(|chunk| { - // NB. Always traverse the whole chunk to enable vectorization, instead of `.any()`. - // LLVM will be fear of memory traps and fallback if loop has short-circuit. + // NB. Always traverse the whole chunk instead of `.all()`, to persuade LLVM to + // vectorize this check. + // We also do not use `<[u8]>::is_ascii` which is unnecessarily complex here. #[expect(clippy::unnecessary_fold)] - let has_non_ascii = chunk.iter().fold(false, |acc, &b| acc || (b >= 0x80)); - has_non_ascii + let all_ascii = chunk.iter().fold(true, |acc, b| acc && b.is_ascii()); + !all_ascii }) .unwrap_or(ascii_rest_chunk_cnt); i += pos * ASCII_CHUNK_SIZE; - if i + MAIN_CHUNK_SIZE > bytes.len() { + if i >= bytes.len() { break; } } From bc57db5337902707b986974c7fb5c1b0082e1502 Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 7 Mar 2025 20:43:38 -0500 Subject: [PATCH 08/10] Add a benchmark of UTF-8 validation for ASCII --- library/coretests/benches/str.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/library/coretests/benches/str.rs b/library/coretests/benches/str.rs index 2f7d9d56a70b7..49b2447e05d6b 100644 --- a/library/coretests/benches/str.rs +++ b/library/coretests/benches/str.rs @@ -11,3 +11,8 @@ mod iter; fn str_validate_emoji(b: &mut Bencher) { b.iter(|| str::from_utf8(black_box(corpora::emoji::LARGE.as_bytes()))); } + +#[bench] +fn str_validate_ascii(b: &mut Bencher) { + b.iter(|| str::from_utf8(black_box(corpora::en::LARGE.as_bytes()))); +} From ff0055905d2a5e7a47d9581c9ba6eea3082119e4 Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 28 Mar 2025 02:18:52 -0400 Subject: [PATCH 09/10] Only use non-chunked UTF8 validation on `optimize_for_size` --- library/core/src/str/validations.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index daa19193a7b72..96f52aa7d77ee 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -263,10 +263,14 @@ const unsafe fn run_with_error_handling( /// Walks through `v` checking that it's a valid UTF-8 sequence, /// returning `Ok(())` in that case, or, if it is invalid, `Err(err)`. -#[inline] +#[cfg_attr(not(feature = "optimize_for_size"), inline)] #[rustc_allow_const_fn_unstable(const_eval_select)] // fallback impl has same behavior pub(super) const fn run_utf8_validation(bytes: &[u8]) -> Result<(), Utf8Error> { - const_eval_select((bytes,), run_utf8_validation_const, run_utf8_validation_rt) + if cfg!(feature = "optimize_for_size") { + run_utf8_validation_const(bytes) + } else { + const_eval_select((bytes,), run_utf8_validation_const, run_utf8_validation_rt) + } } #[inline] From e3bef6ed5025eb36be6c7579bcd093296cf66534 Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 28 Mar 2025 13:50:32 -0400 Subject: [PATCH 10/10] Fix missed state masking --- library/core/src/str/validations.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 96f52aa7d77ee..cac47ab20f884 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -303,7 +303,7 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> { // We also did a quick inspection on the first byte to avoid getting into this path at all // when handling strings with almost no ASCII, eg. Chinese scripts. // SAFETY: `i` is in bound. - if st == ST_ACCEPT && unsafe { bytes.get_unchecked(i).is_ascii() } { + if st & STATE_MASK == ST_ACCEPT && unsafe { bytes.get_unchecked(i).is_ascii() } { // SAFETY: `i` is in bound. let rest = unsafe { bytes.get_unchecked(i..) }; let mut ascii_chunks = rest.array_chunks::();