Skip to content

Commit 36d383d

Browse files
Fokkokevinjqliu
andauthored
PyArrow: Avoid buffer-overflow by avoid doing a sort (#1555)
Second attempt of #1539 This was already being discussed back here: #208 (comment) This PR changes from doing a sort, and then a single pass over the table to the approach where we determine the unique partition tuples filter on them individually. Fixes #1491 Because the sort caused buffers to be joined where it would overflow in Arrow. I think this is an issue on the Arrow side, and it should automatically break up into smaller buffers. The `combine_chunks` method does this correctly. Now: ``` 0.42877754200890195 Run 1 took: 0.2507691659993725 Run 2 took: 0.24833179199777078 Run 3 took: 0.24401691700040828 Run 4 took: 0.2419595829996979 Average runtime of 0.28 seconds ``` Before: ``` Run 0 took: 1.0768639159941813 Run 1 took: 0.8784021250030492 Run 2 took: 0.8486490420036716 Run 3 took: 0.8614017910003895 Run 4 took: 0.8497851670108503 Average runtime of 0.9 seconds ``` So it comes with a nice speedup as well :) --------- Co-authored-by: Kevin Liu <[email protected]>
1 parent 872a445 commit 36d383d

File tree

7 files changed

+805
-743
lines changed

7 files changed

+805
-743
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 50 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
import concurrent.futures
2929
import fnmatch
30+
import functools
3031
import itertools
3132
import logging
33+
import operator
3234
import os
3335
import re
3436
import uuid
@@ -2174,7 +2176,10 @@ def _partition_value(self, partition_field: PartitionField, schema: Schema) -> A
21742176
raise ValueError(
21752177
f"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
21762178
)
2177-
return lower_value
2179+
2180+
source_field = schema.find_field(partition_field.source_id)
2181+
transform = partition_field.transform.transform(source_field.field_type)
2182+
return transform(lower_value)
21782183

21792184
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
21802185
return Record(**{field.name: self._partition_value(field, schema) for field in partition_spec.fields})
@@ -2558,38 +2563,8 @@ class _TablePartition:
25582563
arrow_table_partition: pa.Table
25592564

25602565

2561-
def _get_table_partitions(
2562-
arrow_table: pa.Table,
2563-
partition_spec: PartitionSpec,
2564-
schema: Schema,
2565-
slice_instructions: list[dict[str, Any]],
2566-
) -> list[_TablePartition]:
2567-
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])
2568-
2569-
partition_fields = partition_spec.fields
2570-
2571-
offsets = [inst["offset"] for inst in sorted_slice_instructions]
2572-
projected_and_filtered = {
2573-
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
2574-
.take(offsets)
2575-
.to_pylist()
2576-
for partition_field in partition_fields
2577-
}
2578-
2579-
table_partitions = []
2580-
for idx, inst in enumerate(sorted_slice_instructions):
2581-
partition_slice = arrow_table.slice(**inst)
2582-
fieldvalues = [
2583-
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
2584-
for partition_field in partition_fields
2585-
]
2586-
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
2587-
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
2588-
return table_partitions
2589-
2590-
25912566
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
2592-
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2567+
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.
25932568
25942569
Example:
25952570
Input:
@@ -2598,54 +2573,50 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
25982573
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
25992574
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
26002575
The algorithm:
2601-
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
2602-
and null_placement of "at_end".
2603-
This gives the same table as raw input.
2604-
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
2605-
and null_placement : "at_start".
2606-
This gives:
2607-
[8, 7, 4, 5, 6, 3, 1, 2, 0]
2608-
Based on this we get partition groups of indices:
2609-
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
2610-
We then retrieve the partition keys by offsets.
2611-
And slice the arrow table by offsets and lengths of each partition.
2576+
- We determine the set of unique partition keys
2577+
- Then we produce a set of partitions by filtering on each of the combinations
2578+
- We combine the chunks to create a copy to avoid GIL congestion on the original table
26122579
"""
2613-
partition_columns: List[Tuple[PartitionField, NestedField]] = [
2614-
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
2615-
]
2616-
partition_values_table = pa.table(
2617-
{
2618-
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
2619-
for partition, field in partition_columns
2620-
}
2621-
)
2580+
# Assign unique names to columns where the partition transform has been applied
2581+
# to avoid conflicts
2582+
partition_fields = [f"_partition_{field.name}" for field in spec.fields]
2583+
2584+
for partition, name in zip(spec.fields, partition_fields):
2585+
source_field = schema.find_field(partition.source_id)
2586+
arrow_table = arrow_table.append_column(
2587+
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
2588+
)
2589+
2590+
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
2591+
2592+
table_partitions = []
2593+
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
2594+
for unique_partition in unique_partition_fields.to_pylist():
2595+
partition_key = PartitionKey(
2596+
field_values=[
2597+
PartitionFieldValue(field=field, value=unique_partition[name])
2598+
for field, name in zip(spec.fields, partition_fields)
2599+
],
2600+
partition_spec=spec,
2601+
schema=schema,
2602+
)
2603+
filtered_table = arrow_table.filter(
2604+
functools.reduce(
2605+
operator.and_,
2606+
[
2607+
pc.field(partition_field_name) == unique_partition[partition_field_name]
2608+
if unique_partition[partition_field_name] is not None
2609+
else pc.field(partition_field_name).is_null()
2610+
for field, partition_field_name in zip(spec.fields, partition_fields)
2611+
],
2612+
)
2613+
)
2614+
filtered_table = filtered_table.drop_columns(partition_fields)
26222615

2623-
# Sort by partitions
2624-
sort_indices = pa.compute.sort_indices(
2625-
partition_values_table,
2626-
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
2627-
null_placement="at_end",
2628-
).to_pylist()
2629-
arrow_table = arrow_table.take(sort_indices)
2630-
2631-
# Get slice_instructions to group by partitions
2632-
partition_values_table = partition_values_table.take(sort_indices)
2633-
reversed_indices = pa.compute.sort_indices(
2634-
partition_values_table,
2635-
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
2636-
null_placement="at_start",
2637-
).to_pylist()
2638-
slice_instructions: List[Dict[str, Any]] = []
2639-
last = len(reversed_indices)
2640-
reversed_indices_size = len(reversed_indices)
2641-
ptr = 0
2642-
while ptr < reversed_indices_size:
2643-
group_size = last - reversed_indices[ptr]
2644-
offset = reversed_indices[ptr]
2645-
slice_instructions.append({"offset": offset, "length": group_size})
2646-
last = reversed_indices[ptr]
2647-
ptr = ptr + group_size
2648-
2649-
table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
2616+
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
2617+
# fresh buffers that don't interfere with each other when it is written out to file
2618+
table_partitions.append(
2619+
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
2620+
)
26502621

26512622
return table_partitions

pyiceberg/partitioning.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Optional,
3030
Tuple,
3131
TypeVar,
32+
Union,
3233
)
3334
from urllib.parse import quote_plus
3435

@@ -393,14 +394,14 @@ class PartitionFieldValue:
393394

394395
@dataclass(frozen=True)
395396
class PartitionKey:
396-
raw_partition_field_values: List[PartitionFieldValue]
397+
field_values: List[PartitionFieldValue]
397398
partition_spec: PartitionSpec
398399
schema: Schema
399400

400401
@cached_property
401402
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
402403
iceberg_typed_key_values = {}
403-
for raw_partition_field_value in self.raw_partition_field_values:
404+
for raw_partition_field_value in self.field_values:
404405
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
405406
if len(partition_fields) != 1:
406407
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
@@ -427,25 +428,45 @@ def partition_record_value(partition_field: PartitionField, value: Any, schema:
427428
the final partition record value.
428429
"""
429430
iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type
430-
iceberg_typed_value = _to_partition_representation(iceberg_type, value)
431-
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
432-
return transformed_value
431+
return _to_partition_representation(iceberg_type, value)
433432

