Skip to content

Commit 9ffa065

Browse files
authored
Improvements to UTF-8 statistics truncation (#6870)
* fix a few edge cases with utf-8 incrementing * add todo * simplify truncation * add another test * note case where string should render right to left * rework entirely, also avoid UTF8 processing if not required by the schema * more consistent naming * modify some tests to truncate in the middle of a multibyte char * add test and docstring * document truncate_min_value too
1 parent fc6936a commit 9ffa065

File tree

1 file changed

+236
-57
lines changed
  • parquet/src/column/writer

1 file changed

+236
-57
lines changed

parquet/src/column/writer/mod.rs

Lines changed: 236 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -878,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
878878
}
879879
}
880880

881+
/// Returns `true` if this column's logical type is a UTF-8 string.
882+
fn is_utf8(&self) -> bool {
883+
self.get_descriptor().logical_type() == Some(LogicalType::String)
884+
|| self.get_descriptor().converted_type() == ConvertedType::UTF8
885+
}
886+
887+
/// Truncates a binary statistic to at most `truncation_length` bytes.
888+
///
889+
/// If truncation is not possible, returns `data`.
890+
///
891+
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
892+
///
893+
/// UTF-8 Note:
894+
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
895+
/// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting
896+
/// on non-character boundaries.
881897
fn truncate_min_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
882898
truncation_length
883899
.filter(|l| data.len() > *l)
884-
.and_then(|l| match str::from_utf8(data) {
885-
Ok(str_data) => truncate_utf8(str_data, l),
886-
Err(_) => Some(data[..l].to_vec()),
887-
})
900+
.and_then(|l|
901+
// don't do extra work if this column isn't UTF-8
902+
if self.is_utf8() {
903+
match str::from_utf8(data) {
904+
Ok(str_data) => truncate_utf8(str_data, l),
905+
Err(_) => Some(data[..l].to_vec()),
906+
}
907+
} else {
908+
Some(data[..l].to_vec())
909+
}
910+
)
888911
.map(|truncated| (truncated, true))
889912
.unwrap_or_else(|| (data.to_vec(), false))
890913
}
891914

915+
/// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the
916+
/// final byte(s) to yield a valid upper bound. This may result in a result of less than
917+
/// `truncation_length` bytes if the last byte(s) overflows.
918+
///
919+
/// If truncation is not possible, returns `data`.
920+
///
921+
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
922+
///
923+
/// UTF-8 Note:
924+
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
925+
/// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data`
926+
/// does not contain valid UTF-8, then truncation will occur as if the column is non-string
927+
/// binary.
892928
fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
893929
truncation_length
894930
.filter(|l| data.len() > *l)
895-
.and_then(|l| match str::from_utf8(data) {
896-
Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8),
897-
Err(_) => increment(data[..l].to_vec()),
898-
})
931+
.and_then(|l|
932+
// don't do extra work if this column isn't UTF-8
933+
if self.is_utf8() {
934+
match str::from_utf8(data) {
935+
Ok(str_data) => truncate_and_increment_utf8(str_data, l),
936+
Err(_) => increment(data[..l].to_vec()),
937+
}
938+
} else {
939+
increment(data[..l].to_vec())
940+
}
941+
)
899942
.map(|truncated| (truncated, true))
900943
.unwrap_or_else(|| (data.to_vec(), false))
901944
}
@@ -1418,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool {
14181461
(a[1..]) > (b[1..])
14191462
}
14201463

1421-
/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string,
1422-
/// while being less than `length` bytes and non-empty
1464+
/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string,
1465+
/// while being less than `length` bytes and non-empty. Returns `None` if truncation
1466+
/// is not possible within those constraints.
1467+
///
1468+
/// The caller guarantees that data.len() > length.
14231469
fn truncate_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
14241470
let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?;
14251471
Some(data.as_bytes()[..split].to_vec())
14261472
}
14271473

1474+
/// Truncate a UTF-8 slice and increment it's final character. The returned value is the
1475+
/// longest such slice that is still a valid UTF-8 string while being less than `length`
1476+
/// bytes and non-empty. Returns `None` if no such transformation is possible.
1477+
///
1478+
/// The caller guarantees that data.len() > length.
1479+
fn truncate_and_increment_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
1480+
// UTF-8 is max 4 bytes, so start search 3 back from desired length
1481+
let lower_bound = length.saturating_sub(3);
1482+
let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?;
1483+
increment_utf8(data.get(..split)?)
1484+
}
1485+
1486+
/// Increment the final character in a UTF-8 string in such a way that the returned result
1487+
/// is still a valid UTF-8 string. The returned string may be shorter than the input if the
1488+
/// last character(s) cannot be incremented (due to overflow or producing invalid code points).
1489+
/// Returns `None` if the string cannot be incremented.
1490+
///
1491+
/// Note that this implementation will not promote an N-byte code point to (N+1) bytes.
1492+
fn increment_utf8(data: &str) -> Option<Vec<u8>> {
1493+
for (idx, original_char) in data.char_indices().rev() {
1494+
let original_len = original_char.len_utf8();
1495+
if let Some(next_char) = char::from_u32(original_char as u32 + 1) {
1496+
// do not allow increasing byte width of incremented char
1497+
if next_char.len_utf8() == original_len {
1498+
let mut result = data.as_bytes()[..idx + original_len].to_vec();
1499+
next_char.encode_utf8(&mut result[idx..]);
1500+
return Some(result);
1501+
}
1502+
}
1503+
}
1504+
1505+
None
1506+
}
1507+
14281508
/// Try and increment the bytes from right to left.
14291509
///
14301510
/// Returns `None` if all bytes are set to `u8::MAX`.
@@ -1441,29 +1521,15 @@ fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {
14411521
None
14421522
}
14431523

