Skip to content

Commit 8ca5425

Browse files
committed
Fix NEON ASCII check
1 parent 4680898 commit 8ca5425

File tree

1 file changed

+67
-57
lines changed

1 file changed

+67
-57
lines changed

cbits/aarch64/is-valid-utf8.c

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ SUCH DAMAGE.
2929
*/
3030
#pragma GCC push_options
3131
#pragma GCC optimize("-O2")
32+
#include <arm_neon.h>
3233
#include <stdbool.h>
33-
#include <stdint.h>
3434
#include <stddef.h>
35-
#include <arm_neon.h>
35+
#include <stdint.h>
3636

3737
// Fallback (for tails).
3838
static inline int is_valid_utf8_fallback(uint8_t const *const src,
@@ -102,29 +102,52 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src,
102102
}
103103

104104
static uint8_t const first_len_lookup[16] = {
105-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
105+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
106106
};
107107

108108
static uint8_t const first_range_lookup[16] = {
109-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
109+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
110110
};
111111

112112
static uint8_t const range_min_lookup[16] = {
113-
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
114-
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
113+
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
114+
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
115115
};
116116

117117
static uint8_t const range_max_lookup[16] = {
118-
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
119-
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
118+
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
119+
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
120120
};
121121

122122
static uint8_t const range_adjust_lookup[32] = {
123-
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
124-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
123+
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
124+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
125125
};
126126

