diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 199133f794..cd6736fbba 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -120,6 +120,7 @@ Schema, SchemaVisitorPerPrimitiveType, SchemaWithPartnerVisitor, + _check_schema_compatible, pre_order_visit, promote, prune_columns, @@ -1407,7 +1408,7 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: # This can be removed once this has been fixed: # https://github.com/apache/arrow/issues/38809 list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array) - + value_array = self._cast_if_needed(list_type.element_field, value_array) arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type)) return list_array.cast(arrow_field) else: @@ -1417,6 +1418,8 @@ def map( self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array] ) -> Optional[pa.Array]: if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None: + key_result = self._cast_if_needed(map_type.key_field, key_result) + value_result = self._cast_if_needed(map_type.value_field, value_result) arrow_field = pa.map_( self._construct_field(map_type.key_field, key_result.type), self._construct_field(map_type.value_field, value_result.type), @@ -1549,9 +1552,16 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc expected_physical_type = _primitive_to_physical(iceberg_type) if expected_physical_type != physical_type_string: - raise ValueError( - f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" - ) + # Allow promotable physical types + # INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts + if (physical_type_string == "INT32" and expected_physical_type == "INT64") or ( + physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE" + ): + pass + else: + raise ValueError( + f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}" + ) self.primitive_type = iceberg_type @@ -1896,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata( set the mode for column metrics collection parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID """ - if parquet_metadata.num_columns != len(stats_columns): - raise ValueError( - f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" - ) - - if parquet_metadata.num_columns != len(parquet_column_mapping): - raise ValueError( - f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})" - ) - column_sizes: Dict[int, int] = {} value_counts: Dict[int, int] = {} split_offsets: List[int] = [] @@ -1998,8 +1998,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT ) def write_parquet(task: WriteTask) -> DataFile: - table_schema = task.schema - + table_schema = table_metadata.schema() # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: @@ -2011,7 +2010,7 @@ def write_parquet(task: WriteTask) -> DataFile: batches = [ _to_requested_schema( requested_schema=file_schema, - file_schema=table_schema, + file_schema=task.schema, batch=batch, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, include_field_ids=True, @@ -2070,47 +2069,30 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ return bin_packed_record_batches -def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: +def _check_pyarrow_schema_compatible( + requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False +) -> None: """ - Check if the `table_schema` is compatible with `other_schema`. + Check if the `requested_schema` is compatible with `provided_schema`. Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. Raises: ValueError: If the schemas are not compatible. """ - name_mapping = table_schema.name_mapping + name_mapping = requested_schema.name_mapping try: - task_schema = pyarrow_to_schema( - other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + provided_schema = pyarrow_to_schema( + provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) - additional_names = set(other_schema.column_names) - set(table_schema.column_names) + provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys()) raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - if table_schema.as_struct() != task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable - - console = Console(record=True) - - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") - - for lhs in table_schema.fields: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) - except ValueError: - rich_table.add_row("❌", str(lhs), "Missing") - - console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + _check_schema_compatible(requested_schema, provided_schema) def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: @@ -2124,7 +2106,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_ f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" ) schema = table_metadata.schema() - _check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) + _check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=parquet_metadata, @@ -2205,7 +2187,7 @@ def _dataframe_to_data_files( Returns: An iterable that supplies datafiles that represent the table. """ - from pyiceberg.table import PropertyUtil, TableProperties, WriteTask + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask counter = counter or itertools.count(0) write_uuid = write_uuid or uuid.uuid4() @@ -2214,13 +2196,16 @@ def _dataframe_to_data_files( property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + name_mapping = table_metadata.schema().name_mapping + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) for batches in bin_pack_arrow_table(df, target_file_size) ]), ) @@ -2235,7 +2220,7 @@ def _dataframe_to_data_files( task_id=next(counter), record_batches=batches, partition_key=partition.partition_key, - schema=table_metadata.schema(), + schema=task_schema, ) for partition in partitions for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 77f1addbf5..cfe3fe3a7b 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType: return read_type else: raise ResolveError(f"Cannot promote {file_type} to {read_type}") + + +def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None: + """ + Check if the `provided_schema` is compatible with `requested_schema`. + + Both Schemas must have valid IDs and share the same ID for the same field names. + + Two schemas are considered compatible when: + 1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema` + 2. Field Types are consistent for fields that are present in both schemas. I.e. the field type + in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema` + + Raises: + ValueError: If the schemas are not compatible. + """ + pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema)) + + +class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]): + provided_schema: Schema + + def __init__(self, provided_schema: Schema): + from rich.console import Console + from rich.table import Table as RichTable + + self.provided_schema = provided_schema + self.rich_table = RichTable(show_header=True, header_style="bold") + self.rich_table.add_column("") + self.rich_table.add_column("Table field") + self.rich_table.add_column("Dataframe field") + self.console = Console(record=True) + + def _is_field_compatible(self, lhs: NestedField) -> bool: + # Validate nullability first. + # An optional field can be missing in the provided schema + # But a required field must exist as a required field + try: + rhs = self.provided_schema.find_field(lhs.field_id) + except ValueError: + if lhs.required: + self.rich_table.add_row("❌", str(lhs), "Missing") + return False + else: + self.rich_table.add_row("✅", str(lhs), "Missing") + return True + + if lhs.required and not rhs.required: + self.rich_table.add_row("❌", str(lhs), str(rhs)) + return False + + # Check type compatibility + if lhs.field_type == rhs.field_type: + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + # We only check that the parent node is also of the same type. + # We check the type of the child nodes when we traverse them later. + elif any( + (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) + for container_type in {StructType, MapType, ListType} + ): + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + else: + try: + # If type can be promoted to the requested schema + # it is considered compatible + promote(rhs.field_type, lhs.field_type) + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True + except ResolveError: + self.rich_table.add_row("❌", str(lhs), str(rhs)) + return False + + def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool: + if not (result := struct_result()): + self.console.print(self.rich_table) + raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}") + return result + + def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool: + results = [result() for result in field_results] + return all(results) + + def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool: + return self._is_field_compatible(field) and field_result() + + def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool: + return self._is_field_compatible(list_type.element_field) and element_result() + + def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool: + return all([ + self._is_field_compatible(map_type.key_field), + self._is_field_compatible(map_type.value_field), + key_result(), + value_result(), + ]) + + def primitive(self, primitive: PrimitiveType) -> bool: + return True diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b43dc3206b..0b211e673d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -73,7 +73,6 @@ manifest_evaluator, ) from pyiceberg.io import FileIO, OutputFile, load_file_io -from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table from pyiceberg.manifest import ( POSITIONAL_DELETE_SCHEMA, DataFile, @@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files + if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -481,8 +482,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_schema_compatible( - self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + _check_pyarrow_schema_compatible( + self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) manifest_merge_enabled = PropertyUtil.property_as_bool( @@ -528,6 +529,8 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files + if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -538,8 +541,8 @@ def overwrite( f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_schema_compatible( - self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + _check_pyarrow_schema_compatible( + self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) @@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary """ + from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table + if ( self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT) == TableProperties.DELETE_MODE_MERGE_ON_READ diff --git a/tests/conftest.py b/tests/conftest.py index 91ab8f2e56..7f9a2bcfa8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2506,3 +2506,62 @@ def table_schema_with_all_microseconds_timestamp_precision() -> Schema: NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False), NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False), ) + + +@pytest.fixture(scope="session") +def table_schema_with_promoted_types() -> Schema: + """Iceberg table Schema with longs, doubles and uuid in simple and nested types.""" + return Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + NestedField( + field_id=2, + name="list", + field_type=ListType(element_id=4, element_type=LongType(), element_required=False), + required=True, + ), + NestedField( + field_id=3, + name="map", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=LongType(), + value_required=False, + ), + required=True, + ), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + NestedField(field_id=8, name="uuid", field_type=UUIDType(), required=False), + ) + + +@pytest.fixture(scope="session") +def pyarrow_schema_with_promoted_types() -> "pa.Schema": + """Pyarrow Schema with longs, doubles and uuid in simple and nested types.""" + import pyarrow as pa + + return pa.schema(( + pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long + pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double + )) + + +@pytest.fixture(scope="session") +def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Schema") -> "pa.Table": + """Pyarrow table with longs, doubles and uuid in simple and nested types.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "long": [1, 9], + "list": [[1, 1], [2, 2]], + "map": [{"a": 1}, {"b": 2}], + "double": [1.1, 9.2], + "uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"], + }, + schema=pyarrow_schema_with_promoted_types, + ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index b8fd6d0926..3703a9e0b6 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -30,6 +30,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.io import FileIO +from pyiceberg.io.pyarrow import _pyarrow_schema_ensure_large_types from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table @@ -38,6 +39,7 @@ BooleanType, DateType, IntegerType, + LongType, NestedField, StringType, TimestamptzType, @@ -505,7 +507,7 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog ┃ ┃ Table field ┃ Dataframe field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │ -| ✅ │ 2: bar: optional string │ 2: bar: optional string │ +│ ✅ │ 2: bar: optional string │ 2: bar: optional string │ │ ❌ │ 3: baz: optional int │ 3: baz: optional string │ │ ✅ │ 4: qux: optional date │ 4: qux: optional date │ └────┴──────────────────────────┴──────────────────────────┘ @@ -589,18 +591,7 @@ def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) identifier = f"default.timestamptz_ns_added{format_version}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - tbl = session_catalog.create_table( - identifier=identifier, - schema=nanoseconds_schema_iceberg, - properties={"format-version": str(format_version)}, - partition_spec=PartitionSpec(), - ) + tbl = _create_table(session_catalog, identifier, format_version, schema=nanoseconds_schema_iceberg) file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet" # write parquet files @@ -617,3 +608,127 @@ def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v ), ): tbl.add_files(file_paths=[file_path]) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_file_with_valid_nullability_diff(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_table_with_valid_nullability_diff{format_version}" + table_schema = Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + ) + other_schema = pa.schema(( + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + )) + arrow_table = pa.Table.from_pydict( + { + "long": [1, 9], + }, + schema=other_schema, + ) + tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema) + + file_path = f"s3://warehouse/default/test_add_file_with_valid_nullability_diff/v{format_version}/test.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=other_schema) as writer: + writer.write_table(arrow_table) + + tbl.add_files(file_paths=[file_path]) + # table's long field should cast to be optional on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", pa.int64(), nullable=True),))) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_with_valid_upcast( + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + table_schema_with_promoted_types: Schema, + pyarrow_schema_with_promoted_types: pa.Schema, + pyarrow_table_with_promoted_types: pa.Table, +) -> None: + identifier = f"default.test_table_with_valid_upcast{format_version}" + tbl = _create_table(session_catalog, identifier, format_version, schema=table_schema_with_promoted_types) + + file_path = f"s3://warehouse/default/test_add_files_with_valid_upcast/v{format_version}/test.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) as writer: + writer.write_table(pyarrow_table_with_promoted_types) + + tbl.add_files(file_paths=[file_path]) + # table's long field should cast to long on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == pyarrow_table_with_promoted_types.cast( + pa.schema(( + pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + )) + ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if column == "map": + # Arrow returns a list of tuples, instead of a dict + right = dict(right) + if column == "list": + # Arrow returns an array, convert to list for equality check + left, right = list(left), list(right) + if column == "uuid": + # Spark Iceberg represents UUID as hex string like '715a78ef-4e53-4089-9bf9-3ad0ee9bf545' + # whereas PyIceberg represents UUID as bytes on read + left, right = left.replace("-", ""), right.hex() + assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_subset_of_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_table_subset_of_schema{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_path = f"s3://warehouse/default/test_add_files_subset_of_schema/v{format_version}/test.parquet" + arrow_table_without_some_columns = ARROW_TABLE.combine_chunks().drop(ARROW_TABLE.column_names[0]) + + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=arrow_table_without_some_columns.schema) as writer: + writer.write_table(arrow_table_without_some_columns) + + tbl.add_files(file_paths=[file_path]) + written_arrow_table = tbl.scan().to_arrow() + assert tbl.scan().to_arrow() == pa.Table.from_pylist( + [ + { + "foo": None, # Missing column is read as None on read + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + } + ], + schema=_pyarrow_schema_ensure_large_types(ARROW_SCHEMA), + ) + + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 41bc6fb5bf..09fe654d29 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -43,7 +43,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import IntegerType, NestedField +from pyiceberg.types import IntegerType, LongType, NestedField from utils import _create_table @@ -964,9 +964,10 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 +@pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: - identifier = "default.table_append_subset_of_schema" +def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_subset_of_schema" tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) @@ -976,6 +977,101 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_out_of_order_schema" + # rotate the schema fields by 1 + fields = list(arrow_table_with_null.schema) + rotated_fields = fields[1:] + fields[:1] + rotated_schema = pa.schema(rotated_fields) + assert arrow_table_with_null.schema != rotated_schema + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) + + tbl.overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + # overwrite and then append should produce twice the data + assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_nullability_diff( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_table_write_with_valid_nullability_diff" + table_schema = Schema( + NestedField(field_id=1, name="long", field_type=LongType(), required=False), + ) + other_schema = pa.schema(( + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + )) + arrow_table = pa.Table.from_pydict( + { + "long": [1, 9], + }, + schema=other_schema, + ) + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) + # table's long field should cast to be optional on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", pa.int64(), nullable=True),))) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_table_write_schema_with_valid_upcast( + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + table_schema_with_promoted_types: Schema, + pyarrow_schema_with_promoted_types: pa.Schema, + pyarrow_table_with_promoted_types: pa.Table, +) -> None: + identifier = "default.test_table_write_with_valid_upcast" + + tbl = _create_table( + session_catalog, + identifier, + {"format-version": format_version}, + [pyarrow_table_with_promoted_types], + schema=table_schema_with_promoted_types, + ) + # table's long field should cast to long on read + written_arrow_table = tbl.scan().to_arrow() + assert written_arrow_table == pyarrow_table_with_promoted_types.cast( + pa.schema(( + pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + )) + ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if column == "map": + # Arrow returns a list of tuples, instead of a dict + right = dict(right) + if column == "list": + # Arrow returns an array, convert to list for equality check + left, right = list(left), list(right) + if column == "uuid": + # Spark Iceberg represents UUID as hex string like '715a78ef-4e53-4089-9bf9-3ad0ee9bf545' + # whereas PyIceberg represents UUID as bytes on read + left, right = left.replace("-", ""), right.hex() + assert left == right + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision( diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 37198b7edb..d61a50bb0d 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -60,7 +60,7 @@ PyArrowFile, PyArrowFileIO, StatsAggregator, - _check_schema_compatible, + _check_pyarrow_schema_compatible, _ConvertToArrowSchema, _determine_partitions, _primitive_to_physical, @@ -1742,7 +1742,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: @@ -1763,7 +1763,20 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=False), + )) + + try: + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: @@ -1783,21 +1796,114 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Schema) -> None: + schema = table_schema_nested.as_arrow() + schema = schema.remove(6).insert( + 6, + pa.field( + "person", + pa.struct([ + pa.field("age", pa.int32(), nullable=False), + ]), + nullable=True, + ), + ) + try: + _check_pyarrow_schema_compatible(table_schema_nested, schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") + + +def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Schema) -> None: + other_schema = table_schema_nested.as_arrow() + other_schema = other_schema.remove(6).insert( + 6, + pa.field( + "person", + pa.struct([ + pa.field("name", pa.string(), nullable=True), + ]), + nullable=True, + ), + ) + expected = """Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ✅ │ 2: bar: required int │ 2: bar: required int │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +│ ✅ │ 4: qux: required list │ 4: qux: required list │ +│ ✅ │ 5: element: required string │ 5: element: required string │ +│ ✅ │ 6: quux: required map> │ map> │ +│ ✅ │ 7: key: required string │ 7: key: required string │ +│ ✅ │ 8: value: required map │ int> │ +│ ✅ │ 9: key: required string │ 9: key: required string │ +│ ✅ │ 10: value: required int │ 10: value: required int │ +│ ✅ │ 11: location: required │ 11: location: required │ +│ │ list> │ float>> │ +│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │ +│ │ latitude: optional float, 14: │ latitude: optional float, 14: │ +│ │ longitude: optional float> │ longitude: optional float> │ +│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │ +│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │ +│ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │ +│ │ name: optional string, 17: age: │ name: optional string> │ +│ │ required int> │ │ +│ ✅ │ 16: name: optional string │ 16: name: optional string │ +│ ❌ │ 17: age: required int │ Missing │ +└────┴────────────────────────────────────┴────────────────────────────────────┘ +""" + + with pytest.raises(ValueError, match=expected): + _check_pyarrow_schema_compatible(table_schema_nested, other_schema) + + +def test_schema_compatible_nested(table_schema_nested: Schema) -> None: + try: + _check_pyarrow_schema_compatible(table_schema_nested, table_schema_nested.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), pa.field("baz", pa.bool_(), nullable=True), pa.field("new_field", pa.date32(), nullable=True), )) - expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." + with pytest.raises( + ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." + ): + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) - with pytest.raises(ValueError, match=expected): - _check_schema_compatible(table_schema_simple, other_schema) + +def test_schema_compatible(table_schema_simple: Schema) -> None: + try: + _check_pyarrow_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") + + +def test_schema_projection(table_schema_simple: Schema) -> None: + # remove optional `baz` field from `table_schema_simple` + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + )) + try: + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_schema_downcast(table_schema_simple: Schema) -> None: @@ -1809,9 +1915,9 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: )) try: - _check_schema_compatible(table_schema_simple, other_schema) + _check_pyarrow_schema_compatible(table_schema_simple, other_schema) except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") + pytest.fail("Unexpected Exception raised when calling `_check_pyarrow_schema_compatible`") def test_partition_for_demo() -> None: