From 0ec13aeef7709cdfd2f5521a217d55053cb7de45 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 8 Apr 2025 19:20:30 -0500 Subject: [PATCH 01/16] so close to compiling --- arrow-data/src/data.rs | 2 + arrow-data/src/equal/mod.rs | 1 + arrow-data/src/transform/mod.rs | 3 ++ arrow-integration-test/src/datatype.rs | 3 ++ arrow-ipc/src/convert.rs | 22 +++++++- arrow-schema/src/datatype.rs | 5 ++ arrow-schema/src/extension/mod.rs | 57 ++++++++++++++++++++ arrow-schema/src/field.rs | 5 +- parquet/src/arrow/arrow_reader/statistics.rs | 6 ++- parquet/src/arrow/schema/mod.rs | 6 +++ 10 files changed, 105 insertions(+), 5 deletions(-) diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 4c117184de79..c1b689d932f6 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -151,6 +151,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff } } } + DataType::Extension(extension) => new_buffers(extension.storage_type(), capacity), } } @@ -1664,6 +1665,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { } } DataType::Dictionary(key_type, _value_type) => layout(key_type), + DataType::Extension(extension) => layout(extension.storage_type()), } } diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index f24179b61700..e35ac30b62f9 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -123,6 +123,7 @@ fn equal_values( DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Extension(_) => unimplemented!("Extension not implemented"), } } diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index 93b79e6a5eb8..c898bc1920f7 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -276,6 +276,7 @@ fn build_extend(array: &ArrayData) -> Extend { UnionMode::Dense => union::build_extend_dense(array), }, DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(_) => unimplemented!("Extension not implemented") } } @@ -332,6 +333,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { UnionMode::Dense => union::extend_nulls_dense, }, DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(_) => unimplemented!("ListView/LargeListView not implemented") }) } @@ -590,6 +592,7 @@ impl<'a> MutableArrayData<'a> { MutableArrayData::new(child_arrays, use_nulls, array_capacity) }) .collect::>(), + DataType::Extension(_) => unimplemented!("Extension not implemented") }; // Get the dictionary if any, and if it is a concatenation of multiple diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 24e02c8430c7..6a8a3efba067 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -345,6 +345,9 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { json!({"name": "map", "keysSorted": keys_sorted}) } DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(extension) => { + data_type_to_json(extension.storage_type()) + } } } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 79dd1726ed70..2201d09b01ec 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -514,7 +514,24 @@ pub(crate) fn build_field<'a>( ) -> WIPOffset> { // Optional custom metadata. let mut fb_metadata = None; - if !field.metadata().is_empty() { + + // Handle extension type metadata if applicable + if let DataType::Extension(extension) = field.data_type() { + let mut field_metadata = HashMap::from([ + ( + "ARROW:extension:name".to_string(), + extension.extension_name().to_string(), + ), + ( + "ARROW:extension:metadata".to_string(), + extension.serialized_metadata(), + ), + ]); + + for (k, v) in field.metadata() { + field_metadata.insert(k.clone(), v.clone()); + } + } else if !field.metadata().is_empty() { fb_metadata = Some(metadata_to_fb(fbb, field.metadata())); }; @@ -883,6 +900,9 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&children[..])), } } + DataType::Extension(extension) => { + get_fb_field_type(extension.storage_type(), dictionary_tracker, fbb) + } } } diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 5c9073c4eeb6..e668d097fa91 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -19,6 +19,7 @@ use std::fmt; use std::str::FromStr; use std::sync::Arc; +use crate::extension::DynExtensionType; use crate::{ArrowError, Field, FieldRef, Fields, UnionFields}; /// Datatypes supported by this implementation of Apache Arrow. @@ -411,6 +412,8 @@ pub enum DataType { /// These child arrays are prescribed the standard names of "run_ends" and "values" /// respectively. RunEndEncoded(FieldRef, FieldRef), + /// An ExtensionType + Extension(Arc), } /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. @@ -689,6 +692,7 @@ impl DataType { DataType::Union(_, _) => None, DataType::Dictionary(_, _) => None, DataType::RunEndEncoded(_, _) => None, + DataType::Extension(_) => None, } } @@ -740,6 +744,7 @@ impl DataType { run_ends.size() - std::mem::size_of_val(run_ends) + values.size() - std::mem::size_of_val(values) } + DataType::Extension(extension) => extension.size(), } } diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index c5119873af0c..fb2cc3e44838 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -25,6 +25,10 @@ mod canonical; pub use canonical::*; use crate::{ArrowError, DataType}; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -258,3 +262,56 @@ pub trait ExtensionType: Sized { /// this extension type. fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; } + +/// dyn-compatible ExtensionType +pub trait DynExtensionType: Debug { + /// For dyn-compatible comparison methods + fn as_any(&self) -> &dyn Any; + + /// Because DataType implements sized + fn size(&self) -> usize; + + /// Concrete storage type for this extension + fn storage_type(&self) -> &DataType; + + /// Name of the extension + fn extension_name(&self) -> &'static str; + + /// Extension metadata + fn serialized_metadata(&self) -> String; + + /// Because DataType implement Eq + fn extension_equals(&self, other: &dyn Any) -> bool; + + /// Because DataType implements Hash + fn extension_hash(&self, hasher: &dyn Hasher); + + /// Because DataType implements Ord + fn exension_cmp(&self, other: &dyn Any) -> Ordering; +} + +impl PartialEq for dyn DynExtensionType + Send + Sync { + fn eq(&self, other: &Self) -> bool { + self.extension_equals(other.as_any()) + } +} + +impl Eq for dyn DynExtensionType + Send + Sync {} + +impl Hash for dyn DynExtensionType + Send + Sync { + fn hash(&self, state: &mut H) { + self.extension_hash(state); + } +} + +impl PartialOrd for dyn DynExtensionType + Send + Sync { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for dyn DynExtensionType + Send + Sync { + fn cmp(&self, other: &Self) -> Ordering { + self.exension_cmp(other.as_any()) + } +} diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index dbd671a62a3a..cb385a59a0e5 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -727,7 +727,7 @@ impl Field { DataType::Null => { self.nullable = true; self.data_type = from.data_type.clone(); - } + }, | DataType::Boolean | DataType::Int8 | DataType::Int16 @@ -761,7 +761,8 @@ impl Field { | DataType::LargeUtf8 | DataType::Utf8View | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => { + | DataType::Decimal256(_, _) + | DataType::Extension(_) => { if from.data_type == DataType::Null { self.nullable = true; } else if self.data_type != from.data_type { diff --git a/parquet/src/arrow/arrow_reader/statistics.rs b/parquet/src/arrow/arrow_reader/statistics.rs index 09f8ec7cc274..26c17afad0e6 100644 --- a/parquet/src/arrow/arrow_reader/statistics.rs +++ b/parquet/src/arrow/arrow_reader/statistics.rs @@ -535,7 +535,8 @@ macro_rules! get_statistics { DataType::LargeListView(_) | DataType::Struct(_) | DataType::Union(_, _) | - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(_, _) | + DataType::Extension(_) => { let len = $iterator.count(); // don't know how to extract statistics, so return a null array Ok(new_null_array($data_type, len)) @@ -1056,7 +1057,8 @@ macro_rules! get_data_page_statistics { DataType::Struct(_) | DataType::Union(_, _) | DataType::Map(_, _) | - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(_, _) | + DataType::Extension(_) => { let len = $iterator.count(); // don't know how to extract statistics, so return a null array Ok(new_null_array($data_type, len)) diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 89c42f5eaf92..2519145c9a66 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -767,6 +767,12 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { DataType::RunEndEncoded(_, _) => Err(arrow_err!( "Converting RunEndEncodedType to parquet not supported", )), + DataType::Extension(extension) => arrow_to_parquet_type( + &field + .clone() + .with_data_type(extension.storage_type().clone()), + coerce_types, + ), } } From a7e3ec60014be95f2586de5618a5a79a1c9b1d23 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 8 Apr 2025 21:55:14 -0500 Subject: [PATCH 02/16] let arrow csv compile --- arrow-schema/src/extension/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index fb2cc3e44838..285cc42b44dc 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -29,6 +29,7 @@ use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; use std::hash::{Hash, Hasher}; +use std::panic::RefUnwindSafe; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -264,7 +265,7 @@ pub trait ExtensionType: Sized { } /// dyn-compatible ExtensionType -pub trait DynExtensionType: Debug { +pub trait DynExtensionType: Debug + RefUnwindSafe { /// For dyn-compatible comparison methods fn as_any(&self) -> &dyn Any; From e01cc11827a65a8d87a5467124f5fb7759c2f745 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 8 Apr 2025 23:00:49 -0500 Subject: [PATCH 03/16] don't serde Extension --- arrow-schema/src/datatype.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index e668d097fa91..87cdd6056540 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -413,6 +413,7 @@ pub enum DataType { /// respectively. RunEndEncoded(FieldRef, FieldRef), /// An ExtensionType + #[cfg_attr(feature = "serde", serde(skip))] Extension(Arc), } From 7ddf616657fafa32c64b8cd0d66b9a56181d9f71 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 8 Apr 2025 23:03:22 -0500 Subject: [PATCH 04/16] format --- arrow-data/src/transform/mod.rs | 6 +++--- arrow-integration-test/src/datatype.rs | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index c898bc1920f7..5616e9efea4e 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -276,7 +276,7 @@ fn build_extend(array: &ArrayData) -> Extend { UnionMode::Dense => union::build_extend_dense(array), }, DataType::RunEndEncoded(_, _) => todo!(), - DataType::Extension(_) => unimplemented!("Extension not implemented") + DataType::Extension(_) => unimplemented!("Extension not implemented"), } } @@ -333,7 +333,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { UnionMode::Dense => union::extend_nulls_dense, }, DataType::RunEndEncoded(_, _) => todo!(), - DataType::Extension(_) => unimplemented!("ListView/LargeListView not implemented") + DataType::Extension(_) => unimplemented!("ListView/LargeListView not implemented"), }) } @@ -592,7 +592,7 @@ impl<'a> MutableArrayData<'a> { MutableArrayData::new(child_arrays, use_nulls, array_capacity) }) .collect::>(), - DataType::Extension(_) => unimplemented!("Extension not implemented") + DataType::Extension(_) => unimplemented!("Extension not implemented"), }; // Get the dictionary if any, and if it is a concatenation of multiple diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 6a8a3efba067..aa41f24ff3a6 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -345,9 +345,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { json!({"name": "map", "keysSorted": keys_sorted}) } DataType::RunEndEncoded(_, _) => todo!(), - DataType::Extension(extension) => { - data_type_to_json(extension.storage_type()) - } + DataType::Extension(extension) => data_type_to_json(extension.storage_type()), } } From 396442a0f3e7331ba2b6438983ce2256d2f60325 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 9 Apr 2025 22:17:39 -0500 Subject: [PATCH 05/16] demo scoping + opt-in import --- arrow-ipc/src/convert.rs | 21 ++++++- arrow-ipc/src/reader.rs | 22 +++++++- arrow-schema/src/extension/mod.rs | 93 ++++++++++++++++++++++++++++++- 3 files changed, 129 insertions(+), 7 deletions(-) diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 2201d09b01ec..8f4da1de6a7b 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -18,6 +18,7 @@ //! Utilities for converting between IPC types and native Arrow types use arrow_buffer::Buffer; +use arrow_schema::extension::DynExtensionTypeFactory; use arrow_schema::*; use flatbuffers::{ FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, @@ -194,8 +195,13 @@ impl From> for Field { } } -/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema]. +/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] pub fn fb_to_schema(fb: crate::Schema) -> Schema { + fb_to_schema_with_extension_factory(fb, None).unwrap() +} + +/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] with extension support +pub fn fb_to_schema_with_extension_factory(fb: crate::Schema, extension_factory: Option<&dyn DynExtensionTypeFactory>) -> Result { let mut fields: Vec = vec![]; let c_fields = fb.fields().unwrap(); let len = c_fields.len(); @@ -207,7 +213,15 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema { } _ => (), }; - fields.push(c_field.into()); + let field: Field = c_field.into(); + if let Some(factory) = extension_factory { + if let Some(extension) = factory.make_from_field(&field)? { + fields.push(field.clone().with_data_type(DataType::Extension(extension))); + continue; + } + } + + fields.push(field); } let mut metadata: HashMap = HashMap::default(); @@ -224,7 +238,8 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema { } } } - Schema::new_with_metadata(fields, metadata) + + Ok(Schema::new_with_metadata(fields, metadata)) } /// Try deserialize flat buffer format bytes into a schema diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 83dc5702dc94..6e084f3d5e60 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -26,6 +26,7 @@ mod stream; +use arrow_schema::extension::DynExtensionTypeFactory; pub use stream::*; use flatbuffers::{VectorIter, VerifierOptions}; @@ -229,6 +230,12 @@ impl RecordBatchDecoder<'_> { .offset(0); self.create_array_from_builder(builder) } + Extension(extension) => self.create_array( + &field + .clone() + .with_data_type(extension.storage_type().clone()), + variadic_counts, + ), _ => { let field_node = self.next_node(field)?; let buffers = [self.next_buffer()?, self.next_buffer()?]; @@ -1173,7 +1180,7 @@ impl FileReader { /// Try to create a new file reader. /// /// There is no internal buffering. If buffered reads are needed you likely want to use - /// [`FileReader::try_new_buffered`] instead. + /// [`FileReader::try_new_buffered`] instead. /// /// # Errors /// @@ -1364,8 +1371,17 @@ impl StreamReader { /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema /// as the first message in the stream. pub fn try_new( + reader: R, + projection: Option>, + ) -> Result, ArrowError> { + Self::try_new_with_extension_factory(reader, projection, None) + } + + /// Create a stream reader with an extension factory + pub fn try_new_with_extension_factory( mut reader: R, projection: Option>, + extension_factory: Option<&dyn DynExtensionTypeFactory>, ) -> Result, ArrowError> { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -1389,7 +1405,9 @@ impl StreamReader { let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| { ArrowError::ParseError("Unable to read IPC message as schema".to_string()) })?; - let schema = crate::convert::fb_to_schema(ipc_schema); + + let schema = + crate::convert::fb_to_schema_with_extension_factory(ipc_schema, extension_factory)?; // Create an array of optional dictionary value arrays, one per field. let dictionaries_by_id = HashMap::new(); diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 285cc42b44dc..7659f23d73c2 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -24,12 +24,13 @@ mod canonical; #[cfg(feature = "canonical_extension_types")] pub use canonical::*; -use crate::{ArrowError, DataType}; +use crate::{ArrowError, DataType, Field}; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::panic::RefUnwindSafe; +use std::sync::Arc; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -285,7 +286,7 @@ pub trait DynExtensionType: Debug + RefUnwindSafe { fn extension_equals(&self, other: &dyn Any) -> bool; /// Because DataType implements Hash - fn extension_hash(&self, hasher: &dyn Hasher); + fn extension_hash(&self, hasher: &mut dyn Hasher); /// Because DataType implements Ord fn exension_cmp(&self, other: &dyn Any) -> Ordering; @@ -316,3 +317,91 @@ impl Ord for dyn DynExtensionType + Send + Sync { self.exension_cmp(other.as_any()) } } + +/// A way to create extension types for places where they might be imported +pub trait DynExtensionTypeFactory { + /// Create an extension type from name, storage type, and metadata + fn make_extension_type( + &self, + extension_name: &str, + storage_type: &DataType, + extension_metadata: Option<&String>, + ) -> Result>, ArrowError>; + + /// Create an extension type from a field + fn make_from_field(&self, field: &Field) -> Result>, ArrowError> { + if let Some(extension_name) = field.metadata().get("ARROW:extension:name") { + self.make_extension_type( + extension_name, + field.data_type(), + field.metadata().get("ARROW:extension:metadata"), + ) + } else { + Ok(None) + } + } +} + +/// Simple factory with registered types +pub struct CanonicalExtensionTypeFactory {} + +#[cfg(feature = "canonical_extension_types")] +impl DynExtensionType for Uuid { + fn as_any(&self) -> &dyn Any { + self + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn storage_type(&self) -> &DataType { + &DataType::FixedSizeBinary(16) + } + + fn extension_name(&self) -> &'static str { + Self::NAME + } + + fn serialized_metadata(&self) -> String { + "".to_string() + } + + fn extension_equals(&self, other: &dyn Any) -> bool { + other.downcast_ref::().is_some() + } + + fn extension_hash(&self, hasher: &mut dyn Hasher) { + hasher.write("arrow.uuid".as_bytes()); + } + + fn exension_cmp(&self, other: &dyn Any) -> Ordering { + if self.extension_equals(other) { + Ordering::Equal + } else { + // Fishy... + Ordering::Less + } + } +} + +#[cfg(feature = "canonical_extension_types")] +impl DynExtensionTypeFactory for CanonicalExtensionTypeFactory { + fn make_extension_type( + &self, + extension_name: &str, + storage_type: &DataType, + extension_metadata: Option<&String>, + ) -> Result>, ArrowError> { + match extension_name { + "arrow.uuid" => { + let uuid = Uuid::try_new( + storage_type, + Uuid::deserialize_metadata(extension_metadata.map(|s| s.as_str()))?, + )?; + Ok(Some(Arc::new(uuid))) + } + _ => Ok(None), + } + } +} From 1573540c66f714ea2361bf83431c5393087da2bb Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 15:24:41 -0500 Subject: [PATCH 06/16] format --- arrow-ipc/src/convert.rs | 5 ++++- arrow-schema/src/extension/mod.rs | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 8f4da1de6a7b..40e10362111c 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -201,7 +201,10 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema { } /// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] with extension support -pub fn fb_to_schema_with_extension_factory(fb: crate::Schema, extension_factory: Option<&dyn DynExtensionTypeFactory>) -> Result { +pub fn fb_to_schema_with_extension_factory( + fb: crate::Schema, + extension_factory: Option<&dyn DynExtensionTypeFactory>, +) -> Result { let mut fields: Vec = vec![]; let c_fields = fb.fields().unwrap(); let len = c_fields.len(); diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 7659f23d73c2..ed62dae7a7fc 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -329,7 +329,10 @@ pub trait DynExtensionTypeFactory { ) -> Result>, ArrowError>; /// Create an extension type from a field - fn make_from_field(&self, field: &Field) -> Result>, ArrowError> { + fn make_from_field( + &self, + field: &Field, + ) -> Result>, ArrowError> { if let Some(extension_name) = field.metadata().get("ARROW:extension:name") { self.make_extension_type( extension_name, From 9b3620fce2e9c12bed879019c54abf38fcefa60e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 16:06:49 -0500 Subject: [PATCH 07/16] plausible array --- arrow-array/src/array/extension_array.rs | 93 ++++++++++++++++++++++++ arrow-array/src/array/mod.rs | 4 + 2 files changed, 97 insertions(+) create mode 100644 arrow-array/src/array/extension_array.rs diff --git a/arrow-array/src/array/extension_array.rs b/arrow-array/src/array/extension_array.rs new file mode 100644 index 000000000000..47fd9932459d --- /dev/null +++ b/arrow-array/src/array/extension_array.rs @@ -0,0 +1,93 @@ +use std::{any::Any, sync::Arc}; + +use arrow_data::ArrayData; +use arrow_schema::{extension::DynExtensionType, ArrowError, DataType}; + +use super::{make_array, Array, ArrayRef}; + +/// Array type for DataType::Extension +#[derive(Debug)] +pub struct ExtensionArray { + data_type: DataType, + storage: ArrayRef, +} + +impl ExtensionArray { + /// Try to create a new ExtensionArray + pub fn try_new( + extension: Arc, + storage: ArrayRef, + ) -> Result { + Ok(Self { + data_type: DataType::Extension(extension), + storage, + }) + } + + /// Return the underlying storage array + pub fn storage(&self) -> &ArrayRef { + &self.storage + } +} + +impl From for ExtensionArray { + fn from(data: ArrayData) -> Self { + if let DataType::Extension(_) = data.data_type() { + Self { + data_type: data.data_type().clone(), + storage: Arc::new(make_array(data)) as ArrayRef, + } + } else { + panic!("{} is not Extension", data.data_type()) + } + } +} + +impl Array for ExtensionArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.storage.to_data() + } + + fn into_data(self) -> ArrayData { + self.storage.to_data() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(Self { + data_type: self.data_type.clone(), + storage: self.storage.slice(offset, length), + }) + } + + fn len(&self) -> usize { + self.storage.len() + } + + fn is_empty(&self) -> bool { + self.storage.is_empty() + } + + fn offset(&self) -> usize { + self.storage.offset() + } + + fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> { + self.storage.nulls() + } + + fn get_buffer_memory_size(&self) -> usize { + self.storage.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.storage.get_array_memory_size() + } +} diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index e41a3a1d719a..a5894f47e401 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -76,6 +76,9 @@ mod list_view_array; pub use list_view_array::*; +mod extension_array; +pub use extension_array::*; + use crate::iterator::ArrayIter; /// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) @@ -829,6 +832,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, + DataType::Extension(_) => Arc::new(ExtensionArray::from(data)) as ArrayRef, dt => panic!("Unexpected data type {dt:?}"), } } From 3d05ba88c1ca89dca84c65e3d8249aa5a494bbc0 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 17:19:28 -0500 Subject: [PATCH 08/16] one kernel --- arrow-array/src/array/extension_array.rs | 41 ++++++++++++++++++--- arrow-schema/src/extension/mod.rs | 46 ++++++++++++++++++++++++ arrow-select/src/filter.rs | 29 +++++++++++++++ 3 files changed, 112 insertions(+), 4 deletions(-) diff --git a/arrow-array/src/array/extension_array.rs b/arrow-array/src/array/extension_array.rs index 47fd9932459d..b2f8cb7e2dd4 100644 --- a/arrow-array/src/array/extension_array.rs +++ b/arrow-array/src/array/extension_array.rs @@ -24,18 +24,42 @@ impl ExtensionArray { }) } + /// Create a new ExtensionArray + pub fn new(extension: Arc, storage: ArrayRef) -> Self { + Self::try_new(extension, storage).unwrap() + } + /// Return the underlying storage array pub fn storage(&self) -> &ArrayRef { &self.storage } + + /// Return a new array with new storage of the same type + pub fn with_storage(&self, new_storage: ArrayRef) -> Self { + assert_eq!(new_storage.data_type(), new_storage.data_type()); + Self { + data_type: self.data_type.clone(), + storage: new_storage, + } + } } impl From for ExtensionArray { fn from(data: ArrayData) -> Self { - if let DataType::Extension(_) = data.data_type() { + if let DataType::Extension(extension) = data.data_type() { + let storage_data = ArrayData::try_new( + extension.storage_type().clone(), + data.len(), + data.nulls().map(|b| b.buffer()).cloned(), + data.offset(), + data.buffers().to_vec(), + data.child_data().to_vec(), + ) + .unwrap(); + Self { data_type: data.data_type().clone(), - storage: Arc::new(make_array(data)) as ArrayRef, + storage: Arc::new(make_array(storage_data)) as ArrayRef, } } else { panic!("{} is not Extension", data.data_type()) @@ -49,11 +73,20 @@ impl Array for ExtensionArray { } fn to_data(&self) -> ArrayData { - self.storage.to_data() + let storage_data = self.storage.to_data(); + ArrayData::try_new( + self.data_type.clone(), + storage_data.len(), + storage_data.nulls().map(|b| b.buffer()).cloned(), + storage_data.offset(), + storage_data.buffers().to_vec(), + storage_data.child_data().to_vec(), + ) + .unwrap() } fn into_data(self) -> ArrayData { - self.storage.to_data() + self.to_data() } fn data_type(&self) -> &DataType { diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index ed62dae7a7fc..1dfd35363aef 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -408,3 +408,49 @@ impl DynExtensionTypeFactory for CanonicalExtensionTypeFactory { } } } + +/// Extension for tests +#[derive(Debug)] +pub struct TextExtension { + /// Arbitrary storage type + pub storage_type: DataType, +} + +impl DynExtensionType for TextExtension { + fn as_any(&self) -> &dyn Any { + self + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn storage_type(&self) -> &DataType { + &self.storage_type + } + + fn extension_name(&self) -> &'static str { + "arrow.rs.test" + } + + fn serialized_metadata(&self) -> String { + "".to_string() + } + + fn extension_equals(&self, other: &dyn Any) -> bool { + other.downcast_ref::().is_some() + } + + fn extension_hash(&self, hasher: &mut dyn Hasher) { + hasher.write("arrow.rs.test".as_bytes()); + } + + fn exension_cmp(&self, other: &dyn Any) -> Ordering { + if self.extension_equals(other) { + Ordering::Equal + } else { + // Fishy... + Ordering::Less + } + } +} diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 7bb140d37f51..bdbb27de6fed 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -393,6 +393,11 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?)) } + DataType::Extension(_) => { + let extension_array: ExtensionArray = values.to_data().into(); + let storage_result = filter_array(extension_array.storage(), predicate)?; + Ok(Arc::new(extension_array.with_storage(storage_result))) + } _ => { let data = values.to_data(); // fallback to using MutableArrayData @@ -864,6 +869,7 @@ mod tests { use arrow_array::builder::*; use arrow_array::cast::as_run_array; use arrow_array::types::*; + use arrow_schema::extension::TextExtension; use rand::distr::uniform::{UniformSampler, UniformUsize}; use rand::distr::{Alphanumeric, StandardUniform}; use rand::prelude::*; @@ -2045,4 +2051,27 @@ mod tests { assert_eq!(result.to_data(), expected.to_data()); } + + #[test] + fn test_filter_extension() { + let predicate = BooleanArray::from(vec![true, false, true, false]); + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let array = ExtensionArray::new( + Arc::new(TextExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + let result_ref = filter(&array, &predicate).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 2); + + let result_array: ExtensionArray = array.to_data().into(); + assert_eq!(result_array.storage().to_data(), storage.to_data()); + } } From f63828e3a5196c24182ef80a15050d8026999012 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 17:21:19 -0500 Subject: [PATCH 09/16] rat --- arrow-array/src/array/extension_array.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/arrow-array/src/array/extension_array.rs b/arrow-array/src/array/extension_array.rs index b2f8cb7e2dd4..f1cdaf5b6d48 100644 --- a/arrow-array/src/array/extension_array.rs +++ b/arrow-array/src/array/extension_array.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use std::{any::Any, sync::Arc}; use arrow_data::ArrayData; From 19b5976f7aa957b0857aa07569a1701323345839 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 20:25:41 -0500 Subject: [PATCH 10/16] add interleave --- arrow-select/src/filter.rs | 5 +++-- arrow-select/src/interleave.rs | 38 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index bdbb27de6fed..94480fbbe2a7 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -2071,7 +2071,8 @@ mod tests { assert_eq!(result_ref.data_type(), array.data_type()); assert_eq!(result_ref.len(), 2); - let result_array: ExtensionArray = array.to_data().into(); - assert_eq!(result_array.storage().to_data(), storage.to_data()); + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["one banana", "three banana"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); } } diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index 5fc019da78f1..d3c7a36c0b28 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -93,6 +93,16 @@ pub fn interleave( return Ok(new_empty_array(data_type)); } + if let DataType::Extension(extension) = data_type { + let storage: Vec<_> = values.iter().map(|array| { + let extension_array: ExtensionArray = array.to_data().into(); + extension_array.storage().clone() + }).collect(); + let storage_ref: Vec<_> = storage.iter().map(|array| array.as_ref()).collect(); + let storage_result = interleave(&storage_ref, indices)?; + return Ok(Arc::new(ExtensionArray::new(extension.clone(), storage_result))); + } + downcast_primitive! { data_type => (primitive_helper, values, indices, data_type), DataType::Utf8 => interleave_bytes::(values, indices), @@ -369,6 +379,7 @@ pub fn interleave_record_batch( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; + use arrow_schema::extension::TextExtension; #[test] fn test_primitive() { @@ -729,4 +740,31 @@ mod tests { ] ); } + + + #[test] + fn test_interleave_extension() { + let indices = [(0, 0), (1, 3), (0, 2)]; + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + + let array = ExtensionArray::new( + Arc::new(TextExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + + let result_ref = interleave(&[&array, &array], &indices).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 3); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected_storage = create_array!(Utf8, ["one banana", "four", "three banana"]); + assert_eq!(**result_array.storage(), *expected_storage); + } } From 0e212b3e44a2256ddf4155bfbba0fee7bf9add55 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 20:30:43 -0500 Subject: [PATCH 11/16] nullif --- arrow-select/src/nullif.rs | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/arrow-select/src/nullif.rs b/arrow-select/src/nullif.rs index dc729da7e6c3..95319272108f 100644 --- a/arrow-select/src/nullif.rs +++ b/arrow-select/src/nullif.rs @@ -113,12 +113,17 @@ pub fn nullif(left: &dyn Array, right: &BooleanArray) -> Result Date: Thu, 10 Apr 2025 20:32:07 -0500 Subject: [PATCH 12/16] fix typo in test extension name --- arrow-schema/src/extension/mod.rs | 4 ++-- arrow-select/src/filter.rs | 4 ++-- arrow-select/src/interleave.rs | 21 +++++++++++++-------- arrow-select/src/nullif.rs | 4 ++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index 1dfd35363aef..1d29bfc89a21 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -411,12 +411,12 @@ impl DynExtensionTypeFactory for CanonicalExtensionTypeFactory { /// Extension for tests #[derive(Debug)] -pub struct TextExtension { +pub struct TestExtension { /// Arbitrary storage type pub storage_type: DataType, } -impl DynExtensionType for TextExtension { +impl DynExtensionType for TestExtension { fn as_any(&self) -> &dyn Any { self } diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 94480fbbe2a7..91180b516020 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -869,7 +869,7 @@ mod tests { use arrow_array::builder::*; use arrow_array::cast::as_run_array; use arrow_array::types::*; - use arrow_schema::extension::TextExtension; + use arrow_schema::extension::TestExtension; use rand::distr::uniform::{UniformSampler, UniformUsize}; use rand::distr::{Alphanumeric, StandardUniform}; use rand::prelude::*; @@ -2062,7 +2062,7 @@ mod tests { "four", ])); let array = ExtensionArray::new( - Arc::new(TextExtension { + Arc::new(TestExtension { storage_type: DataType::Utf8, }), storage.clone(), diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index d3c7a36c0b28..39c317c18e25 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -94,13 +94,19 @@ pub fn interleave( } if let DataType::Extension(extension) = data_type { - let storage: Vec<_> = values.iter().map(|array| { - let extension_array: ExtensionArray = array.to_data().into(); - extension_array.storage().clone() - }).collect(); + let storage: Vec<_> = values + .iter() + .map(|array| { + let extension_array: ExtensionArray = array.to_data().into(); + extension_array.storage().clone() + }) + .collect(); let storage_ref: Vec<_> = storage.iter().map(|array| array.as_ref()).collect(); let storage_result = interleave(&storage_ref, indices)?; - return Ok(Arc::new(ExtensionArray::new(extension.clone(), storage_result))); + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); } downcast_primitive! { @@ -379,7 +385,7 @@ pub fn interleave_record_batch( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; - use arrow_schema::extension::TextExtension; + use arrow_schema::extension::TestExtension; #[test] fn test_primitive() { @@ -741,7 +747,6 @@ mod tests { ); } - #[test] fn test_interleave_extension() { let indices = [(0, 0), (1, 3), (0, 2)]; @@ -753,7 +758,7 @@ mod tests { ])); let array = ExtensionArray::new( - Arc::new(TextExtension { + Arc::new(TestExtension { storage_type: DataType::Utf8, }), storage.clone(), diff --git a/arrow-select/src/nullif.rs b/arrow-select/src/nullif.rs index 95319272108f..caf4c6d6d919 100644 --- a/arrow-select/src/nullif.rs +++ b/arrow-select/src/nullif.rs @@ -123,7 +123,7 @@ mod tests { create_array, ExtensionArray, Int32Array, NullArray, StringArray, StructArray, }; use arrow_data::ArrayData; - use arrow_schema::extension::TextExtension; + use arrow_schema::extension::TestExtension; use arrow_schema::{Field, Fields}; use rand::{rng, Rng}; @@ -542,7 +542,7 @@ mod tests { "four", ])); let array = ExtensionArray::new( - Arc::new(TextExtension { + Arc::new(TestExtension { storage_type: DataType::Utf8, }), storage.clone(), From 6026ba4fbd5674c891cdca40c0562bc1a14aca3f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 20:36:56 -0500 Subject: [PATCH 13/16] take --- arrow-select/src/take.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index df03b85ff186..e71ff6848bb1 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -190,6 +190,12 @@ fn take_impl( values: &dyn Array, indices: &PrimitiveArray, ) -> Result { + if let DataType::Extension(_) = values.data_type() { + let extension_array: ExtensionArray = values.to_data().into(); + let storage_result = take_impl(extension_array.storage(), indices)?; + return Ok(Arc::new(extension_array.with_storage(storage_result))); + } + downcast_primitive_array! { values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { @@ -949,7 +955,7 @@ mod tests { use super::*; use arrow_array::builder::*; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; + use arrow_schema::{extension::TestExtension, Field, Fields, TimeUnit, UnionFields}; fn test_take_decimal_arrays( data: Vec>, @@ -2400,4 +2406,28 @@ mod tests { let array = take(&array, &indicies, None).unwrap(); assert_eq!(array.len(), 3); } + + #[test] + fn test_take_extension() { + let indices = Int32Array::from(vec![1, 3]); + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + let result_ref = take(&array, &indices, None).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 2); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["two banana", "four"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); + } } From d3f1a4eefd7372f3eefe583faaca58b684eaee2b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 20:47:58 -0500 Subject: [PATCH 14/16] new_null --- arrow-data/src/data.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index c1b689d932f6..7d01a27a6131 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -591,6 +591,12 @@ impl ArrayData { /// Returns a new [`ArrayData`] valid for `data_type` containing `len` null values pub fn new_null(data_type: &DataType, len: usize) -> Self { + if let DataType::Extension(extension) = data_type { + let mut storage_data = Self::new_null(extension.storage_type(), len); + storage_data.data_type = data_type.clone(); + return storage_data; + } + let bit_len = bit_util::ceil(len, 8); let zeroed = |len: usize| Buffer::from(MutableBuffer::from_len_zeroed(len)); @@ -2121,7 +2127,7 @@ impl From for ArrayDataBuilder { #[cfg(test)] mod tests { use super::*; - use arrow_schema::{Field, Fields}; + use arrow_schema::{extension::TestExtension, Field, Fields}; // See arrow/tests/array_data_validation.rs for test of array validation @@ -2450,4 +2456,15 @@ mod tests { assert!(array.is_null(i)); } } + + #[test] + fn test_data_extension() { + let data_type = DataType::Extension(Arc::new(TestExtension { + storage_type: DataType::Utf8, + })); + let array_null = ArrayData::new_null(&data_type, 3); + assert_eq!(array_null.len(), 3); + assert_eq!(array_null.data_type(), &data_type); + assert_eq!(array_null.null_count(), 3); + } } From 0259655aaff6f799a0653ea372f26c42ef47e8ab Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 20:57:47 -0500 Subject: [PATCH 15/16] concat --- arrow-select/src/concat.rs | 41 +++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index b48998478442..bc62548e5739 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -297,6 +297,22 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { return Err(ArrowError::InvalidArgumentError(error_message)); } + if let DataType::Extension(extension) = d { + let storage: Vec<_> = arrays + .iter() + .map(|array| { + let extension_array: ExtensionArray = array.to_data().into(); + extension_array.storage().clone() + }) + .collect(); + let storage_ref: Vec<_> = storage.iter().map(|array| array.as_ref()).collect(); + let storage_result = concat(&storage_ref)?; + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); + } + downcast_primitive! { d => (primitive_concat, arrays), DataType::Boolean => concat_boolean(arrays), @@ -374,7 +390,7 @@ pub fn concat_batches<'a>( mod tests { use super::*; use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; - use arrow_schema::{Field, Schema}; + use arrow_schema::{extension::TestExtension, Field, Schema}; use std::fmt::Debug; #[test] @@ -1267,4 +1283,27 @@ mod tests { "There are duplicates in the value list (the value list here is sorted which is only for the assertion)" ); } + + #[test] + fn test_concat_extension() { + let storage = Arc::new(StringArray::from(vec!["one banana", "two banana"])); + + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + + let result_ref = concat(&[&array, &array]).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 4); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected_storage = create_array!( + Utf8, + ["one banana", "two banana", "one banana", "two banana"] + ); + assert_eq!(**result_array.storage(), *expected_storage); + } } From 783a303de00f755764b312f761d45f7d7bf80264 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 10 Apr 2025 21:54:42 -0500 Subject: [PATCH 16/16] zip --- arrow-select/src/zip.rs | 47 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index 2efd2e749921..747f59baf30d 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -17,10 +17,12 @@ //! [`zip`]: Combine values from two arrays based on boolean mask +use std::sync::Arc; + use crate::filter::SlicesIterator; use arrow_array::*; use arrow_data::transform::MutableArrayData; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType}; /// Zip two arrays by some boolean mask. /// @@ -116,6 +118,16 @@ pub fn zip( )); } + if let DataType::Extension(extension) = truthy.data_type() { + let truthy_extension: ExtensionArray = truthy.to_data().into(); + let falsy_extension: ExtensionArray = falsy.to_data().into(); + let storage_result = zip(mask, truthy_extension.storage(), falsy_extension.storage())?; + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); + } + let falsy = falsy.to_data(); let truthy = truthy.to_data(); @@ -168,6 +180,10 @@ pub fn zip( #[cfg(test)] mod test { + use std::sync::Arc; + + use arrow_schema::{extension::TestExtension, DataType}; + use super::*; #[test] @@ -279,4 +295,33 @@ mod test { let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_extension() { + let mask = BooleanArray::from(vec![true, false, true, false]); + let truthy_storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let falsy_storage = Arc::new(StringArray::from(vec![ + "five banana", + "six banana", + "seven banana", + "more", + ])); + let extension = Arc::new(TestExtension { + storage_type: DataType::Utf8, + }); + let truthy = ExtensionArray::new(extension.clone(), truthy_storage.clone()); + let falsy = ExtensionArray::new(extension.clone(), falsy_storage.clone()); + let result_ref = zip(&mask, &truthy, &falsy).unwrap(); + assert_eq!(result_ref.data_type(), truthy.data_type()); + assert_eq!(result_ref.len(), 4); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["one banana", "six banana", "three banana", "more"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); + } }