diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 72de14880a..9dbb5e8abc 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1784,7 +1784,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}' schema = table_metadata.schema() - arrow_file_schema = schema_to_pyarrow(schema) + arrow_file_schema = schema.as_arrow() fo = io.new_output(file_path) row_group_size = PropertyUtil.property_as_int( diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2b15cdeb08..2ad1f7fe81 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -145,7 +145,15 @@ _JAVA_LONG_MAX = 9223372036854775807 -def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: +def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None: + """ + Check if the `table_schema` is compatible with `other_schema`. + + Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. + + Raises: + ValueError: If the schemas are not compatible. + """ from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -1118,7 +1126,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - _check_schema(self.schema(), other_schema=df.schema) + _check_schema_compatible(self.schema(), other_schema=df.schema) + # cast if the two schemas are compatible but not equal + if self.schema().as_arrow() != df.schema: + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: @@ -1156,7 +1167,10 @@ def overwrite( if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - _check_schema(self.schema(), other_schema=df.schema) + _check_schema_compatible(self.schema(), other_schema=df.schema) + # cast if the two schemas are compatible but not equal + if self.schema().as_arrow() != df.schema: + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 3a77f8678a..b20f617e32 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema( catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + # lazy_fixture('catalog_sqlite'), + ], +) +def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None: + import pyarrow as pa + + pyarrow_table = pa.Table.from_arrays( + [ + pa.array([None, "A", "B", "C"]), # 'foo' column + pa.array([1, 2, 3, 4]), # 'bar' column + pa.array([True, None, False, True]), # 'baz' column + pa.array([None, "A", "B", "C"]), # 'large' column + ], + schema=pa.schema([ + pa.field('foo', pa.string(), nullable=True), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + pa.field('large', pa.large_string(), nullable=True), + ]), + ) + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, pyarrow_table.schema) + print(pyarrow_table.schema) + print(table.schema().as_struct()) + print() + table.overwrite(pyarrow_table) + + @pytest.mark.parametrize( 'catalog', [ diff --git a/tests/table/test_init.py b/tests/table/test_init.py index bb212d696e..f1191295f3 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -63,7 +63,7 @@ TableIdentifier, UpdateSchema, _apply_table_update, - _check_schema, + _check_schema_compatible, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, @@ -1033,7 +1033,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: @@ -1054,7 +1054,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: @@ -1074,7 +1074,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: @@ -1088,7 +1088,21 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) + + +def test_schema_downcast(table_schema_simple: Schema) -> None: + # large_string type is compatible with string type + other_schema = pa.schema(( + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema`") def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None: