diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 50406972a7..a4ebf2bca8 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -470,9 +470,16 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def schema_to_pyarrow( - schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True + schema: Union[Schema, IcebergType], + metadata: Dict[bytes, bytes] = EMPTY_DICT, + include_field_ids: bool = True, + with_large_types: bool = True, ) -> pa.schema: - return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids)) + pyarrow_schema = visit(schema, _ConvertToArrowSchema(metadata, include_field_ids)) + if with_large_types: + return _pyarrow_schema_ensure_large_types(pyarrow_schema) + else: + return pyarrow_schema class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]): @@ -504,7 +511,7 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field: def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType: element_field = self.field(list_type.element_field, element_result) - return pa.large_list(value_type=element_field) + return pa.list_(value_type=element_field) def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType: key_field = self.field(map_type.key_field, key_result) @@ -548,13 +555,13 @@ def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType: return pa.timestamp(unit="us", tz="UTC") def visit_string(self, _: StringType) -> pa.DataType: - return pa.large_string() + return pa.string() def visit_uuid(self, _: UUIDType) -> pa.DataType: return pa.binary(16) def visit_binary(self, _: BinaryType) -> pa.DataType: - return pa.large_binary() + return pa.binary() def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar: @@ -958,19 +965,23 @@ def after_map_value(self, element: pa.Field) -> None: class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]): def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema: - return pa.schema(struct_result) + return pa.schema(list(struct_result)) def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType: return pa.struct(field_results) def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field: - return field.with_type(field_result) + new_field = field.with_type(field_result) + return new_field def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType: - return pa.large_list(element_result) + element_field = self.field(list_type.value_field, element_result) + return pa.large_list(element_field) def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType: - return pa.map_(key_result, value_result) + key_field = self.field(map_type.key_field, key_result) + value_field = self.field(map_type.item_field, value_result) + return pa.map_(key_type=key_field, item_type=value_field) def primitive(self, primitive: pa.DataType) -> pa.DataType: if primitive == pa.string(): @@ -1004,6 +1015,7 @@ def _task_to_record_batches( positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, name_mapping: Optional[NameMapping] = None, + with_large_types: bool = True, ) -> Iterator[pa.RecordBatch]: _, _, path = PyArrowFileIO.parse_location(task.file.file_path) arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) @@ -1049,7 +1061,7 @@ def _task_to_record_batches( arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) batch = arrow_table.to_batches()[0] - yield to_requested_schema(projected_schema, file_project_schema, batch) + yield to_requested_schema(projected_schema, file_project_schema, batch, with_large_types=with_large_types) current_index += len(batch) @@ -1062,11 +1074,22 @@ def _task_to_table( positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, name_mapping: Optional[NameMapping] = None, + with_large_types: bool = True, ) -> pa.Table: batches = _task_to_record_batches( - fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping + fs, + task, + bound_row_filter, + projected_schema, + projected_field_ids, + positional_deletes, + case_sensitive, + name_mapping, + with_large_types, + ) + return pa.Table.from_batches( + batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False, with_large_types=with_large_types) ) - return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -1095,6 +1118,7 @@ def project_table( projected_schema: Schema, case_sensitive: bool = True, limit: Optional[int] = None, + with_large_types: bool = True, ) -> pa.Table: """Resolve the right columns based on the identifier. @@ -1146,6 +1170,7 @@ def project_table( deletes_per_file.get(task.file.file_path), case_sensitive, table_metadata.name_mapping(), + with_large_types, ) for task in tasks ] @@ -1168,7 +1193,9 @@ def project_table( tables = [f.result() for f in completed_futures if f.result()] if len(tables) < 1: - return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) + return pa.Table.from_batches( + [], schema=schema_to_pyarrow(projected_schema, include_field_ids=False, with_large_types=with_large_types) + ) result = pa.concat_tables(tables) @@ -1186,6 +1213,7 @@ def project_batches( projected_schema: Schema, case_sensitive: bool = True, limit: Optional[int] = None, + with_large_types: bool = True, ) -> Iterator[pa.RecordBatch]: """Resolve the right columns based on the identifier. @@ -1238,6 +1266,7 @@ def project_batches( deletes_per_file.get(task.file.file_path), case_sensitive, table_metadata.name_mapping(), + with_large_types, ) for batch in batches: if limit is not None: @@ -1248,8 +1277,12 @@ def project_batches( total_row_count += len(batch) -def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch: - struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) +def to_requested_schema( + requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, with_large_types: bool = True +) -> pa.RecordBatch: + struct_array = visit_with_partner( + requested_schema, batch, ArrowProjectionVisitor(file_schema, with_large_types), ArrowAccessor(file_schema) + ) arrays = [] fields = [] @@ -1263,15 +1296,26 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): file_schema: Schema - def __init__(self, file_schema: Schema): + def __init__(self, file_schema: Schema, with_large_types: bool = True): self.file_schema = file_schema + self.with_large_types = with_large_types def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: file_field = self.file_schema.find_field(field.field_id) if field.field_type.is_primitive: if field.field_type != file_field.field_type: - return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False)) - elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type: + return values.cast( + schema_to_pyarrow( + promote(file_field.field_type, field.field_type), + include_field_ids=False, + with_large_types=self.with_large_types, + ) + ) + elif ( + target_type := schema_to_pyarrow( + field.field_type, include_field_ids=False, with_large_types=self.with_large_types + ) + ) != values.type: # if file_field and field_type (e.g. String) are the same # but the pyarrow type of the array is different from the expected type # (e.g. string vs larger_string), we want to cast the array to the larger type @@ -1302,7 +1346,7 @@ def struct( field_arrays.append(array) fields.append(self._construct_field(field, array.type)) elif field.optional: - arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False) + arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False, with_large_types=self.with_large_types) field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) fields.append(self._construct_field(field, arrow_type)) else: @@ -1320,7 +1364,10 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: # https://github.com/apache/arrow/issues/38809 list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array) - arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type)) + if self.with_large_types: + arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type)) + else: + arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type)) return list_array.cast(arrow_field) else: return None @@ -1919,14 +1966,14 @@ def write_parquet(task: WriteTask) -> DataFile: file_schema = table_schema batches = [ - to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) + to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch, with_large_types=False) for batch in task.record_batches ] arrow_table = pa.Table.from_batches(batches) file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer: + with pq.ParquetWriter(fos, schema=file_schema.as_arrow(with_large_types=False), **parquet_writer_kwargs) as writer: writer.write(arrow_table, row_group_size=row_group_size) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=writer.writer.metadata, diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 77f1addbf5..e136200a13 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -182,11 +182,11 @@ def as_struct(self) -> StructType: """Return the schema as a struct.""" return StructType(*self.fields) - def as_arrow(self) -> "pa.Schema": + def as_arrow(self, with_large_types: bool = False) -> "pa.Schema": """Return the schema as an Arrow schema.""" from pyiceberg.io.pyarrow import schema_to_pyarrow - return schema_to_pyarrow(self) + return schema_to_pyarrow(self, with_large_types=with_large_types) def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField: """Find a field using a field name or field ID. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8eea9859bc..4cb3d9134f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2010,7 +2010,7 @@ def plan_files(self) -> Iterable[FileScanTask]: for data_entry in data_entries ] - def to_arrow(self) -> pa.Table: + def to_arrow(self, with_large_types: bool = True) -> pa.Table: from pyiceberg.io.pyarrow import project_table return project_table( @@ -2021,15 +2021,16 @@ def to_arrow(self) -> pa.Table: self.projection(), case_sensitive=self.case_sensitive, limit=self.limit, + with_large_types=with_large_types, ) - def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + def to_arrow_batch_reader(self, with_large_types: bool = True) -> pa.RecordBatchReader: import pyarrow as pa from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow return pa.RecordBatchReader.from_batches( - schema_to_pyarrow(self.projection()), + schema_to_pyarrow(self.projection(), include_field_ids=False, with_large_types=with_large_types), project_batches( self.plan_files(), self.table_metadata, @@ -2038,6 +2039,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: self.projection(), case_sensitive=self.case_sensitive, limit=self.limit, + with_large_types=with_large_types, ), ) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 9251d717f8..fb16a483eb 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -39,7 +39,11 @@ TableAlreadyExistsError, ) from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL -from pyiceberg.io.pyarrow import _dataframe_to_data_files, schema_to_pyarrow +from pyiceberg.io.pyarrow import ( + _dataframe_to_data_files, + _pyarrow_schema_ensure_large_types, + schema_to_pyarrow, +) from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC from pyiceberg.schema import Schema from pyiceberg.table.snapshots import Operation @@ -1549,3 +1553,152 @@ def test_table_exists(catalog: SqlCatalog, table_schema_simple: Schema, table_id # Act and Assert for a non-existing table assert catalog.table_exists(("non", "exist")) is False + + +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_read_arrow_table(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: + table_identifier_nocatalog = catalog.identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple, with_large_types=False), + ) + + table.append(df) + + # read back the data + assert df == table.scan().to_arrow(with_large_types=False) + + +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_read_arrow_table_large_types(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: + table_identifier_nocatalog = catalog.identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple, with_large_types=False), + ) + + table.append(df) + + # read back the data + assert df.cast(_pyarrow_schema_ensure_large_types(df.schema)) == table.scan().to_arrow(with_large_types=True) + + +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_read_arrow_batch_reader(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: + table_identifier_nocatalog = catalog.identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple, with_large_types=False), + ) + + table.append(df) + + # read back the data + assert df == table.scan().to_arrow_batch_reader(with_large_types=False).read_all() + + +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_read_arrow_batch_reader_large_types( + catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier +) -> None: + table_identifier_nocatalog = catalog.identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple, with_large_types=False), + ) + + table.append(df) + + # read back the data + assert ( + df.cast(_pyarrow_schema_ensure_large_types(df.schema)) + == table.scan().to_arrow_batch_reader(with_large_types=True).read_all() + ) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index ecb946a98b..dfbcb3b7fc 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -486,17 +486,17 @@ def test_timestamptz_type_to_pyarrow() -> None: def test_string_type_to_pyarrow() -> None: iceberg_type = StringType() - assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_string() + assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.string() def test_binary_type_to_pyarrow() -> None: iceberg_type = BinaryType() - assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.large_binary() + assert visit(iceberg_type, _ConvertToArrowSchema()) == pa.binary() def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: expected = pa.struct([ - pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}), + pa.field("foo", pa.string(), nullable=True, metadata={"field_id": "1"}), pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), ]) @@ -513,7 +513,7 @@ def test_map_type_to_pyarrow() -> None: ) assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_( pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}), - pa.field("value", pa.large_string(), nullable=False, metadata={"field_id": "2"}), + pa.field("value", pa.string(), nullable=False, metadata={"field_id": "2"}), ) @@ -523,7 +523,7 @@ def test_list_type_to_pyarrow() -> None: element_type=IntegerType(), element_required=True, ) - assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.large_list( + assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_( pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"}) ) diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index d3b6217c7b..820fd0471f 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -213,14 +213,14 @@ def test_pyarrow_string_to_iceberg() -> None: pyarrow_type = pa.large_string() converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) assert converted_iceberg_type == StringType() - assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pa.string() def test_pyarrow_variable_binary_to_iceberg() -> None: pyarrow_type = pa.large_binary() converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) assert converted_iceberg_type == BinaryType() - assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pa.binary() def test_pyarrow_struct_to_iceberg() -> None: diff --git a/tests/test_schema.py b/tests/test_schema.py index 23b42ef49e..96109ce9c2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1610,7 +1610,7 @@ def test_arrow_schema() -> None: ) expected_schema = pa.schema([ - pa.field("foo", pa.large_string(), nullable=False), + pa.field("foo", pa.string(), nullable=False), pa.field("bar", pa.int32(), nullable=True), pa.field("baz", pa.bool_(), nullable=True), ])