From 69162f8d5013c352b854b95a11504172189e9c3f Mon Sep 17 00:00:00 2001 From: Maksym Shalenyi Date: Thu, 2 May 2024 09:19:07 -0700 Subject: [PATCH 1/3] make `add_files` to support `snapshot_properties` argument --- pyiceberg/table/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 13186c42cc..5b7d04b543 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -443,7 +443,7 @@ def overwrite( for data_file in data_files: update_snapshot.append_data_file(data_file) - def add_files(self, file_paths: List[str]) -> None: + def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ Shorthand API for adding files as data files to the table transaction. @@ -455,7 +455,7 @@ def add_files(self, file_paths: List[str]) -> None: """ if self._table.name_mapping() is None: self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()}) - with self.update_snapshot().fast_append() as update_snapshot: + with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: data_files = _parquet_files_to_data_files( table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io ) @@ -1341,7 +1341,7 @@ def overwrite( with self.transaction() as tx: tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties) - def add_files(self, file_paths: List[str]) -> None: + def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ Shorthand API for adding files as data files to the table. @@ -1352,7 +1352,7 @@ def add_files(self, file_paths: List[str]) -> None: FileNotFoundError: If the file does not exist. """ with self.transaction() as tx: - tx.add_files(file_paths=file_paths) + tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive) From ee3278d644728f7ec18fb05c5265fe1fb57d5178 Mon Sep 17 00:00:00 2001 From: Maksym Shalenyi Date: Thu, 2 May 2024 13:19:09 -0700 Subject: [PATCH 2/3] add tests --- tests/integration/test_add_files.py | 39 +++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 0de5d5f4ce..114b6bf273 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -16,6 +16,7 @@ # under the License. # pylint:disable=redefined-outer-name +from collections.abc import Iterator from datetime import date from typing import Optional @@ -122,8 +123,13 @@ def _create_table( return tbl +@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")]) +def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]: + """Fixture to run tests with different table format versions.""" + yield request.param + + @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: identifier = f"default.unpartitioned_table_v{format_version}" tbl = _create_table(session_catalog, identifier, format_version) @@ -163,7 +169,6 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_unpartitioned_table_raises_file_not_found( spark: SparkSession, session_catalog: Catalog, format_version: int ) -> None: @@ -184,7 +189,6 @@ def test_add_files_to_unpartitioned_table_raises_file_not_found( @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_unpartitioned_table_raises_has_field_ids( spark: SparkSession, session_catalog: Catalog, format_version: int ) -> None: @@ -205,7 +209,6 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids( @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_unpartitioned_table_with_schema_updates( spark: SparkSession, session_catalog: Catalog, format_version: int ) -> None: @@ -263,7 +266,6 @@ def test_add_files_to_unpartitioned_table_with_schema_updates( @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: identifier = f"default.partitioned_table_v{format_version}" @@ -335,7 +337,6 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: identifier = f"default.partitioned_table_bucket_fails_v{format_version}" @@ -378,7 +379,6 @@ def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, sessio @pytest.mark.integration -@pytest.mark.parametrize("format_version", [1, 2]) def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch( spark: SparkSession, session_catalog: Catalog, format_version: int ) -> None: @@ -424,3 +424,28 @@ def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch( "Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: baz. lower_value=123, upper_value=124" in str(exc_info.value) ) + + +@pytest.mark.integration +def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.unpartitioned_table_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths, snapshot_properties={"snapshot_prop_a": "test_prop_a"}) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + summary = spark.sql(f"SELECT * FROM {identifier}.snapshots;").collect()[0].summary + + assert "snapshot_prop_a" in summary + assert summary["snapshot_prop_a"] == "test_prop_a" From d350376820b2de5126b8c705f8d05db5c499c2ab Mon Sep 17 00:00:00 2001 From: Maksym Shalenyi Date: Tue, 7 May 2024 00:03:54 -0700 Subject: [PATCH 3/3] replace `collections.abc.Iterator` with `typing.Iterator` --- tests/integration/test_add_files.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 114b6bf273..94c73918c8 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -16,9 +16,8 @@ # under the License. # pylint:disable=redefined-outer-name -from collections.abc import Iterator from datetime import date -from typing import Optional +from typing import Iterator, Optional import pyarrow as pa import pyarrow.parquet as pq