Skip to content

Commit e4d7cdd

Browse files
committed
Repair invalid UTF-8 issue, more tests
1 parent ad4c5a5 commit e4d7cdd

File tree

2 files changed

+90
-22
lines changed

2 files changed

+90
-22
lines changed

cbits/is-valid-utf8.c

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ SUCH DAMAGE.
3535
#include <string.h>
3636

3737
#ifdef __x86_64__
38+
#include <cpuid.h>
3839
#include <emmintrin.h>
3940
#include <immintrin.h>
40-
#include <cpuid.h>
41-
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__)
42-
#include <tmmintrin.h>
41+
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || \
42+
defined(__clang_major__)) && \
43+
!defined(__STDC_NO_ATOMICS__)
4344
#include <stdatomic.h>
45+
#include <tmmintrin.h>
4446
#else
4547
// This is needed to support CentOS 7, which has a very old GCC.
4648
#define CRUFTY_GCC
@@ -64,7 +66,8 @@ static inline uint64_t read_uint64(const uint64_t *p) {
6466
return r;
6567
}
6668

67-
static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const len) {
69+
static inline int is_valid_utf8_fallback(uint8_t const *const src,
70+
size_t const len) {
6871
uint8_t const *ptr = (uint8_t const *)src;
6972
// This is 'one past the end' to make loop termination and bounds checks
7073
// easier.
@@ -83,10 +86,11 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const
8386
// Non-ASCII bytes have a set MSB. Thus, if we AND with 0x80 in every
8487
// 'lane', we will get 0 if everything is ASCII, and something else
8588
// otherwise.
86-
uint64_t results[4] = {to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
87-
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
88-
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
89-
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
89+
uint64_t results[4] = {
90+
to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
91+
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
92+
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
93+
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
9094
if (results[0] == 0) {
9195
ptr += 8;
9296
if (results[1] == 0) {
@@ -331,10 +335,26 @@ static int8_t const ef_fe_lookup[16] = {
331335
};
332336

333337
__attribute__((target("ssse3"))) static inline bool
334-
is_ascii_sse2(__m128i const *src) {
338+
is_ascii_sse2(__m128i const *src, __m128i const prev_first_len) {
339+
// Check if we have ASCII, and also that we don't have to treat the prior
340+
// block as special.
341+
// First, verify that we didn't see any non-ASCII bytes in the first half of
342+
// the stride.
343+
__m128i const first_half_clean = _mm_or_si128(src[0], src[1]);
344+
// Then do the same for the second half of the stride.
345+
__m128i const second_half_clean = _mm_or_si128(src[2], src[3]);
346+
// Check cleanliness of the entire stride.
347+
__m128i const stride_clean =
348+
_mm_or_si128(first_half_clean, second_half_clean);
349+
// Finally, check that we didn't have any leftover marker bytes in the
350+
// previous block: these are indicated by non-zeroes in prev_first_len. In
351+
// order to trigger a failure, we have to have non-zeros set the high bit of
352+
// the lane: we do this by doing a greater-than comparison with a block of
353+
// zeroes.
354+
__m128i const no_prior_dirt =
355+
_mm_cmpgt_epi8(prev_first_len, _mm_setzero_si128());
335356
// OR together everything, then check for a high bit anywhere.
336-
__m128i const ored =
337-
_mm_or_si128(_mm_or_si128(src[0], src[1]), _mm_or_si128(src[2], src[3]));
357+
__m128i const ored = _mm_or_si128(stride_clean, no_prior_dirt);
338358
return (_mm_movemask_epi8(ored) == 0);
339359
}
340360

@@ -415,7 +435,7 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) {
415435
_mm_loadu_si128(big_ptr), _mm_loadu_si128(big_ptr + 1),
416436
_mm_loadu_si128(big_ptr + 2), _mm_loadu_si128(big_ptr + 3)};
417437
// Check if we have ASCII.
418-
if (is_ascii_sse2(inputs)) {
438+
if (is_ascii_sse2(inputs, prev_first_len)) {
419439
// Prev_first_len cheaply.
420440
prev_first_len =
421441
_mm_shuffle_epi8(first_len_tbl, high_nibbles_of(inputs[3]));
@@ -598,10 +618,26 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) {
598618
__m256i const inputs[4] = {
599619
_mm256_loadu_si256(big_ptr), _mm256_loadu_si256(big_ptr + 1),
600620
_mm256_loadu_si256(big_ptr + 2), _mm256_loadu_si256(big_ptr + 3)};
601-
// Check if we have ASCII.
602-
bool is_ascii = _mm256_movemask_epi8(_mm256_or_si256(
603-
_mm256_or_si256(inputs[0], inputs[1]),
604-
_mm256_or_si256(inputs[2], inputs[3]))) == 0;
621+
// Check if we have ASCII, and also that we don't have to treat the prior
622+
// block as special.
623+
// First, verify that we didn't see any non-ASCII bytes in the first half of
624+
// the stride.
625+
__m256i const first_half_clean = _mm256_or_si256(inputs[0], inputs[1]);
626+
// Then do the same for the second half of the stride.
627+
__m256i const second_half_clean = _mm256_or_si256(inputs[2], inputs[3]);
628+
// Check cleanliness of the entire stride.
629+
__m256i const stride_clean =
630+
_mm256_or_si256(first_half_clean, second_half_clean);
631+
// Finally, check that we didn't have any leftover marker bytes in the
632+
// previous block: these are indicated by non-zeroes in prev_first_len.
633+
// In order to trigger a failure, we have to have non-zeros set the high bit
634+
// of the lane: we do this by doing a greater-than comparison with a block
635+
// of zeroes.
636+
__m256i const no_prior_dirt =
637+
_mm256_cmpgt_epi8(prev_first_len, _mm256_setzero_si256());
638+
// Combine all checks together, and check if any high bits are set.
639+
bool is_ascii =
640+
_mm256_movemask_epi8(_mm256_or_si256(stride_clean, no_prior_dirt)) == 0;
605641
if (is_ascii) {
606642
// Prev_first_len cheaply
607643
prev_first_len =
@@ -683,7 +719,7 @@ static inline bool has_avx2() {
683719
}
684720
#endif
685721

686-
typedef int (*is_valid_utf8_t) (uint8_t const *const, size_t const);
722+
typedef int (*is_valid_utf8_t)(uint8_t const *const, size_t const);
687723

688724
int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
689725
if (len == 0) {
@@ -693,7 +729,10 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
693729
static _Atomic is_valid_utf8_t s_impl = (is_valid_utf8_t)NULL;
694730
is_valid_utf8_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed);
695731
if (!impl) {
696-
impl = has_avx2() ? is_valid_utf8_avx2 : (has_ssse3() ? is_valid_utf8_ssse3 : (has_sse2() ? is_valid_utf8_sse2 : is_valid_utf8_fallback));
732+
impl = has_avx2() ? is_valid_utf8_avx2
733+
: (has_ssse3() ? is_valid_utf8_ssse3
734+
: (has_sse2() ? is_valid_utf8_sse2
735+
: is_valid_utf8_fallback));
697736
atomic_store_explicit(&s_impl, impl, memory_order_relaxed);
698737
}
699738
return (*impl)(src, len);

tests/IsValidUtf8.hs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@ import qualified Data.ByteString.Short as SBS
88
import qualified Data.ByteString as B
99
import Data.Char (chr, ord)
1010
import Data.Word (Word8)
11-
import GHC.Exts (fromList)
12-
import Test.QuickCheck (Property, forAll, (===))
11+
import Control.Monad (guard)
12+
import Numeric (showHex)
13+
import GHC.Exts (fromList, fromListN, toList)
14+
import Test.QuickCheck (Property, forAll, (===), forAllShrinkShow)
1315
import Test.QuickCheck.Arbitrary (Arbitrary (arbitrary, shrink))
1416
import Test.QuickCheck.Gen (oneof, Gen, choose, vectorOf, listOf1, sized, resize,
15-
elements)
17+
elements, chooseEnum)
1618
import Test.Tasty (testGroup, adjustOption, TestTree)
1719
import Test.Tasty.QuickCheck (testProperty, QuickCheckTests)
1820

1921
testSuite :: TestTree
20-
testSuite = testGroup "UTF-8 validation" $ [
22+
testSuite = testGroup "UTF-8 validation" [
2123
adjustOption (max testCount) . testProperty "Valid UTF-8 ByteString" $ goValidBS,
2224
adjustOption (max testCount) . testProperty "Invalid UTF-8 ByteString" $ goInvalidBS,
2325
adjustOption (max testCount) . testProperty "Valid UTF-8 ShortByteString" $ goValidSBS,
@@ -40,6 +42,7 @@ checkRegressions :: [TestTree]
4042
checkRegressions = [
4143
testProperty "Too high code point" $
4244
not $ B.isValidUtf8 tooHigh,
45+
testProperty "Invalid byte at end of ASCII block" badBlockEnd,
4346
testProperty "Invalid byte between spaces" $
4447
not $ B.isValidUtf8 byteBetweenSpaces,
4548
testProperty "Two invalid bytes between spaces" $
@@ -62,8 +65,34 @@ checkRegressions = [
6265
threeBytesBetweenSpaces :: ByteString
6366
threeBytesBetweenSpaces = fromList $ replicate 125 32 ++ [242, 134, 159] ++ replicate 128 32
6467

68+
badBlockEnd :: Property
69+
badBlockEnd =
70+
forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) ->
71+
not . B.isValidUtf8 $ bs
72+
6573
-- Helpers
6674

75+
-- A 128-byte sequence with a single bad byte at the end, with the rest being
76+
-- ASCII
77+
newtype BadBlock = BadBlock ByteString
78+
79+
genBadBlock :: Gen BadBlock
80+
genBadBlock = do
81+
asciiBytes <- vectorOf 127 $ chooseEnum (0, 127)
82+
pure . BadBlock . fromListN 128 $ asciiBytes <> [216]
83+
84+
shrinkBadBlock :: BadBlock -> [BadBlock]
85+
shrinkBadBlock (BadBlock bs) = BadBlock <$> do
86+
let asList = init . toList $ bs
87+
init' <- fromList <$> traverse shrink asList
88+
guard (B.length init' == 127)
89+
pure $ init' <> B.singleton 216
90+
91+
-- Display as hex instead of ASCII-ish
92+
showBadBlock :: BadBlock -> String
93+
showBadBlock (BadBlock bs) = let asList = toList bs in
94+
foldr showHex "" asList
95+
6796
data Utf8Sequence =
6897
One Word8 |
6998
Two Word8 Word8 |

0 commit comments

Comments
 (0)