Skip to content

Commit e548117

Browse files
committed
PyArrow: Avoid buffer-overflow by avoid doing a sort
This was already being discussed back here: #208 (comment) This PR changes from doing a sort, and then a single pass over the table to the the approach where we determine the unique partition tuples then filter on them one by one. Fixes #1491 Because the sort caused buffers to be joined where it would overflow in Arrow. I think this is an issue on the Arrow side, and it should automatically break up into smaller buffers. The `combine_chunks` method does this correctly. Now: ``` 0.42877754200890195 Run 1 took: 0.2507691659993725 Run 2 took: 0.24833179199777078 Run 3 took: 0.24401691700040828 Run 4 took: 0.2419595829996979 Average runtime of 0.28 seconds ``` Before: ``` Run 0 took: 1.0768639159941813 Run 1 took: 0.8784021250030492 Run 2 took: 0.8486490420036716 Run 3 took: 0.8614017910003895 Run 4 took: 0.8497851670108503 Average runtime of 0.9 seconds ``` So it comes with a nice speedup as well :)
1 parent 818cd15 commit e548117

File tree

4 files changed

+144
-68
lines changed

4 files changed

+144
-68
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 42 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
import concurrent.futures
2929
import fnmatch
30+
import functools
3031
import itertools
3132
import logging
33+
import operator
3234
import os
3335
import re
3436
import uuid
@@ -2542,36 +2544,6 @@ class _TablePartition:
25422544
arrow_table_partition: pa.Table
25432545

25442546