1444-
/// Try and increment the the string's bytes from right to left, returning when the result
1445-
/// is a valid UTF8 string. Returns `None` when it can't increment any byte.
1446-
fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
1447-
for idx in (0..data.len()).rev() {
1448-
let original = data[idx];
1449-
let (byte, overflow) = original.overflowing_add(1);
1450-
if !overflow {
1451-
data[idx] = byte;
1452-
if str::from_utf8(&data).is_ok() {
1453-
return Some(data);
1454-
}
1455-
data[idx] = original;
1456-
}
1457-
}
1458-
1459-
None
1460-
}
1461-
14621524
#[cfg(test)]
14631525
mod tests {
1464-
use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH;
1526+
use crate::{
1527+
file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter},
1528+
schema::parser::parse_message_type,
1529+
};
1530+
use core::str;
14651531
use rand::distributions::uniform::SampleUniform;
1466-
use std::sync::Arc;
1532+
use std::{fs::File, sync::Arc};
14671533

14681534
use crate::column::{
14691535
page::PageReader,
@@ -3140,39 +3206,69 @@ mod tests {
31403206

31413207
#[test]
31423208
fn test_increment_utf8() {
3209+
let test_inc = |o: &str, expected: &str| {
3210+
if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) {
3211+
// Got the expected result...
3212+
assert_eq!(v, expected);
3213+
// and it's greater than the original string
3214+
assert!(*v > *o);
3215+
// Also show that BinaryArray level comparison works here
3216+
let mut greater = ByteArray::new();
3217+
greater.set_data(Bytes::from(v));
3218+
let mut original = ByteArray::new();
3219+
original.set_data(Bytes::from(o.as_bytes().to_vec()));
3220+
assert!(greater > original);
3221+
} else {
3222+
panic!("Expected incremented UTF8 string to also be valid.");
3223+
}
3224+
};
3225+
31433226
// Basic ASCII case
3144-
let v = increment_utf8("hello".as_bytes().to_vec()).unwrap();
3145-
assert_eq!(&v, "hellp".as_bytes());
3227+
test_inc("hello", "hellp");
3228+
3229+
// 1-byte ending in max 1-byte
3230+
test_inc("a\u{7f}", "b");
31463231

3147-
// Also show that BinaryArray level comparison works here
3148-
let mut greater = ByteArray::new();
3149-
greater.set_data(Bytes::from(v));
3150-
let mut original = ByteArray::new();
3151-
original.set_data(Bytes::from("hello".as_bytes().to_vec()));
3152-
assert!(greater > original);
3232+
// 1-byte max should not truncate as it would need 2-byte code points
3233+
assert!(increment_utf8("\u{7f}\u{7f}").is_none());
31533234

31543235
// UTF8 string
3155-
let s = "❤️🧡💛💚💙💜";
3156-
let v = increment_utf8(s.as_bytes().to_vec()).unwrap();
3236+
test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝");
31573237

3158-
if let Ok(new) = String::from_utf8(v) {
3159-
assert_ne!(&new, s);
3160-
assert_eq!(new, "❤️🧡💛💚💙💝");
3161-
assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap());
3162-
} else {
3163-
panic!("Expected incremented UTF8 string to also be valid.")
3164-
}
3238+
// 2-byte without overflow
3239+
test_inc("éééé", "éééê");
31653240

3166-
// Max UTF8 character - should be a No-Op
3167-
let s = char::MAX.to_string();
3168-
assert_eq!(s.len(), 4);
3169-
let v = increment_utf8(s.as_bytes().to_vec());
3170-
assert!(v.is_none());
3241+
// 2-byte that overflows lowest byte
3242+
test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}");
3243+
3244+
// 2-byte ending in max 2-byte
3245+
test_inc("a\u{7ff}", "b");
3246+
3247+
// Max 2-byte should not truncate as it would need 3-byte code points
3248+
assert!(increment_utf8("\u{7ff}\u{7ff}").is_none());
3249+
3250+
// 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these
3251+
// characters should render right to left).
3252+
test_inc("ࠀࠀ", "ࠀࠁ");
3253+
3254+
// 3-byte ending in max 3-byte
3255+
test_inc("a\u{ffff}", "b");
3256+
3257+
// Max 3-byte should not truncate as it would need 4-byte code points
3258+
assert!(increment_utf8("\u{ffff}\u{ffff}").is_none());
31713259

