diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index c705f3b9fd..bf950a68bd 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -27,6 +27,7 @@ from sortedcontainers import SortedList +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysFalse, BooleanExpression, @@ -55,6 +56,7 @@ from pyiceberg.partitioning import ( PartitionSpec, ) +from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -82,7 +84,7 @@ from pyiceberg.utils.properties import property_as_bool, property_as_int if TYPE_CHECKING: - from pyiceberg.table import Transaction + from pyiceberg.table import Transaction, Table def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: @@ -239,7 +241,21 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary: truncate_full_table=self._operation == Operation.OVERWRITE, ) + def refresh(self) -> Table: + table = self._transaction._table.refresh() + return table + + @abstractmethod + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: ... + def _commit(self) -> UpdatesAndRequirements: + current_snapshot = self._transaction.table_metadata.current_snapshot() + if current_snapshot is not None: + table = self.refresh() + if table is None: + raise CommitFailedException("Table is none.") + self._validate(table.metadata, current_snapshot) + new_manifests = self._manifests() next_sequence_number = self._transaction.table_metadata.next_sequence_number() @@ -249,6 +265,7 @@ def _commit(self) -> UpdatesAndRequirements: attempt=0, commit_uuid=self.commit_uuid, ) + location_provider = self._transaction._table.location_provider() manifest_list_file_path = location_provider.new_metadata_location(file_name) with write_manifest_list( @@ -445,6 +462,14 @@ def files_affected(self) -> bool: """Indicate if any manifest-entries can be dropped.""" return len(self._deleted_entries()) > 0 + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: + if snapshot is None: + raise CommitFailedException("Snapshot cannot be None.") + current_snapshot_id = current_metadata.current_snapshot_id + if current_snapshot_id != None and snapshot.snapshot_id != current_snapshot_id: + raise CommitFailedException("Operation conflicts are not allowed when performing deleting.") + return + class _FastAppendFiles(_SnapshotProducer["_FastAppendFiles"]): def _existing_manifests(self) -> List[ManifestFile]: @@ -474,6 +499,10 @@ def _deleted_entries(self) -> List[ManifestEntry]: """ return [] + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: + """Other operations don't affect the appending operation, and we can just append.""" + return + class _MergeAppendFiles(_FastAppendFiles): _target_size_bytes: int @@ -602,6 +631,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: else: return [] + def _validate(self, current_metadata: TableMetadata, snapshot: Optional[Snapshot]) -> None: + if snapshot is None: + raise CommitFailedException("Snapshot cannot be None.") + current_snapshot_id = current_metadata.current_snapshot_id + if current_snapshot_id != None and snapshot.snapshot_id != current_snapshot_id: + raise CommitFailedException("Operation conflicts are not allowed when performing overwriting.") + return + class UpdateSnapshot: _transaction: Transaction diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index bfbc8db668..c2bafdb93b 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -31,7 +31,7 @@ from pytest_mock.plugin import MockerFixture from pyiceberg.catalog import Catalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import NoSuchTableError, CommitFailedException from pyiceberg.io import FileIO from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _pyarrow_schema_ensure_large_types from pyiceberg.manifest import DataFile @@ -903,3 +903,61 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file with pytest.raises(ValueError) as exc_info: tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True) assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value) + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.delete("string == 'z'") + + with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + # This is allowed + tbl1.delete("string == 'z'") + tbl2.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_delete( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.delete("string == 'z'") + + with pytest.raises(CommitFailedException, match="(branch main has changed: expected id ).*"): + # tbl2 isn't aware of the commit by tbl1 + tbl2.delete("string == 'z'") + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) + tbl2 = session_catalog.load_table(identifier) + + tbl1.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null)