Skip to content

Commit 24e7da0

Browse files
committed
cast to pyarrow schema
1 parent 36a505f commit 24e7da0

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pyiceberg/table/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,8 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
11261126
except ModuleNotFoundError as e:
11271127
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
11281128

1129+
from pyiceberg.io.pyarrow import schema_to_pyarrow
1130+
11291131
if not isinstance(df, pa.Table):
11301132
raise ValueError(f"Expected PyArrow table, got: {df}")
11311133

@@ -1136,6 +1138,9 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
11361138
raise ValueError("Cannot write to partitioned tables")
11371139

11381140
_check_schema(self.schema(), other_schema=df.schema)
1141+
# safe to cast
1142+
pyarrow_schema = schema_to_pyarrow(self.schema())
1143+
df = df.cast(pyarrow_schema)
11391144

11401145
with self.transaction() as txn:
11411146
with txn.update_snapshot().overwrite() as update_snapshot:

tests/catalog/test_sql.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema(
193193
catalog.drop_table(random_identifier)
194194

195195

196+
@pytest.mark.parametrize(
197+
'catalog',
198+
[
199+
lazy_fixture('catalog_memory'),
200+
# lazy_fixture('catalog_sqlite'),
201+
],
202+
)
203+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
204+
import pyarrow as pa
205+
206+
pyarrow_table = pa.Table.from_arrays(
207+
[
208+
pa.array([None, "A", "B", "C"]), # 'foo' column
209+
pa.array([1, 2, 3, 4]), # 'bar' column
210+
pa.array([True, None, False, True]), # 'baz' column
211+
pa.array([None, "A", "B", "C"]), # 'large' column
212+
],
213+
schema=pa.schema([
214+
pa.field('foo', pa.string(), nullable=True),
215+
pa.field('bar', pa.int32(), nullable=False),
216+
pa.field('baz', pa.bool_(), nullable=True),
217+
pa.field('large', pa.large_string(), nullable=True),
218+
]),
219+
)
220+
database_name, _table_name = random_identifier
221+
catalog.create_namespace(database_name)
222+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
223+
print(pyarrow_table.schema)
224+
print(table.schema().as_struct())
225+
print()
226+
table.overwrite(pyarrow_table)
227+
228+
196229
@pytest.mark.parametrize(
197230
'catalog',
198231
[

0 commit comments

Comments
 (0)