diff --git a/Cargo.lock b/Cargo.lock index 15214a1ac7..5e91bcccaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3143,6 +3143,7 @@ dependencies = [ "futures", "iceberg", "iceberg-catalog-memory", + "parquet", "tempfile", "tokio", ] diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs index 0369db0752..56caeaf559 100644 --- a/crates/iceberg/src/arrow/mod.rs +++ b/crates/iceberg/src/arrow/mod.rs @@ -19,7 +19,12 @@ mod schema; pub use schema::*; + +mod nan_val_cnt_visitor; +pub(crate) use nan_val_cnt_visitor::*; + pub(crate) mod delete_file_manager; + mod reader; pub(crate) mod record_batch_projector; pub(crate) mod record_batch_transformer; diff --git a/crates/iceberg/src/arrow/nan_val_cnt_visitor.rs b/crates/iceberg/src/arrow/nan_val_cnt_visitor.rs new file mode 100644 index 0000000000..db6279d9ca --- /dev/null +++ b/crates/iceberg/src/arrow/nan_val_cnt_visitor.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The module contains the visitor for calculating NaN values in give arrow record batch. + +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, StructArray}; +use arrow_schema::DataType; + +use crate::arrow::ArrowArrayAccessor; +use crate::spec::{ + visit_struct_with_partner, ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaRef, + SchemaWithPartnerVisitor, StructType, +}; +use crate::Result; + +macro_rules! cast_and_update_cnt_map { + ($t:ty, $col:ident, $self:ident, $field_id:ident) => { + let nan_val_cnt = $col + .as_any() + .downcast_ref::<$t>() + .unwrap() + .iter() + .filter(|value| value.map_or(false, |v| v.is_nan())) + .count() as u64; + + match $self.nan_value_counts.entry($field_id) { + Entry::Occupied(mut ele) => { + let total_nan_val_cnt = ele.get() + nan_val_cnt; + ele.insert(total_nan_val_cnt); + } + Entry::Vacant(v) => { + v.insert(nan_val_cnt); + } + }; + }; +} + +macro_rules! count_float_nans { + ($col:ident, $self:ident, $field_id:ident) => { + match $col.data_type() { + DataType::Float32 => { + cast_and_update_cnt_map!(Float32Array, $col, $self, $field_id); + } + DataType::Float64 => { + cast_and_update_cnt_map!(Float64Array, $col, $self, $field_id); + } + _ => {} + } + }; +} + +/// Visitor which counts and keeps track of NaN value counts in given record batch(s) +pub struct NanValueCountVisitor { + /// Stores field ID to NaN value count mapping + pub nan_value_counts: HashMap, +} + +impl SchemaWithPartnerVisitor for NanValueCountVisitor { + type T = (); + + fn schema( + &mut self, + _schema: &Schema, + _partner: &ArrayRef, + _value: Self::T, + ) -> Result { + Ok(()) + } + + fn field( + &mut self, + _field: &NestedFieldRef, + _partner: &ArrayRef, + _value: Self::T, + ) -> Result { + Ok(()) + } + + fn r#struct( + &mut self, + _struct: &StructType, + _partner: &ArrayRef, + _results: Vec, + ) -> Result { + Ok(()) + } + + fn list(&mut self, _list: &ListType, _list_arr: &ArrayRef, _value: Self::T) -> Result { + Ok(()) + } + + fn map( + &mut self, + _map: &MapType, + _partner: &ArrayRef, + _key_value: Self::T, + _value: Self::T, + ) -> Result { + Ok(()) + } + + fn primitive(&mut self, _p: &PrimitiveType, _col: &ArrayRef) -> Result { + Ok(()) + } + + fn after_struct_field(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> { + let field_id = field.id; + count_float_nans!(partner, self, field_id); + Ok(()) + } + + fn after_list_element(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> { + let field_id = field.id; + count_float_nans!(partner, self, field_id); + Ok(()) + } + + fn after_map_key(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> { + let field_id = field.id; + count_float_nans!(partner, self, field_id); + Ok(()) + } + + fn after_map_value(&mut self, field: &NestedFieldRef, partner: &ArrayRef) -> Result<()> { + let field_id = field.id; + count_float_nans!(partner, self, field_id); + Ok(()) + } +} + +impl NanValueCountVisitor { + /// Creates new instance of NanValueCountVisitor + pub fn new() -> Self { + Self { + nan_value_counts: HashMap::new(), + } + } + + /// Compute nan value counts in given schema and record batch + pub fn compute(&mut self, schema: SchemaRef, batch: RecordBatch) -> Result<()> { + let arrow_arr_partner_accessor = ArrowArrayAccessor {}; + + let struct_arr = Arc::new(StructArray::from(batch)) as ArrayRef; + visit_struct_with_partner( + schema.as_struct(), + &struct_arr, + self, + &arrow_arr_partner_accessor, + )?; + + Ok(()) + } +} + +impl Default for NanValueCountVisitor { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/iceberg/src/arrow/value.rs b/crates/iceberg/src/arrow/value.rs index d78c4f4400..84b33d3ff6 100644 --- a/crates/iceberg/src/arrow/value.rs +++ b/crates/iceberg/src/arrow/value.rs @@ -425,7 +425,8 @@ impl SchemaWithPartnerVisitor for ArrowArrayToIcebergStructConverter { } } -struct ArrowArrayAccessor; +/// Partner type representing accessing and walking arrow arrays alongside iceberg schema +pub struct ArrowArrayAccessor; impl PartnerAccessor for ArrowArrayAccessor { fn struct_parner<'a>(&self, schema_partner: &'a ArrayRef) -> Result<&'a ArrayRef> { @@ -435,6 +436,7 @@ impl PartnerAccessor for ArrowArrayAccessor { "The schema partner is not a struct type", )); } + Ok(schema_partner) } @@ -452,6 +454,7 @@ impl PartnerAccessor for ArrowArrayAccessor { "The struct partner is not a struct array", ) })?; + let field_pos = struct_array .fields() .iter() @@ -466,6 +469,7 @@ impl PartnerAccessor for ArrowArrayAccessor { format!("Field id {} not found in struct array", field.id), ) })?; + Ok(struct_array.column(field_pos)) } diff --git a/crates/iceberg/src/writer/base_writer/data_file_writer.rs b/crates/iceberg/src/writer/base_writer/data_file_writer.rs index c4f39ba318..dea8fd423c 100644 --- a/crates/iceberg/src/writer/base_writer/data_file_writer.rs +++ b/crates/iceberg/src/writer/base_writer/data_file_writer.rs @@ -103,11 +103,13 @@ impl CurrentFileStatus for DataFileWriter { #[cfg(test)] mod test { + use std::collections::HashMap; use std::sync::Arc; use arrow_array::{Int32Array, StringArray}; use arrow_schema::{DataType, Field}; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use parquet::file::properties::WriterProperties; use tempfile::TempDir; @@ -153,8 +155,14 @@ mod test { .unwrap(); let arrow_schema = arrow_schema::Schema::new(vec![ - Field::new("foo", DataType::Int32, false), - Field::new("bar", DataType::Utf8, false), + Field::new("foo", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 3.to_string(), + )])), + Field::new("bar", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 4.to_string(), + )])), ]); let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), @@ -224,8 +232,14 @@ mod test { .await?; let arrow_schema = arrow_schema::Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), + Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 5.to_string(), + )])), + Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 6.to_string(), + )])), ]); let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index e14f5a2d4d..48d37eb9a9 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -40,7 +40,7 @@ use super::track_writer::TrackWriter; use super::{FileWriter, FileWriterBuilder}; use crate::arrow::{ get_parquet_stat_max_as_datum, get_parquet_stat_min_as_datum, ArrowFileReader, - DEFAULT_MAP_FIELD_NAME, + NanValueCountVisitor, DEFAULT_MAP_FIELD_NAME, }; use crate::io::{FileIO, FileWrite, OutputFile}; use crate::spec::{ @@ -99,6 +99,7 @@ impl FileWriterBuilder for ParquetWr written_size, current_row_num: 0, out_file, + nan_value_count_visitor: NanValueCountVisitor::new(), }) } } @@ -224,6 +225,7 @@ pub struct ParquetWriter { writer_properties: WriterProperties, written_size: Arc, current_row_num: usize, + nan_value_count_visitor: NanValueCountVisitor, } /// Used to aggregate min and max value of each column. @@ -346,6 +348,8 @@ impl ParquetWriter { parquet_metadata, file_size_in_bytes, file_path, + // TODO: Implement nan_value_counts here + HashMap::new(), )?; builder.partition_spec_id(table_metadata.default_partition_spec_id()); let data_file = builder.build().unwrap(); @@ -384,6 +388,7 @@ impl ParquetWriter { metadata: Arc, written_size: usize, file_path: String, + nan_value_counts: HashMap, ) -> Result { let index_by_parquet_path = { let mut visitor = IndexByParquetPathName::new(); @@ -438,6 +443,9 @@ impl ParquetWriter { .column_sizes(column_sizes) .value_counts(value_counts) .null_value_counts(null_value_counts) + .nan_value_counts(nan_value_counts) + // # NOTE: + // - We can ignore implementing distinct_counts due to this: https://lists.apache.org/thread/j52tsojv0x4bopxyzsp7m7bqt23n5fnd .lower_bounds(lower_bounds) .upper_bounds(upper_bounds) .split_offsets( @@ -461,6 +469,10 @@ impl FileWriter for ParquetWriter { self.current_row_num += batch.num_rows(); + let batch_c = batch.clone(); + self.nan_value_count_visitor + .compute(self.schema.clone(), batch_c)?; + // Lazy initialize the writer let writer = if let Some(writer) = &mut self.inner_writer { writer @@ -489,6 +501,7 @@ impl FileWriter for ParquetWriter { ) .with_source(err) })?; + Ok(()) } @@ -518,6 +531,7 @@ impl FileWriter for ParquetWriter { parquet_metadata, written_size as usize, self.out_file.location().to_string(), + self.nan_value_count_visitor.nan_value_counts, )?]) } } @@ -576,12 +590,13 @@ mod tests { use std::sync::Arc; use anyhow::Result; - use arrow_array::types::Int64Type; + use arrow_array::builder::{Float32Builder, Int32Builder, MapBuilder}; + use arrow_array::types::{Float32Type, Int64Type}; use arrow_array::{ - Array, ArrayRef, BooleanArray, Decimal128Array, Int32Array, Int64Array, ListArray, - RecordBatch, StructArray, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, + Int64Array, ListArray, MapArray, RecordBatch, StructArray, }; - use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; + use arrow_schema::{DataType, Field, Fields, SchemaRef as ArrowSchemaRef}; use arrow_select::concat::concat_batches; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use rust_decimal::Decimal; @@ -1572,4 +1587,586 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_nan_val_cnts_primitive_type() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + // prepare data + let arrow_schema = { + let fields = vec![ + Field::new("col", arrow_schema::DataType::Float32, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]), + ), + Field::new("col2", arrow_schema::DataType::Float64, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1".to_string())]), + ), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + + let float_32_col = Arc::new(Float32Array::from_iter_values_with_nulls( + [1.0_f32, f32::NAN, 2.0, 2.0].into_iter(), + None, + )) as ArrayRef; + + let float_64_col = Arc::new(Float64Array::from_iter_values_with_nulls( + [1.0_f64, f64::NAN, 2.0, 2.0].into_iter(), + None, + )) as ArrayRef; + + let to_write = + RecordBatch::try_new(arrow_schema.clone(), vec![float_32_col, float_64_col]).unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(to_write.schema().as_ref().try_into().unwrap()), + file_io.clone(), + location_gen, + file_name_gen, + ) + .build() + .await?; + + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .partition_spec_id(0) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 4); + assert_eq!(*data_file.value_counts(), HashMap::from([(0, 4), (1, 4)])); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([(0, Datum::float(1.0)), (1, Datum::double(1.0)),]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([(0, Datum::float(2.0)), (1, Datum::double(2.0)),]) + ); + assert_eq!( + *data_file.null_value_counts(), + HashMap::from([(0, 0), (1, 0)]) + ); + assert_eq!( + *data_file.nan_value_counts(), + HashMap::from([(0, 1), (1, 1)]) + ); + + // check the written file + let expect_batch = concat_batches(&arrow_schema, vec![&to_write]).unwrap(); + check_parquet_data_file(&file_io, &data_file, &expect_batch).await; + + Ok(()) + } + + #[tokio::test] + async fn test_nan_val_cnts_struct_type() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + let schema_struct_float_fields = + Fields::from(vec![Field::new("col4", DataType::Float32, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "4".to_string(), + )]))]); + + let schema_struct_nested_float_fields = + Fields::from(vec![Field::new("col7", DataType::Float32, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "7".to_string(), + )]))]); + + let schema_struct_nested_fields = Fields::from(vec![Field::new( + "col6", + arrow_schema::DataType::Struct(schema_struct_nested_float_fields.clone()), + false, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "6".to_string(), + )]))]); + + // prepare data + let arrow_schema = { + let fields = vec![ + Field::new( + "col3", + arrow_schema::DataType::Struct(schema_struct_float_fields.clone()), + false, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )])), + Field::new( + "col5", + arrow_schema::DataType::Struct(schema_struct_nested_fields.clone()), + false, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "5".to_string(), + )])), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + + let float_32_col = Arc::new(Float32Array::from_iter_values_with_nulls( + [1.0_f32, f32::NAN, 2.0, 2.0].into_iter(), + None, + )) as ArrayRef; + + let struct_float_field_col = Arc::new(StructArray::new( + schema_struct_float_fields, + vec![float_32_col.clone()], + None, + )) as ArrayRef; + + let struct_nested_float_field_col = Arc::new(StructArray::new( + schema_struct_nested_fields, + vec![Arc::new(StructArray::new( + schema_struct_nested_float_fields, + vec![float_32_col.clone()], + None, + )) as ArrayRef], + None, + )) as ArrayRef; + + let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![ + struct_float_field_col, + struct_nested_float_field_col, + ]) + .unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(to_write.schema().as_ref().try_into().unwrap()), + file_io.clone(), + location_gen, + file_name_gen, + ) + .build() + .await?; + + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .partition_spec_id(0) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 4); + assert_eq!(*data_file.value_counts(), HashMap::from([(4, 4), (7, 4)])); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([(4, Datum::float(1.0)), (7, Datum::float(1.0)),]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([(4, Datum::float(2.0)), (7, Datum::float(2.0)),]) + ); + assert_eq!( + *data_file.null_value_counts(), + HashMap::from([(4, 0), (7, 0)]) + ); + assert_eq!( + *data_file.nan_value_counts(), + HashMap::from([(4, 1), (7, 1)]) + ); + + // check the written file + let expect_batch = concat_batches(&arrow_schema, vec![&to_write]).unwrap(); + check_parquet_data_file(&file_io, &data_file, &expect_batch).await; + + Ok(()) + } + + #[tokio::test] + async fn test_nan_val_cnts_list_type() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + let schema_list_float_field = Field::new("element", DataType::Float32, true).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1".to_string())]), + ); + + let schema_struct_list_float_field = Field::new("element", DataType::Float32, true) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "4".to_string(), + )])); + + let schema_struct_list_field = Fields::from(vec![Field::new_list( + "col2", + schema_struct_list_float_field.clone(), + true, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )]))]); + + let arrow_schema = { + let fields = vec![ + Field::new_list("col0", schema_list_float_field.clone(), true).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]), + ), + Field::new_struct("col1", schema_struct_list_field.clone(), true) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])) + .clone(), + // Field::new_large_list("col3", schema_large_list_float_field.clone(), true).with_metadata( + // HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "5".to_string())]), + // ).clone(), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + + let list_parts = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1.0_f32), + Some(f32::NAN), + Some(2.0), + Some(2.0), + ])]) + .into_parts(); + + let list_float_field_col = Arc::new({ + let list_parts = list_parts.clone(); + ListArray::new( + { + if let DataType::List(field) = arrow_schema.field(0).data_type() { + field.clone() + } else { + unreachable!() + } + }, + list_parts.1, + list_parts.2, + list_parts.3, + ) + }) as ArrayRef; + + let struct_list_fields_schema = + if let DataType::Struct(fields) = arrow_schema.field(1).data_type() { + fields.clone() + } else { + unreachable!() + }; + + let struct_list_float_field_col = Arc::new({ + ListArray::new( + { + if let DataType::List(field) = struct_list_fields_schema + .first() + .expect("could not find first list field") + .data_type() + { + field.clone() + } else { + unreachable!() + } + }, + list_parts.1, + list_parts.2, + list_parts.3, + ) + }) as ArrayRef; + + let struct_list_float_field_col = Arc::new(StructArray::new( + struct_list_fields_schema, + vec![struct_list_float_field_col.clone()], + None, + )) as ArrayRef; + + let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![ + list_float_field_col, + struct_list_float_field_col, + // large_list_float_field_col, + ]) + .expect("Could not form record batch"); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new( + to_write + .schema() + .as_ref() + .try_into() + .expect("Could not convert iceberg schema"), + ), + file_io.clone(), + location_gen, + file_name_gen, + ) + .build() + .await?; + + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .partition_spec_id(0) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 1); + assert_eq!(*data_file.value_counts(), HashMap::from([(1, 4), (4, 4)])); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([(1, Datum::float(1.0)), (4, Datum::float(1.0))]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([(1, Datum::float(2.0)), (4, Datum::float(2.0))]) + ); + assert_eq!( + *data_file.null_value_counts(), + HashMap::from([(1, 0), (4, 0)]) + ); + assert_eq!( + *data_file.nan_value_counts(), + HashMap::from([(1, 1), (4, 1)]) + ); + + // check the written file + let expect_batch = concat_batches(&arrow_schema, vec![&to_write]).unwrap(); + check_parquet_data_file(&file_io, &data_file, &expect_batch).await; + + Ok(()) + } + + macro_rules! construct_map_arr { + ($map_key_field_schema:ident, $map_value_field_schema:ident) => {{ + let int_builder = Int32Builder::new(); + let float_builder = Float32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, int_builder, float_builder); + builder.keys().append_value(1); + builder.values().append_value(1.0_f32); + builder.append(true).unwrap(); + builder.keys().append_value(2); + builder.values().append_value(f32::NAN); + builder.append(true).unwrap(); + builder.keys().append_value(3); + builder.values().append_value(2.0); + builder.append(true).unwrap(); + builder.keys().append_value(4); + builder.values().append_value(2.0); + builder.append(true).unwrap(); + let array = builder.finish(); + + let (_field, offsets, entries, nulls, ordered) = array.into_parts(); + let new_struct_fields_schema = + Fields::from(vec![$map_key_field_schema, $map_value_field_schema]); + + let entries = { + let (_, arrays, nulls) = entries.into_parts(); + StructArray::new(new_struct_fields_schema.clone(), arrays, nulls) + }; + + let field = Arc::new(Field::new( + DEFAULT_MAP_FIELD_NAME, + DataType::Struct(new_struct_fields_schema), + false, + )); + + Arc::new(MapArray::new(field, offsets, entries, nulls, ordered)) + }}; + } + + #[tokio::test] + async fn test_nan_val_cnts_map_type() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + let map_key_field_schema = + Field::new(MAP_KEY_FIELD_NAME, DataType::Int32, false).with_metadata(HashMap::from([ + (PARQUET_FIELD_ID_META_KEY.to_string(), "1".to_string()), + ])); + + let map_value_field_schema = + Field::new(MAP_VALUE_FIELD_NAME, DataType::Float32, true).with_metadata(HashMap::from( + [(PARQUET_FIELD_ID_META_KEY.to_string(), "2".to_string())], + )); + + let struct_map_key_field_schema = + Field::new(MAP_KEY_FIELD_NAME, DataType::Int32, false).with_metadata(HashMap::from([ + (PARQUET_FIELD_ID_META_KEY.to_string(), "6".to_string()), + ])); + + let struct_map_value_field_schema = + Field::new(MAP_VALUE_FIELD_NAME, DataType::Float32, true).with_metadata(HashMap::from( + [(PARQUET_FIELD_ID_META_KEY.to_string(), "7".to_string())], + )); + + let schema_struct_map_field = Fields::from(vec![Field::new_map( + "col3", + DEFAULT_MAP_FIELD_NAME, + struct_map_key_field_schema.clone(), + struct_map_value_field_schema.clone(), + false, + false, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "5".to_string(), + )]))]); + + let arrow_schema = { + let fields = vec![ + Field::new_map( + "col0", + DEFAULT_MAP_FIELD_NAME, + map_key_field_schema.clone(), + map_value_field_schema.clone(), + false, + false, + ) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "0".to_string(), + )])), + Field::new_struct("col1", schema_struct_map_field.clone(), true) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )])) + .clone(), + ]; + Arc::new(arrow_schema::Schema::new(fields)) + }; + + let map_array = construct_map_arr!(map_key_field_schema, map_value_field_schema); + + let struct_map_arr = + construct_map_arr!(struct_map_key_field_schema, struct_map_value_field_schema); + + let struct_list_float_field_col = Arc::new(StructArray::new( + schema_struct_map_field, + vec![struct_map_arr], + None, + )) as ArrayRef; + + let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![ + map_array, + struct_list_float_field_col, + ]) + .expect("Could not form record batch"); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new( + to_write + .schema() + .as_ref() + .try_into() + .expect("Could not convert iceberg schema"), + ), + file_io.clone(), + location_gen, + file_name_gen, + ) + .build() + .await?; + + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .partition_spec_id(0) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 4); + assert_eq!( + *data_file.value_counts(), + HashMap::from([(1, 4), (2, 4), (6, 4), (7, 4)]) + ); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([ + (1, Datum::int(1)), + (2, Datum::float(1.0)), + (6, Datum::int(1)), + (7, Datum::float(1.0)) + ]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([ + (1, Datum::int(4)), + (2, Datum::float(2.0)), + (6, Datum::int(4)), + (7, Datum::float(2.0)) + ]) + ); + assert_eq!( + *data_file.null_value_counts(), + HashMap::from([(1, 0), (2, 0), (6, 0), (7, 0)]) + ); + assert_eq!( + *data_file.nan_value_counts(), + HashMap::from([(2, 1), (7, 1)]) + ); + + // check the written file + let expect_batch = concat_batches(&arrow_schema, vec![&to_write]).unwrap(); + check_parquet_data_file(&file_io, &data_file, &expect_batch).await; + + Ok(()) + } } diff --git a/crates/integrations/datafusion/Cargo.toml b/crates/integrations/datafusion/Cargo.toml index ccb9ca175b..c6c564574c 100644 --- a/crates/integrations/datafusion/Cargo.toml +++ b/crates/integrations/datafusion/Cargo.toml @@ -41,4 +41,5 @@ tokio = { workspace = true } [dev-dependencies] iceberg-catalog-memory = { workspace = true } +parquet = { workspace = true } tempfile = { workspace = true } diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index a45e21a586..10b92d54b1 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -220,20 +220,31 @@ fn scalar_value_to_datum(value: &ScalarValue) -> Option { #[cfg(test)] mod tests { + use std::collections::HashMap; + use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::DFSchema; use datafusion::logical_expr::utils::split_conjunction; use datafusion::prelude::{Expr, SessionContext}; use iceberg::expr::{Predicate, Reference}; use iceberg::spec::Datum; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use super::convert_filters_to_predicate; fn create_test_schema() -> DFSchema { let arrow_schema = Schema::new(vec![ - Field::new("foo", DataType::Int32, true), - Field::new("bar", DataType::Utf8, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new("foo", DataType::Int32, true).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), + Field::new("bar", DataType::Utf8, true).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])), + Field::new("ts", DataType::Timestamp(TimeUnit::Second, None), true).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "3".to_string())]), + ), ]); DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() }