Skip to content

Commit e24541b

Browse files
kevinjqliuHonahX
authored andcommitted
Cast data to Iceberg Table's pyarrow schema (apache#523)
Backport to 0.6.1
1 parent b9362ee commit e24541b

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
17211721
parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)
17221722

17231723
file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
1724-
file_schema = schema_to_pyarrow(table.schema())
1724+
file_schema = table.schema().as_arrow()
17251725

17261726
fo = table.io.new_output(file_path)
17271727
row_group_size = PropertyUtil.property_as_int(

pyiceberg/table/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,15 @@
132132
_JAVA_LONG_MAX = 9223372036854775807
133133

134134

135-
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
135+
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
136+
"""
137+
Check if the `table_schema` is compatible with `other_schema`.
138+
139+
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
140+
141+
Raises:
142+
ValueError: If the schemas are not compatible.
143+
"""
136144
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
137145

138146
name_mapping = table_schema.name_mapping
@@ -1044,7 +1052,10 @@ def append(self, df: pa.Table) -> None:
10441052
if len(self.spec().fields) > 0:
10451053
raise ValueError("Cannot write to partitioned tables")
10461054

1047-
_check_schema(self.schema(), other_schema=df.schema)
1055+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1056+
# cast if the two schemas are compatible but not equal
1057+
if self.schema().as_arrow() != df.schema:
1058+
df = df.cast(self.schema().as_arrow())
10481059

10491060
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
10501061

@@ -1079,7 +1090,10 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10791090
if len(self.spec().fields) > 0:
10801091
raise ValueError("Cannot write to partitioned tables")
10811092

1082-
_check_schema(self.schema(), other_schema=df.schema)
1093+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1094+
# cast if the two schemas are compatible but not equal
1095+
if self.schema().as_arrow() != df.schema:
1096+
df = df.cast(self.schema().as_arrow())
10831097

10841098
merge = _MergingSnapshotProducer(
10851099
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,

tests/catalog/test_sql.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,39 @@ def test_create_table_with_pyarrow_schema(
191191
catalog.drop_table(random_identifier)
192192

193193

194+
@pytest.mark.parametrize(
195+
'catalog',
196+
[
197+
lazy_fixture('catalog_memory'),
198+
# lazy_fixture('catalog_sqlite'),
199+
],
200+
)
201+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
202+
import pyarrow as pa
203+
204+
pyarrow_table = pa.Table.from_arrays(
205+
[
206+
pa.array([None, "A", "B", "C"]), # 'foo' column
207+
pa.array([1, 2, 3, 4]), # 'bar' column
208+
pa.array([True, None, False, True]), # 'baz' column
209+
pa.array([None, "A", "B", "C"]), # 'large' column
210+
],
211+
schema=pa.schema([
212+
pa.field('foo', pa.string(), nullable=True),
213+
pa.field('bar', pa.int32(), nullable=False),
214+
pa.field('baz', pa.bool_(), nullable=True),
215+
pa.field('large', pa.large_string(), nullable=True),
216+
]),
217+
)
218+
database_name, _table_name = random_identifier
219+
catalog.create_namespace(database_name)
220+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
221+
print(pyarrow_table.schema)
222+
print(table.schema().as_struct())
223+
print()
224+
table.overwrite(pyarrow_table)
225+
226+
194227
@pytest.mark.parametrize(
195228
'catalog',
196229
[

tests/table/test_init.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
Table,
5959
UpdateSchema,
6060
_apply_table_update,
61-
_check_schema,
61+
_check_schema_compatible,
6262
_generate_snapshot_id,
6363
_match_deletes_to_data_file,
6464
_TableMetadataUpdateContext,
@@ -1004,7 +1004,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
10041004
"""
10051005

10061006
with pytest.raises(ValueError, match=expected):
1007-
_check_schema(table_schema_simple, other_schema)
1007+
_check_schema_compatible(table_schema_simple, other_schema)
10081008

10091009

10101010
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1025,7 +1025,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
10251025
"""
10261026

10271027
with pytest.raises(ValueError, match=expected):
1028-
_check_schema(table_schema_simple, other_schema)
1028+
_check_schema_compatible(table_schema_simple, other_schema)
10291029

10301030

10311031
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1045,7 +1045,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
10451045
"""
10461046

10471047
with pytest.raises(ValueError, match=expected):
1048-
_check_schema(table_schema_simple, other_schema)
1048+
_check_schema_compatible(table_schema_simple, other_schema)
10491049

10501050

10511051
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1059,4 +1059,18 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
10591059
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
10601060

10611061
with pytest.raises(ValueError, match=expected):
1062-
_check_schema(table_schema_simple, other_schema)
1062+
_check_schema_compatible(table_schema_simple, other_schema)
1063+
1064+
1065+
def test_schema_downcast(table_schema_simple: Schema) -> None:
1066+
# large_string type is compatible with string type
1067+
other_schema = pa.schema((
1068+
pa.field("foo", pa.large_string(), nullable=True),
1069+
pa.field("bar", pa.int32(), nullable=False),
1070+
pa.field("baz", pa.bool_(), nullable=True),
1071+
))
1072+
1073+
try:
1074+
_check_schema_compatible(table_schema_simple, other_schema)
1075+
except Exception:
1076+
pytest.fail("Unexpected Exception raised when calling `_check_schema`")

0 commit comments

Comments
 (0)