Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def overwrite(
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def add_files(self, file_paths: List[str]) -> None:
def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for adding files as data files to the table transaction.

Expand All @@ -455,7 +455,7 @@ def add_files(self, file_paths: List[str]) -> None:
"""
if self._table.name_mapping() is None:
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
with self.update_snapshot().fast_append() as update_snapshot:
with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
)
Expand Down Expand Up @@ -1341,7 +1341,7 @@ def overwrite(
with self.transaction() as tx:
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)

def add_files(self, file_paths: List[str]) -> None:
def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for adding files as data files to the table.

Expand All @@ -1352,7 +1352,7 @@ def add_files(self, file_paths: List[str]) -> None:
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
tx.add_files(file_paths=file_paths)
tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties)

def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)
Expand Down
40 changes: 32 additions & 8 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=redefined-outer-name

from datetime import date
from typing import Optional
from typing import Iterator, Optional

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -122,8 +122,13 @@ def _create_table(
return tbl


@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
Copy link
Contributor Author

@enkidulan enkidulan May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making params more human-friendly:

test_add_files_to_unpartitioned_table[format_version=1] PASSED
test_add_files_to_unpartitioned_table[format_version=2] PASSED
..
# instead of 
test_add_files_to_unpartitioned_table[1] PASSED
test_add_files_to_unpartitioned_table[2] PASSED

"""Fixture to run tests with different table format versions."""
yield request.param


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_table_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)
Expand Down Expand Up @@ -163,7 +168,6 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog:


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_raises_file_not_found(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
Expand All @@ -184,7 +188,6 @@ def test_add_files_to_unpartitioned_table_raises_file_not_found(


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_raises_has_field_ids(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
Expand All @@ -205,7 +208,6 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids(


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_with_schema_updates(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
Expand Down Expand Up @@ -263,7 +265,6 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.partitioned_table_v{format_version}"

Expand Down Expand Up @@ -335,7 +336,6 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.partitioned_table_bucket_fails_v{format_version}"

Expand Down Expand Up @@ -378,7 +378,6 @@ def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, sessio


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
Expand Down Expand Up @@ -424,3 +423,28 @@ def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: baz. lower_value=123, upper_value=124"
in str(exc_info.value)
)


@pytest.mark.integration
def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_table_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)

file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths, snapshot_properties={"snapshot_prop_a": "test_prop_a"})

# NameMapping must have been set to enable reads
assert tbl.name_mapping() is not None

summary = spark.sql(f"SELECT * FROM {identifier}.snapshots;").collect()[0].summary

assert "snapshot_prop_a" in summary
assert summary["snapshot_prop_a"] == "test_prop_a"