Skip to content

Commit d9b1c03

Browse files
gabeiglioGabriel IgliozziGabriel Igliozzi
authored
Implement column projection (#1443)
This is a fix for issue #1401. In which table scans needed to infer partition column by following the column projection [rules](https://iceberg.apache.org/spec/#column-projection) Fixes #1401 --------- Co-authored-by: Gabriel Igliozzi <[email protected]> Co-authored-by: Gabriel Igliozzi <[email protected]>
1 parent dd175aa commit d9b1c03

File tree

2 files changed

+193
-5
lines changed

2 files changed

+193
-5
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,14 @@
125125
)
126126
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
127127
from pyiceberg.schema import (
128+
Accessor,
128129
PartnerAccessor,
129130
PreOrderSchemaVisitor,
130131
Schema,
131132
SchemaVisitorPerPrimitiveType,
132133
SchemaWithPartnerVisitor,
133134
_check_schema_compatible,
135+
build_position_accessors,
134136
pre_order_visit,
135137
promote,
136138
prune_columns,
@@ -141,7 +143,7 @@
141143
from pyiceberg.table.locations import load_location_provider
142144
from pyiceberg.table.metadata import TableMetadata
143145
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
144-
from pyiceberg.transforms import TruncateTransform
146+
from pyiceberg.transforms import IdentityTransform, TruncateTransform
145147
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
146148
from pyiceberg.types import (
147149
BinaryType,
@@ -1298,6 +1300,45 @@ def _field_id(self, field: pa.Field) -> int:
12981300
return -1
12991301

13001302

1303+
def _get_column_projection_values(
1304+
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
1305+
) -> Tuple[bool, Dict[str, Any]]:
1306+
"""Apply Column Projection rules to File Schema."""
1307+
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
1308+
should_project_columns = len(project_schema_diff) > 0
1309+
projected_missing_fields: Dict[str, Any] = {}
1310+
1311+
if not should_project_columns:
1312+
return False, {}
1313+
1314+
partition_schema: StructType
1315+
accessors: Dict[int, Accessor]
1316+
1317+
if partition_spec is not None:
1318+
partition_schema = partition_spec.partition_type(projected_schema)
1319+
accessors = build_position_accessors(partition_schema)
1320+
else:
1321+
return False, {}
1322+
1323+
for field_id in project_schema_diff:
1324+
for partition_field in partition_spec.fields_by_source_id(field_id):
1325+
if isinstance(partition_field.transform, IdentityTransform):
1326+
accessor = accessors.get(partition_field.field_id)
1327+
1328+
if accessor is None:
1329+
continue
1330+
1331+
# The partition field may not exist in the partition record of the data file.
1332+
# This can happen when new partition fields are introduced after the file was written.
1333+
try:
1334+
if partition_value := accessor.get(file.partition):
1335+
projected_missing_fields[partition_field.name] = partition_value
1336+
except IndexError:
1337+
continue
1338+
1339+
return True, projected_missing_fields
1340+
1341+
13011342
def _task_to_record_batches(
13021343
fs: FileSystem,
13031344
task: FileScanTask,
@@ -1308,6 +1349,7 @@ def _task_to_record_batches(
13081349
case_sensitive: bool,
13091350
name_mapping: Optional[NameMapping] = None,
13101351
use_large_types: bool = True,
1352+
partition_spec: Optional[PartitionSpec] = None,
13111353
) -> Iterator[pa.RecordBatch]:
13121354
_, _, path = _parse_location(task.file.file_path)
13131355
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
@@ -1319,16 +1361,20 @@ def _task_to_record_batches(
13191361
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
13201362
# the table format version.
13211363
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
1364+
13221365
pyarrow_filter = None
13231366
if bound_row_filter is not AlwaysTrue():
13241367
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
13251368
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
13261369
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
13271370

1328-
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
1371+
# Apply column projection rules
1372+
# https://iceberg.apache.org/spec/#column-projection
1373+
should_project_columns, projected_missing_fields = _get_column_projection_values(
1374+
task.file, projected_schema, partition_spec, file_schema.field_ids
1375+
)
13291376

1330-
if file_schema is None:
1331-
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
1377+
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
13321378

13331379
fragment_scanner = ds.Scanner.from_fragment(
13341380
fragment=fragment,
@@ -1368,14 +1414,23 @@ def _task_to_record_batches(
13681414
continue
13691415
output_batches = arrow_table.to_batches()
13701416
for output_batch in output_batches:
1371-
yield _to_requested_schema(
1417+
result_batch = _to_requested_schema(
13721418
projected_schema,
13731419
file_project_schema,
13741420
output_batch,
13751421
downcast_ns_timestamp_to_us=True,
13761422
use_large_types=use_large_types,
13771423
)
13781424

1425+
# Inject projected column values if available
1426+
if should_project_columns:
1427+
for name, value in projected_missing_fields.items():
1428+
index = result_batch.schema.get_field_index(name)
1429+
if index != -1:
1430+
result_batch = result_batch.set_column(index, name, [value])
1431+
1432+
yield result_batch
1433+
13791434

13801435
def _task_to_table(
13811436
fs: FileSystem,
@@ -1597,6 +1652,7 @@ def _record_batches_from_scan_tasks_and_deletes(
15971652
self._case_sensitive,
15981653
self._table_metadata.name_mapping(),
15991654
self._use_large_types,
1655+
self._table_metadata.spec(),
16001656
)
16011657
for batch in batches:
16021658
if self._limit is not None:

tests/io/test_pyarrow.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,18 @@
6969
_read_deletes,
7070
_to_requested_schema,
7171
bin_pack_arrow_table,
72+
compute_statistics_plan,
73+
data_file_statistics_from_parquet_metadata,
7274
expression_to_pyarrow,
75+
parquet_path_to_id_mapping,
7376
schema_to_pyarrow,
7477
)
7578
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
7679
from pyiceberg.partitioning import PartitionField, PartitionSpec
7780
from pyiceberg.schema import Schema, make_compatible_name, visit
7881
from pyiceberg.table import FileScanTask, TableProperties
7982
from pyiceberg.table.metadata import TableMetadataV2
83+
from pyiceberg.table.name_mapping import create_mapping_from_schema
8084
from pyiceberg.transforms import IdentityTransform
8185
from pyiceberg.typedef import UTF8, Properties, Record
8286
from pyiceberg.types import (
@@ -99,6 +103,7 @@
99103
TimestamptzType,
100104
TimeType,
101105
)
106+
from tests.catalog.test_base import InMemoryCatalog
102107
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES
103108

104109

@@ -1127,6 +1132,133 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
11271132
assert repr(result_table.schema) == "id: int32"
11281133

11291134

1135+
def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
1136+
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
1137+
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
1138+
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
1139+
1140+
schema = Schema(
1141+
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
1142+
)
1143+
1144+
partition_spec = PartitionSpec(
1145+
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
1146+
)
1147+
1148+
table = catalog.create_table(
1149+
"default.test_projection_partition",
1150+
schema=schema,
1151+
partition_spec=partition_spec,
1152+
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
1153+
)
1154+
1155+
file_data = pa.array(["foo"], type=pa.string())
1156+
file_loc = f"{tmp_path}/test.parquet"
1157+
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)
1158+
1159+
statistics = data_file_statistics_from_parquet_metadata(
1160+
parquet_metadata=pq.read_metadata(file_loc),
1161+
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
1162+
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
1163+
)
1164+
1165+
unpartitioned_file = DataFile(
1166+
content=DataFileContent.DATA,
1167+
file_path=file_loc,
1168+
file_format=FileFormat.PARQUET,
1169+
# projected value
1170+
partition=Record(partition_id=1),
1171+
file_size_in_bytes=os.path.getsize(file_loc),
1172+
sort_order_id=None,
1173+
spec_id=table.metadata.default_spec_id,
1174+
equality_ids=None,
1175+
key_metadata=None,
1176+
**statistics.to_serialized_dict(),
1177+
)
1178+
1179+
with table.transaction() as transaction:
1180+
with transaction.update_snapshot().overwrite() as update:
1181+
update.append_data_file(unpartitioned_file)
1182+
1183+
assert (
1184+
str(table.scan().to_arrow())
1185+
== """pyarrow.Table
1186+
other_field: large_string
1187+
partition_id: int64
1188+
----
1189+
other_field: [["foo"]]
1190+
partition_id: [[1]]"""
1191+
)
1192+
1193+
1194+
def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
1195+
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
1196+
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
1197+
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
1198+
schema = Schema(
1199+
NestedField(1, "field_1", StringType(), required=False),
1200+
NestedField(2, "field_2", IntegerType(), required=False),
1201+
NestedField(3, "field_3", IntegerType(), required=False),
1202+
)
1203+
1204+
partition_spec = PartitionSpec(
1205+
PartitionField(2, 1000, IdentityTransform(), "field_2"),
1206+
PartitionField(3, 1001, IdentityTransform(), "field_3"),
1207+
)
1208+
1209+
table = catalog.create_table(
1210+
"default.test_projection_partitions",
1211+
schema=schema,
1212+
partition_spec=partition_spec,
1213+
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
1214+
)
1215+
1216+
file_data = pa.array(["foo"], type=pa.string())
1217+
file_loc = f"{tmp_path}/test.parquet"
1218+
pq.write_table(pa.table([file_data], names=["field_1"]), file_loc)
1219+
1220+
statistics = data_file_statistics_from_parquet_metadata(
1221+
parquet_metadata=pq.read_metadata(file_loc),
1222+
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
1223+
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
1224+
)
1225+
1226+
unpartitioned_file = DataFile(
1227+
content=DataFileContent.DATA,
1228+
file_path=file_loc,
1229+
file_format=FileFormat.PARQUET,
1230+
# projected value
1231+
partition=Record(field_2=2, field_3=3),
1232+
file_size_in_bytes=os.path.getsize(file_loc),
1233+
sort_order_id=None,
1234+
spec_id=table.metadata.default_spec_id,
1235+
equality_ids=None,
1236+
key_metadata=None,
1237+
**statistics.to_serialized_dict(),
1238+
)
1239+
1240+
with table.transaction() as transaction:
1241+
with transaction.update_snapshot().overwrite() as update:
1242+
update.append_data_file(unpartitioned_file)
1243+
1244+
assert (
1245+
str(table.scan().to_arrow())
1246+
== """pyarrow.Table
1247+
field_1: large_string
1248+
field_2: int64
1249+
field_3: int64
1250+
----
1251+
field_1: [["foo"]]
1252+
field_2: [[2]]
1253+
field_3: [[3]]"""
1254+
)
1255+
1256+
1257+
@pytest.fixture
1258+
def catalog() -> InMemoryCatalog:
1259+
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})
1260+
1261+
11301262
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
11311263
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
11321264
assert len(result_table.columns[0]) == 0

0 commit comments

Comments
 (0)