diff --git a/arrow-row/src/run.rs b/arrow-row/src/run.rs index 01f5002dc2a1..ff7c0ffe54eb 100644 --- a/arrow-row/src/run.rs +++ b/arrow-row/src/run.rs @@ -159,11 +159,44 @@ pub unsafe fn decode( mod tests { use crate::{RowConverter, SortField}; use arrow_array::cast::AsArray; - use arrow_array::types::{Int16Type, Int32Type, Int64Type}; + use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}; use arrow_array::{Array, Int64Array, PrimitiveArray, RunArray, StringArray}; use arrow_schema::{DataType, SortOptions}; use std::sync::Arc; + fn assert_roundtrip( + array: &RunArray, + run_end_type: DataType, + values_type: DataType, + sort_options: Option, + ) { + let sort_field = if let Some(options) = sort_options { + SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", run_end_type, false)), + Arc::new(arrow_schema::Field::new("values", values_type, true)), + ), + options, + ) + } else { + SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", run_end_type, false)), + Arc::new(arrow_schema::Field::new("values", values_type, true)), + )) + }; + + let converter = RowConverter::new(vec![sort_field]).unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(array.clone())]) + .unwrap(); + + let arrays = converter.convert_rows(&rows).unwrap(); + let result = arrays[0].as_any().downcast_ref::>().unwrap(); + + assert_eq!(array, result); + } + #[test] fn test_run_end_encoded_supports_datatype() { // Test that the RowConverter correctly supports run-end encoded arrays @@ -183,24 +216,7 @@ mod tests { let array: RunArray = RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap(); - let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int16, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)), - ))]) - .unwrap(); - - let rows = converter - .convert_columns(&[Arc::new(array.clone())]) - .unwrap(); - - let arrays = converter.convert_rows(&rows).unwrap(); - let result = arrays[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!(array.run_ends().values(), result.run_ends().values()); - assert_eq!(array.values().as_ref(), result.values().as_ref()); + assert_roundtrip(&array, DataType::Int16, DataType::Int64, None); } #[test] @@ -213,24 +229,7 @@ mod tests { let array: RunArray = RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap(); - let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)), - ))]) - .unwrap(); - - let rows = converter - .convert_columns(&[Arc::new(array.clone())]) - .unwrap(); - - let arrays = converter.convert_rows(&rows).unwrap(); - let result = arrays[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!(array.run_ends().values(), result.run_ends().values()); - assert_eq!(array.values().as_ref(), result.values().as_ref()); + assert_roundtrip(&array, DataType::Int32, DataType::Int64, None); } #[test] @@ -243,24 +242,7 @@ mod tests { let array: RunArray = RunArray::try_new(&PrimitiveArray::from(run_ends), &values).unwrap(); - let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int64, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)), - ))]) - .unwrap(); - - let rows = converter - .convert_columns(&[Arc::new(array.clone())]) - .unwrap(); - - let arrays = converter.convert_rows(&rows).unwrap(); - let result = arrays[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!(array.run_ends().values(), result.run_ends().values()); - assert_eq!(array.values().as_ref(), result.values().as_ref()); + assert_roundtrip(&array, DataType::Int64, DataType::Int64, None); } #[test] @@ -269,24 +251,7 @@ mod tests { let array: RunArray = vec!["b", "b", "a"].into_iter().collect(); - let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ))]) - .unwrap(); - - let rows = converter - .convert_columns(&[Arc::new(array.clone())]) - .unwrap(); - - let arrays = converter.convert_rows(&rows).unwrap(); - let result = arrays[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!(array.run_ends().values(), result.run_ends().values()); - assert_eq!(array.values().as_ref(), result.values().as_ref()); + assert_roundtrip(&array, DataType::Int32, DataType::Utf8, None); } #[test] @@ -297,24 +262,7 @@ mod tests { .into_iter() .collect(); - let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ))]) - .unwrap(); - - let rows = converter - .convert_columns(&[Arc::new(array.clone())]) - .unwrap(); - - let arrays = converter.convert_rows(&rows).unwrap(); - let result = arrays[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!(array.run_ends().values(), result.run_ends().values()); - assert_eq!(array.values().as_ref(), result.values().as_ref()); + assert_roundtrip(&array, DataType::Int32, DataType::Utf8, None); } #[test] @@ -331,98 +279,26 @@ mod tests { .unwrap(); // Test ascending order - let converter_asc = RowConverter::new(vec![SortField::new_with_options( - DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ), - SortOptions { + assert_roundtrip( + &run_array_asc, + DataType::Int32, + DataType::Utf8, + Some(SortOptions { descending: false, nulls_first: true, - }, - )]) - .unwrap(); - - let rows_asc = converter_asc - .convert_columns(&[Arc::new(run_array_asc.clone())]) - .unwrap(); - let arrays_asc = converter_asc.convert_rows(&rows_asc).unwrap(); - let result_asc = arrays_asc[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - // Verify round-trip correctness for ascending - assert_eq!(run_array_asc.len(), result_asc.len()); - for i in 0..run_array_asc.len() { - let orig_physical = run_array_asc.get_physical_index(i); - let result_physical = result_asc.get_physical_index(i); - - let orig_values = run_array_asc - .values() - .as_any() - .downcast_ref::() - .unwrap(); - let result_values = result_asc - .values() - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!( - orig_values.value(orig_physical), - result_values.value(result_physical), - "Ascending sort value mismatch at index {}", - i - ); - } + }), + ); // Test descending order - let converter_desc = RowConverter::new(vec![SortField::new_with_options( - DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ), - SortOptions { + assert_roundtrip( + &run_array_asc, + DataType::Int32, + DataType::Utf8, + Some(SortOptions { descending: true, nulls_first: true, - }, - )]) - .unwrap(); - - let rows_desc = converter_desc - .convert_columns(&[Arc::new(run_array_asc.clone())]) - .unwrap(); - let arrays_desc = converter_desc.convert_rows(&rows_desc).unwrap(); - let result_desc = arrays_desc[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - // Verify round-trip correctness for descending - assert_eq!(run_array_asc.len(), result_desc.len()); - for i in 0..run_array_asc.len() { - let orig_physical = run_array_asc.get_physical_index(i); - let result_physical = result_desc.get_physical_index(i); - - let orig_values = run_array_asc - .values() - .as_any() - .downcast_ref::() - .unwrap(); - let result_values = result_desc - .values() - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!( - orig_values.value(orig_physical), - result_values.value(result_physical), - "Descending sort value mismatch at index {}", - i - ); - } + }), + ); } #[test] @@ -431,44 +307,27 @@ mod tests { let test_array: RunArray = vec!["test"].into_iter().collect(); - let converter_asc = RowConverter::new(vec![SortField::new_with_options( - DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ), - SortOptions { + // Test ascending order + assert_roundtrip( + &test_array, + DataType::Int32, + DataType::Utf8, + Some(SortOptions { descending: false, nulls_first: true, - }, - )]) - .unwrap(); + }), + ); - let converter_desc = RowConverter::new(vec![SortField::new_with_options( - DataType::RunEndEncoded( - Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), - Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), - ), - SortOptions { + // Test descending order + assert_roundtrip( + &test_array, + DataType::Int32, + DataType::Utf8, + Some(SortOptions { descending: true, nulls_first: true, - }, - )]) - .unwrap(); - - let rows_test_asc = converter_asc - .convert_columns(&[Arc::new(test_array.clone())]) - .unwrap(); - let rows_test_desc = converter_desc - .convert_columns(&[Arc::new(test_array.clone())]) - .unwrap(); - - // Convert back to verify both configurations work - let result_test_asc = converter_asc.convert_rows(&rows_test_asc).unwrap(); - let result_test_desc = converter_desc.convert_rows(&rows_test_desc).unwrap(); - - // Both should successfully reconstruct the original - assert_eq!(result_test_asc.len(), 1); - assert_eq!(result_test_desc.len(), 1); + }), + ); } #[test]