diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs index 6d7ab4400b6e..2e9d5fbba1ab 100644 --- a/arrow-json/src/lib.rs +++ b/arrow-json/src/lib.rs @@ -74,7 +74,7 @@ pub mod reader; pub mod writer; -pub use self::reader::{Reader, ReaderBuilder}; +pub use self::reader::{ArrayDecoder, DecoderFactory, Reader, ReaderBuilder, Tape, TapeElement}; pub use self::writer::{ ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer, WriterBuilder, diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index 1a1dee6a23d4..1a76bf8d3d14 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -24,6 +24,9 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; use std::marker::PhantomData; +use std::sync::Arc; + +use super::DecoderFactory; pub struct ListArrayDecoder { data_type: DataType, @@ -39,6 +42,7 @@ impl ListArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let field = match &data_type { DataType::List(f) if !O::IS_LARGE => f, @@ -51,6 +55,7 @@ impl ListArrayDecoder { strict_mode, field.is_nullable(), struct_mode, + decoder_factory, )?; Ok(Self { diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index ee78373a551e..f9d610ed946d 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{make_decoder, ArrayDecoder}; use crate::StructMode; @@ -24,6 +26,8 @@ use arrow_buffer::ArrowNativeType; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; +use super::DecoderFactory; + pub struct MapArrayDecoder { data_type: DataType, keys: Box, @@ -38,6 +42,7 @@ impl MapArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let fields = match &data_type { DataType::Map(_, true) => { @@ -62,6 +67,7 @@ impl MapArrayDecoder { strict_mode, fields[0].is_nullable(), struct_mode, + decoder_factory.clone(), )?; let values = make_decoder( fields[1].data_type().clone(), @@ -69,6 +75,7 @@ impl MapArrayDecoder { strict_mode, fields[1].is_nullable(), struct_mode, + decoder_factory, )?; Ok(Self { diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index cd33e337be08..5815d4b236a2 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -146,6 +146,7 @@ use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader, use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; +pub use tape::*; use crate::reader::boolean_array::BooleanArrayDecoder; use crate::reader::decimal_array::DecimalArrayDecoder; @@ -156,7 +157,6 @@ use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::string_view_array::StringViewArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod boolean_array; @@ -180,6 +180,7 @@ pub struct ReaderBuilder { strict_mode: bool, is_field: bool, struct_mode: StructMode, + decoder_factory: Option>, schema: SchemaRef, } @@ -201,6 +202,7 @@ impl ReaderBuilder { is_field: false, struct_mode: Default::default(), schema, + decoder_factory: None, } } @@ -242,6 +244,7 @@ impl ReaderBuilder { is_field: true, struct_mode: Default::default(), schema: Arc::new(Schema::new([field.into()])), + decoder_factory: None, } } @@ -281,6 +284,14 @@ impl ReaderBuilder { } } + /// Set an optional hook for customizing decoding behavior. + pub fn with_decoder_factory(self, decoder_factory: Arc) -> Self { + Self { + decoder_factory: Some(decoder_factory), + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -305,6 +316,7 @@ impl ReaderBuilder { self.strict_mode, nullable, self.struct_mode, + self.decoder_factory, )?; let num_fields = self.schema.flattened_fields().len(); @@ -369,6 +381,95 @@ impl RecordBatchReader for Reader { } } +/// A trait to create custom decoders for specific data types. +/// +/// This allows overriding the default decoders for specific data types, +/// or adding new decoders for custom data types. +/// +/// # Examples +/// +/// ``` +/// use arrow_json::{ArrayDecoder, DecoderFactory, TapeElement, Tape, ReaderBuilder, StructMode}; +/// use arrow_schema::ArrowError; +/// use arrow_schema::{DataType, Field, Fields, Schema}; +/// use arrow_array::cast::AsArray; +/// use arrow_array::Array; +/// use arrow_array::builder::StringBuilder; +/// use arrow_data::ArrayData; +/// use std::sync::Arc; +/// +/// struct IncorrectStringAsNullDecoder {} +/// +/// impl ArrayDecoder for IncorrectStringAsNullDecoder { +/// fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { +/// let mut builder = StringBuilder::new(); +/// for p in pos { +/// match tape.get(*p) { +/// TapeElement::String(idx) => { +/// builder.append_value(tape.get_string(idx)); +/// } +/// _ => builder.append_null(), +/// } +/// } +/// Ok(builder.finish().into_data()) +/// } +/// } +/// +/// #[derive(Debug)] +/// struct IncorrectStringAsNullDecoderFactory; +/// +/// impl DecoderFactory for IncorrectStringAsNullDecoderFactory { +/// fn make_default_decoder<'a>( +/// &self, +/// data_type: DataType, +/// _coerce_primitive: bool, +/// _strict_mode: bool, +/// _is_nullable: bool, +/// _struct_mode: StructMode, +/// ) -> Result>, ArrowError> { +/// match data_type { +/// DataType::Utf8 => Ok(Some(Box::new(IncorrectStringAsNullDecoder {}))), +/// _ => Ok(None), +/// } +/// } +/// } +/// +/// let json = r#" +/// {"a": "a"} +/// {"a": 12} +/// "#; +/// let batch = ReaderBuilder::new(Arc::new(Schema::new(Fields::from(vec![Field::new( +/// "a", +/// DataType::Utf8, +/// true, +/// )])))) +/// .with_decoder_factory(Arc::new(IncorrectStringAsNullDecoderFactory)) +/// .build(json.as_bytes()) +/// .unwrap() +/// .next() +/// .unwrap() +/// .unwrap(); +/// +/// let values = batch.column(0).as_string::(); +/// assert_eq!(values.len(), 2); +/// assert_eq!(values.value(0), "a"); +/// assert!(values.is_null(1)); +/// ``` +pub trait DecoderFactory: std::fmt::Debug + Send + Sync { + /// Make a decoder that overrides the default decoder for a specific data type. + /// This can be used to override how e.g. error in decoding are handled. + fn make_default_decoder( + &self, + _data_type: DataType, + _coerce_primitive: bool, + _strict_mode: bool, + _is_nullable: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + Ok(None) + } +} + /// A low-level interface for reading JSON data from a byte stream /// /// See [`Reader`] for a higher-level interface for interface with [`BufRead`] @@ -668,7 +769,8 @@ impl Decoder { } } -trait ArrayDecoder: Send { +/// A trait to decode JSON values into arrow arrays +pub trait ArrayDecoder: Send { /// Decode elements from `tape` starting at the indexes contained in `pos` fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result; } @@ -685,7 +787,20 @@ fn make_decoder( strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result, ArrowError> { + if let Some(ref factory) = decoder_factory { + if let Some(decoder) = factory.make_default_decoder( + data_type.clone(), + coerce_primitive, + strict_mode, + is_nullable, + struct_mode, + )? { + return Ok(decoder); + } + } + downcast_integer! { data_type => (primitive_decoder, data_type), DataType::Null => Ok(Box::::default()), @@ -736,13 +851,13 @@ fn make_decoder( DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), DataType::Utf8View => Ok(Box::new(StringViewArrayDecoder::new(coerce_primitive))), DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), + DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => { Err(ArrowError::JsonError(format!("{data_type} is not supported by JSON"))) } - DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode, decoder_factory)?)), d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader"))) } } @@ -2808,4 +2923,68 @@ mod tests { "Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned() ); } + + #[test] + fn test_decoder_factory() { + use arrow_array::builder; + + struct AlwaysNullStringArrayDecoder; + + impl ArrayDecoder for AlwaysNullStringArrayDecoder { + fn decode(&mut self, _tape: &Tape<'_>, pos: &[u32]) -> Result { + let mut builder = builder::StringBuilder::new(); + for _ in pos { + builder.append_null(); + } + Ok(builder.finish().into_data()) + } + } + + #[derive(Debug)] + struct AlwaysNullStringArrayDecoderFactory; + + impl DecoderFactory for AlwaysNullStringArrayDecoderFactory { + fn make_default_decoder<'a>( + &self, + data_type: DataType, + _coerce_primitive: bool, + _strict_mode: bool, + _is_nullable: bool, + _struct_mode: StructMode, + ) -> Result>, ArrowError> { + match data_type { + DataType::Utf8 => Ok(Some(Box::new(AlwaysNullStringArrayDecoder {}))), + _ => Ok(None), + } + } + } + + let buf = r#" + {"a": "1", "b": 2} + {"a": "hello", "b": 23} + "#; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + + let batches = ReaderBuilder::new(schema.clone()) + .with_batch_size(2) + .with_decoder_factory(Arc::new(AlwaysNullStringArrayDecoderFactory)) + .build(Cursor::new(buf.as_bytes())) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(batches.len(), 1); + + let col1 = batches[0].column(0).as_string::(); + assert_eq!(col1.null_count(), 2); + assert!(col1.is_null(0)); + assert!(col1.is_null(1)); + + let col2 = batches[0].column(1).as_primitive::(); + assert_eq!(col2.value(0), 2); + assert_eq!(col2.value(1), 23); + } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index b9408df77a43..d4c96cd42581 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::reader::tape::{Tape, TapeElement}; use crate::reader::{make_decoder, ArrayDecoder, StructMode}; use arrow_array::builder::BooleanBufferBuilder; @@ -22,6 +24,8 @@ use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Fields}; +use super::DecoderFactory; + pub struct StructArrayDecoder { data_type: DataType, decoders: Vec>, @@ -37,6 +41,7 @@ impl StructArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + decoder_factory: Option>, ) -> Result { let decoders = struct_fields(&data_type) .iter() @@ -51,6 +56,7 @@ impl StructArrayDecoder { strict_mode, nullable, struct_mode, + decoder_factory.clone(), ) }) .collect::, ArrowError>>()?; diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index ed65baab9f2b..d3d593cd254f 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -338,6 +338,7 @@ impl TapeDecoder { } } + /// Decodes JSON data from the provided buffer, returning the number of bytes consumed pub fn decode(&mut self, buf: &[u8]) -> Result { let mut iter = BufIter::new(buf);