Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 69 additions & 210 deletions arrow-row/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,44 @@ pub unsafe fn decode<R: RunEndIndexType>(
mod tests {
use crate::{RowConverter, SortField};
use arrow_array::cast::AsArray;
use arrow_array::types::{Int16Type, Int32Type, Int64Type};
use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
use arrow_array::{Array, Int64Array, PrimitiveArray, RunArray, StringArray};
use arrow_schema::{DataType, SortOptions};
use std::sync::Arc;

fn assert_roundtrip<R: RunEndIndexType>(
array: &RunArray<R>,
run_end_type: DataType,
values_type: DataType,
sort_options: Option<SortOptions>,
) {
let sort_field = if let Some(options) = sort_options {
SortField::new_with_options(
DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", run_end_type, false)),
Arc::new(arrow_schema::Field::new("values", values_type, true)),
),
options,
)
} else {
SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", run_end_type, false)),
Arc::new(arrow_schema::Field::new("values", values_type, true)),
))
};

let converter = RowConverter::new(vec![sort_field]).unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0].as_any().downcast_ref::<RunArray<R>>().unwrap();

assert_eq!(array, result);
}

#[test]
fn test_run_end_encoded_supports_datatype() {
// Test that the RowConverter correctly supports run-end encoded arrays
Expand All @@ -183,24 +216,7 @@ mod tests {
let array: RunArray<Int16Type> =
RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap();

let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int16, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)),
))])
.unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0]
.as_any()
.downcast_ref::<RunArray<Int16Type>>()
.unwrap();

assert_eq!(array.run_ends().values(), result.run_ends().values());
assert_eq!(array.values().as_ref(), result.values().as_ref());
assert_roundtrip(&array, DataType::Int16, DataType::Int64, None);
}

#[test]
Expand All @@ -213,24 +229,7 @@ mod tests {
let array: RunArray<Int32Type> =
RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap();

let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)),
))])
.unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0]
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.unwrap();

assert_eq!(array.run_ends().values(), result.run_ends().values());
assert_eq!(array.values().as_ref(), result.values().as_ref());
assert_roundtrip(&array, DataType::Int32, DataType::Int64, None);
}

#[test]
Expand All @@ -243,24 +242,7 @@ mod tests {
let array: RunArray<Int64Type> =
RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap();

let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int64, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)),
))])
.unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0]
.as_any()
.downcast_ref::<RunArray<Int64Type>>()
.unwrap();

assert_eq!(array.run_ends().values(), result.run_ends().values());
assert_eq!(array.values().as_ref(), result.values().as_ref());
assert_roundtrip(&array, DataType::Int64, DataType::Int64, None);
}

#[test]
Expand All @@ -269,24 +251,7 @@ mod tests {

let array: RunArray<Int32Type> = vec!["b", "b", "a"].into_iter().collect();

let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
))])
.unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0]
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.unwrap();

assert_eq!(array.run_ends().values(), result.run_ends().values());
assert_eq!(array.values().as_ref(), result.values().as_ref());
assert_roundtrip(&array, DataType::Int32, DataType::Utf8, None);
}

#[test]
Expand All @@ -297,24 +262,7 @@ mod tests {
.into_iter()
.collect();

let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
))])
.unwrap();

let rows = converter
.convert_columns(&[Arc::new(array.clone())])
.unwrap();

let arrays = converter.convert_rows(&rows).unwrap();
let result = arrays[0]
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.unwrap();

assert_eq!(array.run_ends().values(), result.run_ends().values());
assert_eq!(array.values().as_ref(), result.values().as_ref());
assert_roundtrip(&array, DataType::Int32, DataType::Utf8, None);
}

#[test]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to refactor the remaining tests too?

  • test_run_end_encoded_ascending_descending_round_trip
  • test_run_end_encoded_sort_configurations_basic,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just missed it, added! thanks for catching that!

Expand All @@ -331,98 +279,26 @@ mod tests {
.unwrap();

// Test ascending order
let converter_asc = RowConverter::new(vec![SortField::new_with_options(
DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
),
SortOptions {
assert_roundtrip(
&run_array_asc,
DataType::Int32,
DataType::Utf8,
Some(SortOptions {
descending: false,
nulls_first: true,
},
)])
.unwrap();

let rows_asc = converter_asc
.convert_columns(&[Arc::new(run_array_asc.clone())])
.unwrap();
let arrays_asc = converter_asc.convert_rows(&rows_asc).unwrap();
let result_asc = arrays_asc[0]
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.unwrap();

// Verify round-trip correctness for ascending
assert_eq!(run_array_asc.len(), result_asc.len());
for i in 0..run_array_asc.len() {
let orig_physical = run_array_asc.get_physical_index(i);
let result_physical = result_asc.get_physical_index(i);

let orig_values = run_array_asc
.values()
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let result_values = result_asc
.values()
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();

assert_eq!(
orig_values.value(orig_physical),
result_values.value(result_physical),
"Ascending sort value mismatch at index {}",
i
);
}
}),
);

// Test descending order
let converter_desc = RowConverter::new(vec![SortField::new_with_options(
DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
),
SortOptions {
assert_roundtrip(
&run_array_asc,
DataType::Int32,
DataType::Utf8,
Some(SortOptions {
descending: true,
nulls_first: true,
},
)])
.unwrap();

let rows_desc = converter_desc
.convert_columns(&[Arc::new(run_array_asc.clone())])
.unwrap();
let arrays_desc = converter_desc.convert_rows(&rows_desc).unwrap();
let result_desc = arrays_desc[0]
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.unwrap();

// Verify round-trip correctness for descending
assert_eq!(run_array_asc.len(), result_desc.len());
for i in 0..run_array_asc.len() {
let orig_physical = run_array_asc.get_physical_index(i);
let result_physical = result_desc.get_physical_index(i);

let orig_values = run_array_asc
.values()
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();
let result_values = result_desc
.values()
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.unwrap();

assert_eq!(
orig_values.value(orig_physical),
result_values.value(result_physical),
"Descending sort value mismatch at index {}",
i
);
}
}),
);
}

#[test]
Expand All @@ -431,44 +307,27 @@ mod tests {

let test_array: RunArray<Int32Type> = vec!["test"].into_iter().collect();

let converter_asc = RowConverter::new(vec![SortField::new_with_options(
DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
),
SortOptions {
// Test ascending order
assert_roundtrip(
&test_array,
DataType::Int32,
DataType::Utf8,
Some(SortOptions {
descending: false,
nulls_first: true,
},
)])
.unwrap();
}),
);

let converter_desc = RowConverter::new(vec![SortField::new_with_options(
DataType::RunEndEncoded(
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
),
SortOptions {
// Test descending order
assert_roundtrip(
&test_array,
DataType::Int32,
DataType::Utf8,
Some(SortOptions {
descending: true,
nulls_first: true,
},
)])
.unwrap();

let rows_test_asc = converter_asc
.convert_columns(&[Arc::new(test_array.clone())])
.unwrap();
let rows_test_desc = converter_desc
.convert_columns(&[Arc::new(test_array.clone())])
.unwrap();

// Convert back to verify both configurations work
let result_test_asc = converter_asc.convert_rows(&rows_test_asc).unwrap();
let result_test_desc = converter_desc.convert_rows(&rows_test_desc).unwrap();

// Both should successfully reconstruct the original
assert_eq!(result_test_asc.len(), 1);
assert_eq!(result_test_desc.len(), 1);
}),
);
}

#[test]
Expand Down
Loading