From dac2adcf998932c781378bd57c15ec4ebfb9771d Mon Sep 17 00:00:00 2001 From: YinZheng-Sun <51255903009@stu.ecnu.edu.cn> Date: Tue, 18 Mar 2025 16:17:49 +0800 Subject: [PATCH 1/2] feat: check whether other table ops conflict when committing --- pyiceberg/table/update/snapshot.py | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index c705f3b9fd..43bb3baebf 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -55,6 +55,7 @@ from pyiceberg.partitioning import ( PartitionSpec, ) +from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -239,7 +240,21 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary: truncate_full_table=self._operation == Operation.OVERWRITE, ) + def refresh(self) -> TableMetadata: + try: + table = self._transaction._table.refresh() + return table.metadata + except Exception: + return self._transaction._table.metadata + + @abstractmethod + def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: ... + def _commit(self) -> UpdatesAndRequirements: + current_snapshot = self._transaction.table_metadata.current_snapshot() + table_metadata = self.refresh() + self._validate(table_metadata, current_snapshot) + new_manifests = self._manifests() next_sequence_number = self._transaction.table_metadata.next_sequence_number() @@ -249,6 +264,7 @@ def _commit(self) -> UpdatesAndRequirements: attempt=0, commit_uuid=self.commit_uuid, ) + location_provider = self._transaction._table.location_provider() manifest_list_file_path = location_provider.new_metadata_location(file_name) with write_manifest_list( @@ -445,6 +461,14 @@ def files_affected(self) -> bool: """Indicate if any manifest-entries can be dropped.""" return len(self._deleted_entries()) > 0 + def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: + if Snapshot is None: + raise ValueError("Snapshot cannot be None.") + + if Snapshot.snapshot_id != current_metadata.snapshot_id: + raise ValueError("Operation conflicts are not allowed when performing deleting.") + return + class _FastAppendFiles(_SnapshotProducer["_FastAppendFiles"]): def _existing_manifests(self) -> List[ManifestFile]: @@ -474,6 +498,10 @@ def _deleted_entries(self) -> List[ManifestEntry]: """ return [] + def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: + """Other operations don't affect the appending operation, and we can just append.""" + return + class _MergeAppendFiles(_FastAppendFiles): _target_size_bytes: int @@ -602,6 +630,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: else: return [] + def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: + if Snapshot is None: + raise ValueError("Snapshot cannot be None.") + + if Snapshot.snapshot_id != current_metadata.snapshot_id: + raise ValueError("Operation conflicts are not allowed when performing overwriting.") + return + class UpdateSnapshot: _transaction: Transaction From e7ea131ff2258848af7515c524b8833fcdd90090 Mon Sep 17 00:00:00 2001 From: YinZheng-Sun <51255903009@stu.ecnu.edu.cn> Date: Wed, 19 Mar 2025 20:55:08 +0800 Subject: [PATCH 2/2] fix code format --- pyiceberg/table/update/snapshot.py | 47 +++++++++++----------- tests/integration/test_add_files.py | 60 ++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 24 deletions(-) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 43bb3baebf..bf950a68bd 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -27,6 +27,7 @@ from sortedcontainers import SortedList +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysFalse, BooleanExpression, @@ -83,7 +84,7 @@ from pyiceberg.utils.properties import property_as_bool, property_as_int if TYPE_CHECKING: - from pyiceberg.table import Transaction + from pyiceberg.table import Transaction, Table def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: @@ -240,20 +241,20 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary: truncate_full_table=self._operation == Operation.OVERWRITE, ) - def refresh(self) -> TableMetadata: - try: - table = self._transaction._table.refresh() - return table.metadata - except Exception: - return self._transaction._table.metadata + def refresh(self) -> Table: + table = self._transaction._table.refresh() + return table @abstractmethod - def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: ... + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: ... def _commit(self) -> UpdatesAndRequirements: current_snapshot = self._transaction.table_metadata.current_snapshot() - table_metadata = self.refresh() - self._validate(table_metadata, current_snapshot) + if current_snapshot is not None: + table = self.refresh() + if table is None: + raise CommitFailedException("Table is none.") + self._validate(table.metadata, current_snapshot) new_manifests = self._manifests() next_sequence_number = self._transaction.table_metadata.next_sequence_number() @@ -461,12 +462,12 @@ def files_affected(self) -> bool: """Indicate if any manifest-entries can be dropped.""" return len(self._deleted_entries()) > 0 - def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: - if Snapshot is None: - raise ValueError("Snapshot cannot be None.") - - if Snapshot.snapshot_id != current_metadata.snapshot_id: - raise ValueError("Operation conflicts are not allowed when performing deleting.") + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: + if snapshot is None: + raise CommitFailedException("Snapshot cannot be None.") + current_snapshot_id = current_metadata.current_snapshot_id + if current_snapshot_id != None and snapshot.snapshot_id != current_snapshot_id: + raise CommitFailedException("Operation conflicts are not allowed when performing deleting.") return @@ -498,7 +499,7 @@ def _deleted_entries(self) -> List[ManifestEntry]: """ return [] - def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: """Other operations don't affect the appending operation, and we can just append.""" return @@ -630,12 +631,12 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: else: return [] - def _validate(self, current_metadata: TableMetadata, Snapshot: Optional[Snapshot]) -> None: - if Snapshot is None: - raise ValueError("Snapshot cannot be None.") - - if Snapshot.snapshot_id != current_metadata.snapshot_id: - raise ValueError("Operation conflicts are not allowed when performing overwriting.") + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: + if snapshot is None: + raise CommitFailedException("Snapshot cannot be None.") + current_snapshot_id = current_metadata.current_snapshot_id + if current_snapshot_id != None and snapshot.snapshot_id != current_snapshot_id: + raise CommitFailedException("Operation conflicts are not allowed when performing overwriting.") return diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index bfbc8db668..c2bafdb93b 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -31,7 +31,7 @@ from pytest_mock.plugin import MockerFixture from pyiceberg.catalog import Catalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import NoSuchTableError, CommitFailedException from pyiceberg.io import FileIO from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _pyarrow_schema_ensure_large_types from pyiceberg.manifest import DataFile @@ -903,3 +903,61 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file with pytest.raises(ValueError) as exc_info: tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True) assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value) + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.delete("string == 'z'") + + with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + # This is allowed + tbl1.delete("string == 'z'") + tbl2.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.delete("string == 'z'") + + with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null)