diff --git a/library/core/src/str/lossy.rs b/library/core/src/str/lossy.rs index 6ec1c93908fc7..9e7f1a60da95e 100644 --- a/library/core/src/str/lossy.rs +++ b/library/core/src/str/lossy.rs @@ -51,19 +51,43 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> { } const TAG_CONT_U8: u8 = 128; - fn safe_get(xs: &[u8], i: usize) -> u8 { - *xs.get(i).unwrap_or(&0) + + /// Gets the byte at `current`, returning zero if `current` is past `end` + /// + /// # Safety + /// + /// `current` must be a valid pointer to an initialized byte + /// + unsafe fn try_get(current: *const u8, end: *const u8) -> u8 { + // SAFETY: current is a valid pointer + if current < end { unsafe { *current } } else { 0 } + } + + /// Checks if the byte at `current` bitand 192 is `TAG_CONT_U8`, returning + /// false if not or if `current` is past `end` + /// + /// # Safety + /// + /// `current` must be a valid pointer to an initalized byte + /// + unsafe fn shouldnt_continue(current: *const u8, end: *const u8) -> bool { + // SAFETY: current is a valid pointer + unsafe { current < end && *current & 192 != TAG_CONT_U8 } } - let mut i = 0; - let mut valid_up_to = 0; - while i < self.source.len() { - // SAFETY: `i < self.source.len()` per previous line. + let length = self.source.len(); + let mut current = self.source.as_ptr(); + // SAFETY: current + length is one past the end of the allocation + let (start, end, mut valid_up_to) = unsafe { (current, current.add(length), current) }; + + while current < end { + // SAFETY: `current < end` per previous line. // For some reason the following are both significantly slower: // while let Some(&byte) = self.source.get(i) { // while let Some(byte) = self.source.get(i).copied() { - let byte = unsafe { *self.source.get_unchecked(i) }; - i += 1; + let byte = unsafe { *current }; + // SAFETY: This will be at most one past the end of the slice (and then equal to end) + current = unsafe { current.add(1) }; if byte < 128 { // This could be a `1 => ...` case in the match below, but for @@ -72,51 +96,70 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> { } else { let w = utf8_char_width(byte); - match w { - 2 => { - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - } - 3 => { - match (byte, safe_get(self.source, i)) { - (0xE0, 0xA0..=0xBF) => (), - (0xE1..=0xEC, 0x80..=0xBF) => (), - (0xED, 0x80..=0x9F) => (), - (0xEE..=0xEF, 0x80..=0xBF) => (), - _ => break, - } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; - } - i += 1; - } - 4 => { - match (byte, safe_get(self.source, i)) { - (0xF0, 0x90..=0xBF) => (), - (0xF1..=0xF3, 0x80..=0xBF) => (), - (0xF4, 0x80..=0x8F) => (), - _ => break, + // SAFETY: All pointers will be at most one past the end of + // the slice (and then equal to end) + unsafe { + match w { + 2 => { + if shouldnt_continue(current, end) { + break; + } + + current = current.add(1); } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; + + 3 => { + match (byte, try_get(current, end)) { + (0xE0, 0xA0..=0xBF) + | (0xE1..=0xEC, 0x80..=0xBF) + | (0xED, 0x80..=0x9F) + | (0xEE..=0xEF, 0x80..=0xBF) => {} + _ => break, + } + current = current.add(1); + + if shouldnt_continue(current, end) { + break; + } + current = current.add(1); } - i += 1; - if safe_get(self.source, i) & 192 != TAG_CONT_U8 { - break; + + 4 => { + match (byte, try_get(current, end)) { + (0xF0, 0x90..=0xBF) + | (0xF1..=0xF3, 0x80..=0xBF) + | (0xF4, 0x80..=0x8F) => {} + _ => break, + } + current = current.add(1); + + if shouldnt_continue(current, end) { + break; + } + current = current.add(1); + + if shouldnt_continue(current, end) { + break; + } + current = current.add(1); } - i += 1; + + _ => break, } - _ => break, } } - valid_up_to = i; + valid_up_to = current; } + // SAFETY: Both pointers come from the same allocation + let idx = unsafe { current.offset_from(start) as usize }; + debug_assert!(idx <= length); + + // SAFETY: Both pointers come from the same allocation + let valid_up_to = unsafe { valid_up_to.offset_from(start) as usize }; + debug_assert!(valid_up_to <= length); + // SAFETY: `i <= self.source.len()` because it is only ever incremented // via `i += 1` and in between every single one of those increments, `i` // is compared against `self.source.len()`. That happens either @@ -125,7 +168,7 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> { // loop is terminated as soon as the latest `i += 1` has made `i` no // longer less than `self.source.len()`, which means it'll be at most // equal to `self.source.len()`. - let (inspected, remaining) = unsafe { self.source.split_at_unchecked(i) }; + let (inspected, remaining) = unsafe { self.source.split_at_unchecked(idx) }; self.source = remaining; // SAFETY: `valid_up_to <= i` because it is only ever assigned via