127-
static bool is_ascii (uint8x16_t const * const inputs) {
127+
static bool is_ascii(uint8x16_t const *const inputs,
128+
uint8x16_t const prev_first_len) {
129+
// Check if we have ASCII, and also that we don't have to treat the prior
130+
// block as special.
131+
// First, verify that we didn't see any non-ASCII bytes in the first half of
132+
// the stride.
133+
uint8x16_t const first_half_clean = vorrq_u8(inputs[0], inputs[1]);
134+
// Then we do the same for the second half of the stride.
135+
uint8x16_t const second_half_clean = vorrq_u8(inputs[2], inputs[3]);
136+
// Check cleanliness of the entire stride.
137+
uint8x16_t const stride_clean = vorrq_u8(first_half_clean, second_half_clean);
138+
// Leave only the high-order set bits.
139+
uint8x16_t const masked = vandq_u8(stride_clean, vdupq_n_u8(0x80));
140+
// Finally, check that we didn't have any leftover marker bytes in the
141+
// previous block: these are indicated by non-zeroes in prev_first_len. In
142+
// order to trigger a failure, we have to have non-zeroes set in the high bit
143+
// of the lane: we do this by doing a greater-than comparison with a block of
144+
// zeroes.
145+
uint8x16_t const no_prior_dirt = vcgt_u8(prev_first_len, vdupq_n_u8(0x00));
146+
// Check for all-zero.
147+
uint64x2_t const result =
148+
vreinterpretq_u64_u8(vorrq_u8(masked, no_prior_dirt));
149+
return !(vgetq_lane_u64(result, 0) || vgetq_lane_u64(result, 1));
150+
/*
128151
uint8x16_t const all_80 = vdupq_n_u8(0x80);
129152
// A non-ASCII byte will have its highest-order bit set. Since this is
130153
// preserved by OR, we can OR everything together.
@@ -133,20 +156,16 @@ static bool is_ascii (uint8x16_t const * const inputs) {
133156
// ANDing with 0x80 retains any set high-order bits. We then check for zeroes.
134157
uint64x2_t result = vreinterpretq_u64_u8(vandq_u8(ored, all_80));
135158
return !(vgetq_lane_u64(result, 0) || vgetq_lane_u64(result, 1));
159+
*/
136160
}
137161

138-
static void check_block_neon(uint8x16_t const prev_input,
139-
uint8x16_t const prev_first_len,
140-
uint8x16_t* errors,
141-
uint8x16_t const first_range_tbl,
142-
uint8x16_t const range_min_tbl,
143-
uint8x16_t const range_max_tbl,
144-
uint8x16x2_t const range_adjust_tbl,
145-
uint8x16_t const all_ones,
146-
uint8x16_t const all_twos,
147-
uint8x16_t const all_e0s,
148-
uint8x16_t const input,
149-
uint8x16_t const first_len) {
162+
static void
163+
check_block_neon(uint8x16_t const prev_input, uint8x16_t const prev_first_len,
164+
uint8x16_t *errors, uint8x16_t const first_range_tbl,
165+
uint8x16_t const range_min_tbl, uint8x16_t const range_max_tbl,
166+
uint8x16x2_t const range_adjust_tbl, uint8x16_t const all_ones,
167+
uint8x16_t const all_twos, uint8x16_t const all_e0s,
168+
uint8x16_t const input, uint8x16_t const first_len) {
150169
// Get the high 4-bits of the input.
151170
uint8x16_t const high_nibbles = vshrq_n_u8(input, 4);
152171
// Set range index to 8 for bytes in [C0, FF] by lookup (first byte).
@@ -182,20 +201,20 @@ static void check_block_neon(uint8x16_t const prev_input,
182201
errors[1] = vorrq_u8(errors[1], vcgtq_u8(input, maxv));
183202
}
184203

185-
int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
204+
int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
186205
if (len == 0) {
187206
return 1;
188207
}
189208
// We step 64 bytes at a time.
190209
size_t const big_strides = len / 64;
191210
size_t const remaining = len % 64;
192-
uint8_t const * ptr = (uint8_t const *)src;
211+
uint8_t const *ptr = (uint8_t const *)src;
193212
// Tracking state
194213
uint8x16_t prev_input = vdupq_n_u8(0);
195214
uint8x16_t prev_first_len = vdupq_n_u8(0);
196215
uint8x16_t errors[2] = {
197-
vdupq_n_u8(0),
198-
vdupq_n_u8(0),
216+
vdupq_n_u8(0),
217+
vdupq_n_u8(0),
199218
};
200219
// Load our lookup tables.
201220
uint8x16_t const first_len_tbl = vld1q_u8(first_len_lookup);
@@ -209,40 +228,33 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
209228
uint8x16_t const all_e0s = vdupq_n_u8(0xE0);
210229
for (size_t i = 0; i < big_strides; i++) {
211230
// Load 64 bytes
212-
uint8x16_t const inputs[4] = {
213-
vld1q_u8(ptr),
214-
vld1q_u8(ptr + 16),
215-
vld1q_u8(ptr + 32),
216-
vld1q_u8(ptr + 48)
217-
};
231+
uint8x16_t const inputs[4] = {vld1q_u8(ptr), vld1q_u8(ptr + 16),
232+
vld1q_u8(ptr + 32), vld1q_u8(ptr + 48)};
218233
// Check if we have ASCII
219-
if (is_ascii(inputs)) {
234+
if (is_ascii(inputs, prev_first_len)) {
220235
// Prev_first_len cheaply.
221236
prev_first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
222237
} else {
223-
uint8x16_t first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
224-
check_block_neon(prev_input, prev_first_len, errors,
225-
first_range_tbl, range_min_tbl, range_max_tbl,
226-
range_adjust_tbl, all_ones, all_twos, all_e0s,
227-
inputs[0], first_len);
238+
uint8x16_t first_len =
239+
vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
240+
check_block_neon(prev_input, prev_first_len, errors, first_range_tbl,
241+
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
242+
all_twos, all_e0s, inputs[0], first_len);
228243
prev_first_len = first_len;
229244
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[1], 4));
230-
check_block_neon(inputs[0], prev_first_len, errors,
231-
first_range_tbl, range_min_tbl, range_max_tbl,
232-
range_adjust_tbl, all_ones, all_twos, all_e0s,
233-
inputs[1], first_len);
245+
check_block_neon(inputs[0], prev_first_len, errors, first_range_tbl,
246+
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
247+
all_twos, all_e0s, inputs[1], first_len);
234248
prev_first_len = first_len;
235249
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[2], 4));
236-
check_block_neon(inputs[1], prev_first_len, errors,
237-
first_range_tbl, range_min_tbl, range_max_tbl,
238-
range_adjust_tbl, all_ones, all_twos, all_e0s,
239-
inputs[2], first_len);
250+
check_block_neon(inputs[1], prev_first_len, errors, first_range_tbl,
251+
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
252+
all_twos, all_e0s, inputs[2], first_len);
240253
prev_first_len = first_len;
241254
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
242-
check_block_neon(inputs[2], prev_first_len, errors,
243-
first_range_tbl, range_min_tbl, range_max_tbl,
244-
range_adjust_tbl, all_ones, all_twos, all_e0s,
245-
inputs[3], first_len);
255+
check_block_neon(inputs[2], prev_first_len, errors, first_range_tbl,
256+
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
257+
all_twos, all_e0s, inputs[3], first_len);
246258
prev_first_len = first_len;
247259
}
248260
// Set prev_input based on last block.
@@ -260,19 +272,17 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
260272
vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3);
261273
// We cast this pointer to avoid a redundant check against < 127, as any such
262274
// value would be negative in signed form.
263-
int8_t const * token_ptr = (int8_t const *)&token;
275+
int8_t const *token_ptr = (int8_t const *)&token;
264276
ptrdiff_t lookahead = 0;
265277
if (token_ptr[3] > (int8_t)0xBF) {
266278
lookahead = 1;
267-
}
268-
else if (token_ptr[2] > (int8_t)0xBF) {
279+
} else if (token_ptr[2] > (int8_t)0xBF) {
269280
lookahead = 2;
270-
}
271-
else if (token_ptr[1] > (int8_t)0xBF) {
281+
} else if (token_ptr[1] > (int8_t)0xBF) {
272282
lookahead = 3;
273283
}
274284
// Finish the job.
275-
uint8_t const * const small_ptr = ptr - lookahead;
285+
uint8_t const *const small_ptr = ptr - lookahead;
276286
size_t const small_len = remaining + lookahead;
277287
return is_valid_utf8_fallback(small_ptr, small_len);
278288
}

0 commit comments

Comments
 (0)