Skip to content

On write operation, cast data to Iceberg Table's pyarrow schema #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT

file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
schema = table_metadata.schema()
arrow_file_schema = schema_to_pyarrow(schema)
arrow_file_schema = schema.as_arrow()

fo = io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
Expand Down
20 changes: 17 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

name_mapping = table_schema.name_mapping
Expand Down Expand Up @@ -1118,7 +1126,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)
_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
if self.schema().as_arrow() != df.schema:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be good to call as_arrow() just once in case we need to cast.

df = df.cast(self.schema().as_arrow())

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
Expand Down Expand Up @@ -1156,7 +1167,10 @@ def overwrite(
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)
_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
if self.schema().as_arrow() != df.schema:
df = df.cast(self.schema().as_arrow())

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
Expand Down
33 changes: 33 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema(
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
# lazy_fixture('catalog_sqlite'),
],
)
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
import pyarrow as pa

pyarrow_table = pa.Table.from_arrays(
[
pa.array([None, "A", "B", "C"]), # 'foo' column
pa.array([1, 2, 3, 4]), # 'bar' column
pa.array([True, None, False, True]), # 'baz' column
pa.array([None, "A", "B", "C"]), # 'large' column
],
schema=pa.schema([
pa.field('foo', pa.string(), nullable=True),
pa.field('bar', pa.int32(), nullable=False),
pa.field('baz', pa.bool_(), nullable=True),
pa.field('large', pa.large_string(), nullable=True),
]),
)
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, pyarrow_table.schema)
print(pyarrow_table.schema)
print(table.schema().as_struct())
print()
table.overwrite(pyarrow_table)


@pytest.mark.parametrize(
'catalog',
[
Expand Down
24 changes: 19 additions & 5 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
TableIdentifier,
UpdateSchema,
_apply_table_update,
_check_schema,
_check_schema_compatible,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
update_table_metadata,
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
Expand All @@ -1054,7 +1054,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
Expand All @@ -1074,7 +1074,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
Expand All @@ -1088,7 +1088,21 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")


def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None:
Expand Down