434433

435434
@singledispatch
436435
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
436+
"""Strip the logical type into the physical type.
437+
438+
It can be that the value is already transformed into its physical type,
439+
in this case it will return the original value. Keep in mind that the
440+
bucket transform always will return an int, but an identity transform
441+
can return date that still needs to be transformed into an int (days
442+
since epoch).
443+
"""
437444
return TypeError(f"Unsupported partition field type: {type}")
438445

439446

440447
@_to_partition_representation.register(TimestampType)
441448
@_to_partition_representation.register(TimestamptzType)
442-
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
443-
return datetime_to_micros(value) if value is not None else None
449+
def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int]:
450+
if value is None:
451+
return None
452+
elif isinstance(value, int):
453+
return value
454+
elif isinstance(value, datetime):
455+
return datetime_to_micros(value)
456+
else:
457+
raise ValueError(f"Unknown type: {value}")
444458

445459

446460
@_to_partition_representation.register(DateType)
447-
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
448-
return date_to_days(value) if value is not None else None
461+
def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]:
462+
if value is None:
463+
return None
464+
elif isinstance(value, int):
465+
return value
466+
elif isinstance(value, date):
467+
return date_to_days(value)
468+
else:
469+
raise ValueError(f"Unknown type: {value}")
449470

450471

451472
@_to_partition_representation.register(TimeType)

