diff --git a/crates/iceberg/src/writer/base_writer/data_file_writer.rs b/crates/iceberg/src/writer/base_writer/data_file_writer.rs index 6f9c0a8920..940aa15842 100644 --- a/crates/iceberg/src/writer/base_writer/data_file_writer.rs +++ b/crates/iceberg/src/writer/base_writer/data_file_writer.rs @@ -100,16 +100,21 @@ impl CurrentFileStatus for DataFileWriter { mod test { use std::sync::Arc; + use arrow_array::{Int32Array, StringArray}; + use arrow_schema::{DataType, Field}; + use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::file::properties::WriterProperties; use tempfile::TempDir; use crate::io::FileIOBuilder; - use crate::spec::{DataContentType, DataFileFormat, Schema, Struct}; + use crate::spec::{ + DataContentType, DataFileFormat, Literal, NestedField, PrimitiveType, Schema, Struct, Type, + }; use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder; use crate::writer::file_writer::location_generator::test::MockLocationGenerator; use crate::writer::file_writer::location_generator::DefaultFileNameGenerator; use crate::writer::file_writer::ParquetWriterBuilder; - use crate::writer::{IcebergWriter, IcebergWriterBuilder}; + use crate::writer::{IcebergWriter, IcebergWriterBuilder, RecordBatch}; use crate::Result; #[tokio::test] @@ -121,20 +126,124 @@ mod test { let file_name_gen = DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + let schema = Schema::builder() + .with_schema_id(3) + .with_fields(vec![ + NestedField::required(3, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(4, "bar", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + let pw = ParquetWriterBuilder::new( WriterProperties::builder().build(), - Arc::new(Schema::builder().build().unwrap()), + Arc::new(schema), + file_io.clone(), + location_gen, + file_name_gen, + ); + + let mut data_file_writer = DataFileWriterBuilder::new(pw, None).build().await.unwrap(); + + let data_files = data_file_writer.close().await.unwrap(); + assert_eq!(data_files.len(), 1); + + let data_file = &data_files[0]; + assert_eq!(data_file.file_format, DataFileFormat::Parquet); + assert_eq!(data_file.content, DataContentType::Data); + assert_eq!(data_file.partition, Struct::empty()); + + let input_file = file_io.new_input(data_file.file_path.clone())?; + let input_content = input_file.read().await?; + + let parquet_reader = + ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default()) + .expect("Failed to load Parquet metadata"); + + let field_ids: Vec = parquet_reader + .parquet_schema() + .columns() + .iter() + .map(|col| col.self_type().get_basic_info().id()) + .collect(); + + assert_eq!(field_ids, vec![3, 4]); + Ok(()) + } + + #[tokio::test] + async fn test_parquet_writer_with_partition() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let location_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = DefaultFileNameGenerator::new( + "test_partitioned".to_string(), + None, + DataFileFormat::Parquet, + ); + + let schema = Schema::builder() + .with_schema_id(5) + .with_fields(vec![ + NestedField::required(5, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(6, "name", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + + let partition_value = Struct::from_iter([Some(Literal::int(1))]); + + let parquet_writer_builder = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(schema.clone()), file_io.clone(), location_gen, file_name_gen, ); - let mut data_file_writer = DataFileWriterBuilder::new(pw, None).build().await?; - let data_file = data_file_writer.close().await.unwrap(); - assert_eq!(data_file.len(), 1); - assert_eq!(data_file[0].file_format, DataFileFormat::Parquet); - assert_eq!(data_file[0].content, DataContentType::Data); - assert_eq!(data_file[0].partition, Struct::empty()); + let mut data_file_writer = + DataFileWriterBuilder::new(parquet_writer_builder, Some(partition_value.clone())) + .build() + .await?; + + let arrow_schema = arrow_schema::Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ]); + let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])), + ])?; + data_file_writer.write(batch).await?; + + let data_files = data_file_writer.close().await.unwrap(); + assert_eq!(data_files.len(), 1); + + let data_file = &data_files[0]; + assert_eq!(data_file.file_format, DataFileFormat::Parquet); + assert_eq!(data_file.content, DataContentType::Data); + assert_eq!(data_file.partition, partition_value); + + let input_file = file_io.new_input(data_file.file_path.clone())?; + let input_content = input_file.read().await?; + + let parquet_reader = + ArrowReaderMetadata::load(&input_content, ArrowReaderOptions::default())?; + + let field_ids: Vec = parquet_reader + .parquet_schema() + .columns() + .iter() + .map(|col| col.self_type().get_basic_info().id()) + .collect(); + assert_eq!(field_ids, vec![5, 6]); + + let field_names: Vec<&str> = parquet_reader + .parquet_schema() + .columns() + .iter() + .map(|col| col.name()) + .collect(); + assert_eq!(field_names, vec!["id", "name"]); Ok(()) }