3172-
// Handle multi-byte UTF8 characters
3173-
let s = "a\u{10ffff}";
3174-
let v = increment_utf8(s.as_bytes().to_vec());
3175-
assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes());
3260+
// 4-byte without overflow
3261+
test_inc("𐀀𐀀", "𐀀𐀁");
3262+
3263+
// 4-byte ending in max unicode
3264+
test_inc("a\u{10ffff}", "b");
3265+
3266+
// Max 4-byte should not truncate
3267+
assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none());
3268+
3269+
// Skip over surrogate pair range (0xD800..=0xDFFF)
3270+
//test_inc("a\u{D7FF}", "a\u{e000}");
3271+
test_inc("a\u{D7FF}", "b");
31763272
}
31773273

31783274
#[test]
@@ -3182,7 +3278,6 @@ mod tests {
31823278
let r = truncate_utf8(data, data.as_bytes().len()).unwrap();
31833279
assert_eq!(r.len(), data.as_bytes().len());
31843280
assert_eq!(&r, data.as_bytes());
3185-
println!("len is {}", data.len());
31863281

31873282
// We slice it away from the UTF8 boundary
31883283
let r = truncate_utf8(data, 13).unwrap();
@@ -3192,6 +3287,90 @@ mod tests {
31923287
// One multi-byte code point, and a length shorter than it, so we can't slice it
31933288
let r = truncate_utf8("\u{0836}", 1);
31943289
assert!(r.is_none());
3290+
3291+
// Test truncate and increment for max bounds on UTF-8 statistics
3292+
// 7-bit (i.e. ASCII)
3293+
let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap();
3294+
assert_eq!(&r, "yyyyyyyz".as_bytes());
3295+
3296+
// 2-byte without overflow
3297+
let r = truncate_and_increment_utf8("ééééé", 7).unwrap();
3298+
assert_eq!(&r, "ééê".as_bytes());
3299+
3300+
// 2-byte that overflows lowest byte
3301+
let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap();
3302+
assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes());
3303+
3304+
// max 2-byte should not truncate as it would need 3-byte code points
3305+
let r = truncate_and_increment_utf8("߿߿߿߿߿", 8);
3306+
assert!(r.is_none());
3307+
3308+
// 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these
3309+
// characters should render right to left).
3310+
let r = truncate_and_increment_utf8("ࠀࠀࠀࠀ", 8).unwrap();
3311+
assert_eq!(&r, "ࠀࠁ".as_bytes());
3312+
3313+
// max 3-byte should not truncate as it would need 4-byte code points
3314+
let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8);
3315+
assert!(r.is_none());
3316+
3317+
// 4-byte without overflow
3318+
let r = truncate_and_increment_utf8("𐀀𐀀𐀀𐀀", 9).unwrap();
3319+
assert_eq!(&r, "𐀀𐀁".as_bytes());
3320+
3321+
// max 4-byte should not truncate
3322+
let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8);
3323+
assert!(r.is_none());
3324+
}
3325+
3326+
#[test]
3327+
// Check fallback truncation of statistics that should be UTF-8, but aren't
3328+
// (see https://github.com/apache/arrow-rs/pull/6870).
3329+
fn test_byte_array_truncate_invalid_utf8_statistics() {
3330+
let message_type = "
3331+
message test_schema {
3332+
OPTIONAL BYTE_ARRAY a (UTF8);
3333+
}
3334+
";
3335+
let schema = Arc::new(parse_message_type(message_type).unwrap());
3336+
3337+
// Create Vec<ByteArray> containing non-UTF8 bytes
3338+
let data = vec![ByteArray::from(vec![128u8; 32]); 7];
3339+
let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1];
3340+
let file: File = tempfile::tempfile().unwrap();
3341+
let props = Arc::new(
3342+
WriterProperties::builder()
3343+
.set_statistics_enabled(EnabledStatistics::Chunk)
3344+
.set_statistics_truncate_length(Some(8))
3345+
.build(),
3346+
);
3347+
3348+
let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap();
3349+
let mut row_group_writer = writer.next_row_group().unwrap();
3350+
3351+
let mut col_writer = row_group_writer.next_column().unwrap().unwrap();
3352+
col_writer
3353+
.typed::<ByteArrayType>()
3354+
.write_batch(&data, Some(&def_levels), None)
3355+
.unwrap();
3356+
col_writer.close().unwrap();
3357+
row_group_writer.close().unwrap();
3358+
let file_metadata = writer.close().unwrap();
3359+
assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some());
3360+
let stats = file_metadata.row_groups[0].columns[0]
3361+
.meta_data
3362+
.as_ref()
3363+
.unwrap()
3364+
.statistics
3365+
.as_ref()
3366+
.unwrap();
3367+
assert!(!stats.is_max_value_exact.unwrap());
3368+
// Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should
3369+
// be incremented by 1.
3370+
assert_eq!(
3371+
stats.max_value,
3372+
Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec())
3373+
);
31953374
}
31963375

31973376
#[test]

0 commit comments

Comments
 (0)