pyiceberg/table/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
453453
with self._append_snapshot_producer(snapshot_properties) as append_files:
454454
# skip writing data files if the dataframe is empty
455455
if df.shape[0] > 0:
456-
data_files = _dataframe_to_data_files(
457-
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
456+
data_files = list(
457+
_dataframe_to_data_files(
458+
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
459+
)
458460
)
459461
for data_file in data_files:
460462
append_files.append_data_file(data_file)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,7 @@ markers = [
12201220
"adls: marks a test as requiring access to adls compliant storage (use with --adls.account-name, --adls.account-key, and --adls.endpoint args)",
12211221
"integration: marks integration tests against Apache Spark",
12221222
"gcs: marks a test as requiring access to gcs compliant storage (use with --gs.token, --gs.project, and --gs.endpoint)",
1223+
"benchmark: collection of tests to validate read/write performance before and after a change"
12231224
]
12241225

12251226
# Turns a warning into an error

tests/benchmark/test_benchmark.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import statistics
18+
import timeit
19+
import urllib
20+
21+
import pyarrow as pa
22+
import pyarrow.parquet as pq
23+
import pytest
24+
25+
from pyiceberg.transforms import DayTransform
26+
27+
28+
@pytest.fixture(scope="session")
29+
def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table:
30+
"""Reads the Taxi dataset to disk"""
31+
taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet"
32+
taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet"
33+
urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest)
34+
35+
return pq.read_table(taxi_dataset_dest)
36+
37+
38+
@pytest.mark.benchmark
39+
def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None:
40+
"""Tests writing to a partitioned table with something that would be close a production-like situation"""
41+
from pyiceberg.catalog.sql import SqlCatalog
42+
43+
warehouse_path = str(tmp_path_factory.mktemp("warehouse"))
44+
catalog = SqlCatalog(
45+
"default",
46+
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
47+
warehouse=f"file://{warehouse_path}",
48+
)
49+
50+
catalog.create_namespace("default")
51+
52+
tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema)
53+
54+
with tbl.update_spec() as spec:
55+
spec.add_field("tpep_pickup_datetime", DayTransform())
56+
57+
# Profiling can sometimes be handy as well
58+
# with cProfile.Profile() as pr:
59+
# tbl.append(taxi_dataset)
60+
#
61+
# pr.print_stats(sort=True)
62+
63+
runs = []
64+
for run in range(5):
65+
start_time = timeit.default_timer()
66+
tbl.append(taxi_dataset)
67+
elapsed = timeit.default_timer() - start_time
68+
69+
print(f"Run {run} took: {elapsed}")
70+
runs.append(elapsed)
71+
72+
print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds")

0 commit comments

Comments
 (0)