2545-
def _get_table_partitions(
2546-
arrow_table: pa.Table,
2547-
partition_spec: PartitionSpec,
2548-
schema: Schema,
2549-
slice_instructions: list[dict[str, Any]],
2550-
) -> list[_TablePartition]:
2551-
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])
2552-
2553-
partition_fields = partition_spec.fields
2554-
2555-
offsets = [inst["offset"] for inst in sorted_slice_instructions]
2556-
projected_and_filtered = {
2557-
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
2558-
.take(offsets)
2559-
.to_pylist()
2560-
for partition_field in partition_fields
2561-
}
2562-
2563-
table_partitions = []
2564-
for idx, inst in enumerate(sorted_slice_instructions):
2565-
partition_slice = arrow_table.slice(**inst)
2566-
fieldvalues = [
2567-
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
2568-
for partition_field in partition_fields
2569-
]
2570-
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
2571-
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
2572-
return table_partitions
2573-
2574-
25752547
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
25762548
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
25772549
@@ -2594,42 +2566,46 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
25942566
We then retrieve the partition keys by offsets.
25952567
And slice the arrow table by offsets and lengths of each partition.
25962568
"""
2597-
partition_columns: List[Tuple[PartitionField, NestedField]] = [
2598-
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
2599-
]
2600-
partition_values_table = pa.table(
2601-
{
2602-
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
2603-
for partition, field in partition_columns
2604-
}
2605-
)
2569+
# Assign unique names to columns where the partition transform has been applied
2570+
# to avoid conflicts
2571+
partition_fields = [f"_partition_{field.name}" for field in spec.fields]
2572+
2573+
for partition, name in zip(spec.fields, partition_fields):
2574+
source_field = schema.find_field(partition.source_id)
2575+
arrow_table = arrow_table.append_column(
2576+
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
2577+
)
2578+
2579+
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
2580+
2581+
table_partitions = []
2582+
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
2583+
for unique_partition in unique_partition_fields.to_pylist():
2584+
partition_key = PartitionKey(
2585+
raw_partition_field_values=[
2586+
PartitionFieldValue(field=field, value=unique_partition[name])
2587+
for field, name in zip(spec.fields, partition_fields)
2588+
],
2589+
partition_spec=spec,
2590+
schema=schema,
2591+
)
2592+
filtered_table = arrow_table.filter(
2593+
functools.reduce(
2594+
operator.and_,
2595+
[
2596+
pc.field(partition_field_name) == unique_partition[partition_field_name]
2597+
if unique_partition[partition_field_name] is not None
2598+
else pc.field(partition_field_name).is_null()
2599+
for field, partition_field_name in zip(spec.fields, partition_fields)
2600+
],
2601+
)
2602+
)
2603+
filtered_table = filtered_table.drop_columns(partition_fields)
26062604

2607-
# Sort by partitions
2608-
sort_indices = pa.compute.sort_indices(
2609-
partition_values_table,
2610-
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
2611-
null_placement="at_end",
2612-
).to_pylist()
2613-
arrow_table = arrow_table.take(sort_indices)
2614-
2615-
# Get slice_instructions to group by partitions
2616-
partition_values_table = partition_values_table.take(sort_indices)
2617-
reversed_indices = pa.compute.sort_indices(
2618-
partition_values_table,
2619-
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
2620-
null_placement="at_start",
2621-
).to_pylist()
2622-
slice_instructions: List[Dict[str, Any]] = []
2623-
last = len(reversed_indices)
2624-
reversed_indices_size = len(reversed_indices)
2625-
ptr = 0
2626-
while ptr < reversed_indices_size:
2627-
group_size = last - reversed_indices[ptr]
2628-
offset = reversed_indices[ptr]
2629-
slice_instructions.append({"offset": offset, "length": group_size})
2630-
last = reversed_indices[ptr]
2631-
ptr = ptr + group_size
2632-
2633-
table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
2605+
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
2606+
# fresh buffers that don't interfere with each other when it is written out to file
2607+
table_partitions.append(
2608+
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
2609+
)
26342610

26352611
return table_partitions

pyiceberg/partitioning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Optional,
3030
Tuple,
3131
TypeVar,
32+
Union,
3233
)
3334
from urllib.parse import quote_plus
3435

@@ -425,8 +426,13 @@ def _to_partition_representation(type: IcebergType, value: Any) -> Any:
425426

426427
@_to_partition_representation.register(TimestampType)
427428
@_to_partition_representation.register(TimestamptzType)
428-
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
429-
return datetime_to_micros(value) if value is not None else None
429+
def _(type: IcebergType, value: Optional[Union[datetime, int]]) -> Optional[int]:
430+
if value is None:
431+
return None
432+
elif isinstance(value, int):
433+
return value
434+
else:
435+
return datetime_to_micros(value)
430436

431437

432438
@_to_partition_representation.register(DateType)

tests/benchmark/test_benchmark.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import statistics
18+
import timeit
19+
import urllib
20+
21+
import pyarrow as pa
22+
import pyarrow.parquet as pq
23+
import pytest
24+
25+
from pyiceberg.transforms import DayTransform
26+
27+
28+
@pytest.fixture(scope="session")
29+
def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table:
30+
"""Reads the Taxi dataset to disk"""
31+
taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet"
32+
taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet"
33+
urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest)
34+
35+
return pq.read_table(taxi_dataset_dest)
36+
37+
38+
def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None:
39+
"""Tests writing to a partitioned table with something that would be close a production-like situation"""
40+
from pyiceberg.catalog.sql import SqlCatalog
41+
42+
warehouse_path = str(tmp_path_factory.mktemp("warehouse"))
43+
catalog = SqlCatalog(
44+
"default",
45+
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
46+
warehouse=f"file://{warehouse_path}",
47+
)
48+
49+
catalog.create_namespace("default")
50+
51+
tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema)
52+
53+
with tbl.update_spec() as spec:
54+
spec.add_field("tpep_pickup_datetime", DayTransform())
55+
56+
# Profiling can sometimes be handy as well
57+
# with cProfile.Profile() as pr:
58+
# tbl.append(taxi_dataset)
59+
#
60+
# pr.print_stats(sort=True)
61+
62+
runs = []
63+
for run in range(5):
64+
start_time = timeit.default_timer()
65+
tbl.append(taxi_dataset)
66+
elapsed = timeit.default_timer() - start_time
67+
68+
print(f"Run {run} took: {elapsed}")
69+
runs.append(elapsed)
70+
71+
print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds")

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint:disable=redefined-outer-name
1818

1919

20+
import random
2021
from datetime import date
2122
from typing import Any, Set
2223

@@ -1126,3 +1127,25 @@ def test_append_multiple_partitions(
11261127
"""
11271128
)
11281129
assert files_df.count() == 6
1130+
1131+
1132+
@pytest.mark.integration
1133+
def test_pyarrow_overflow(session_catalog: Catalog) -> None:
1134+
"""Test what happens when the offset is beyond 32 bits"""
1135+
identifier = "default.arrow_table_overflow"
1136+
try:
1137+
session_catalog.drop_table(identifier=identifier)
1138+
except NoSuchTableError:
1139+
pass
1140+
1141+
x = pa.array([random.randint(0, 999) for _ in range(30_000)])
1142+
ta = pa.chunked_array([x] * 10_000)
1143+
y = ["fixed_string"] * 30_000
1144+
tb = pa.chunked_array([y] * 10_000)
1145+
# Create pa.table
1146+
arrow_table = pa.table({"a": ta, "b": tb})
1147+
1148+
table = session_catalog.create_table(identifier, arrow_table.schema)
1149+
with table.update_spec() as update_spec:
1150+
update_spec.add_field("b", IdentityTransform(), "pb")
1151+
table.append(arrow_table)

0 commit comments

Comments
 (0)