diff --git a/pyiceberg/cli/console.py b/pyiceberg/cli/console.py index 83e67a3cbb..10465a1a43 100644 --- a/pyiceberg/cli/console.py +++ b/pyiceberg/cli/console.py @@ -33,7 +33,7 @@ from pyiceberg.cli.output import ConsoleOutput, JsonOutput, Output from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchPropertyException, NoSuchTableError from pyiceberg.table import TableProperties -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType from pyiceberg.utils.properties import property_as_int @@ -419,7 +419,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None: refs = table.refs() if type: type = type.lower() - if type not in {"branch", "tag"}: + if type not in {SnapshotRefType.BRANCH, SnapshotRefType.TAG}: raise ValueError(f"Type must be either branch or tag, got: {type}") relevant_refs = [ @@ -433,7 +433,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None: def _retention_properties(ref: SnapshotRef, table_properties: Dict[str, str]) -> Dict[str, str]: retention_properties = {} - if ref.snapshot_ref_type == "branch": + if ref.snapshot_ref_type == SnapshotRefType.BRANCH: default_min_snapshots_to_keep = property_as_int( table_properties, TableProperties.MIN_SNAPSHOTS_TO_KEEP, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 78676a774a..1458e30dc3 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -87,7 +87,7 @@ from pyiceberg.table.name_mapping import ( NameMapping, ) -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import ( Snapshot, SnapshotLogEntry, @@ -398,7 +398,7 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE expr = Or(expr, match_partition_expression) return expr - def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles: + def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: str = MAIN_BRANCH) -> _FastAppendFiles: """Determine the append type based on table properties. Args: @@ -411,7 +411,7 @@ def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _Fas TableProperties.MANIFEST_MERGE_ENABLED, TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, ) - update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties) + update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch) return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append() def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: @@ -431,13 +431,13 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> UpdateSnapshot: + def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: A new UpdateSnapshot """ - return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties) + return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) def update_statistics(self) -> UpdateStatistics: """ @@ -448,13 +448,14 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ try: import pyarrow as pa @@ -477,7 +478,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = list( @@ -549,6 +550,7 @@ def overwrite( df: pa.Table, overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: str = MAIN_BRANCH, case_sensitive: bool = True, ) -> None: """ @@ -566,6 +568,7 @@ def overwrite( or a boolean expression in case of a partial overwrite case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the overwrite operation """ try: import pyarrow as pa @@ -590,9 +593,14 @@ def overwrite( if overwrite_filter != AlwaysFalse(): # Only delete when the filter is != AlwaysFalse - self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + self.delete( + delete_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, + ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = _dataframe_to_data_files( @@ -605,6 +613,7 @@ def delete( self, delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: str = MAIN_BRANCH, case_sensitive: bool = True, ) -> None: """ @@ -618,6 +627,7 @@ def delete( Args: delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the delete operation case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive """ from pyiceberg.io.pyarrow import ( @@ -635,7 +645,7 @@ def delete( if isinstance(delete_filter, str): delete_filter = _parse_row_filter(delete_filter) - with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot: + with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot: delete_snapshot.delete_by_predicate(delete_filter, case_sensitive) # Check if there are any files that require an actual rewrite of a data file @@ -643,7 +653,7 @@ def delete( bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive) preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) - files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files() + files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).use_ref(branch).plan_files() commit_uuid = uuid.uuid4() counter = itertools.count(0) @@ -685,7 +695,9 @@ def delete( ) if len(replaced_files) > 0: - with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot: + with self.update_snapshot( + branch=branch, snapshot_properties=snapshot_properties + ).overwrite() as overwrite_snapshot: overwrite_snapshot.commit_uuid = commit_uuid for original_data_file, replaced_data_files in replaced_files: overwrite_snapshot.delete_data_file(original_data_file) @@ -1284,16 +1296,17 @@ def upsert( case_sensitive=case_sensitive, ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ with self.transaction() as tx: - tx.append(df=df, snapshot_properties=snapshot_properties) + tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. @@ -1311,6 +1324,7 @@ def overwrite( df: pa.Table, overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: str = MAIN_BRANCH, case_sensitive: bool = True, ) -> None: """ @@ -1327,17 +1341,23 @@ def overwrite( overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the overwrite operation case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive """ with self.transaction() as tx: tx.overwrite( - df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties + df=df, + overwrite_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, ) def delete( self, delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: str = MAIN_BRANCH, case_sensitive: bool = True, ) -> None: """ @@ -1346,10 +1366,12 @@ def delete( Args: delete_filter: The predicate that used to remove rows snapshot_properties: Custom properties to be added to the snapshot summary - case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive + branch: Branch Reference to run the delete operation """ with self.transaction() as tx: - tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + tx.delete( + delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch + ) def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 4905c31bfb..6653f119f0 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -29,7 +29,7 @@ from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, Snapshot, @@ -139,7 +139,7 @@ class AddSnapshotUpdate(IcebergBaseModel): class SetSnapshotRefUpdate(IcebergBaseModel): action: Literal["set-snapshot-ref"] = Field(default="set-snapshot-ref") ref_name: str = Field(alias="ref-name") - type: Literal["tag", "branch"] + type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH] snapshot_id: int = Field(alias="snapshot-id") max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)] max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)] @@ -702,6 +702,10 @@ class AssertRefSnapshotId(ValidatableTableRequirement): def validate(self, base_metadata: Optional[TableMetadata]) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") + elif len(base_metadata.snapshots) == 0 and self.ref != MAIN_BRANCH: + raise CommitFailedException( + f"Requirement failed: Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." + ) elif snapshot_ref := base_metadata.refs.get(self.ref): ref_type = snapshot_ref.snapshot_ref_type if self.snapshot_id is None: diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index a82167744d..cc219ace69 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -55,6 +55,7 @@ from pyiceberg.partitioning import ( PartitionSpec, ) +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -104,12 +105,14 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _added_data_files: List[DataFile] _manifest_num_counter: itertools.count[int] _deleted_data_files: Set[DataFile] + _branch: str def __init__( self, operation: Operation, transaction: Transaction, io: FileIO, + branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: @@ -118,9 +121,9 @@ def __init__( self._io = io self._operation = operation self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() - # Since we only support the main branch for now + self._branch = branch self._parent_snapshot_id = ( - snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.current_snapshot()) else None + snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._branch)) else None ) self._added_data_files = [] self._deleted_data_files = set() @@ -271,10 +274,20 @@ def _commit(self) -> UpdatesAndRequirements: ( AddSnapshotUpdate(snapshot=snapshot), SetSnapshotRefUpdate( - snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + ref_name=self._branch, + type=SnapshotRefType.BRANCH, + ), + ), + ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[self._branch].snapshot_id + if self._branch in self._transaction.table_metadata.refs + else self._transaction.table_metadata.current_snapshot_id, + ref=self._branch, ), ), - (AssertRefSnapshotId(snapshot_id=self._transaction.table_metadata.current_snapshot_id, ref="main"),), ) @property @@ -321,10 +334,11 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, + branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ): - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) + super().__init__(operation, transaction, io, branch, commit_uuid, snapshot_properties) self._predicate = AlwaysFalse() self._case_sensitive = True @@ -483,12 +497,13 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, + branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: from pyiceberg.table import TableProperties - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) + super().__init__(operation, transaction, io, branch, commit_uuid, snapshot_properties) self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, @@ -534,7 +549,7 @@ def _existing_manifests(self) -> List[ManifestFile]: """Determine if there are any existing manifest files.""" existing_files = [] - if snapshot := self._transaction.table_metadata.current_snapshot(): + if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._branch): for manifest_file in snapshot.manifests(io=self._io): entries = manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True) found_deleted_data_files = [entry.data_file for entry in entries if entry.data_file in self._deleted_data_files] @@ -604,21 +619,33 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: class UpdateSnapshot: _transaction: Transaction _io: FileIO + _branch: str _snapshot_properties: Dict[str, str] - def __init__(self, transaction: Transaction, io: FileIO, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def __init__( + self, transaction: Transaction, io: FileIO, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH + ) -> None: self._transaction = transaction self._io = io self._snapshot_properties = snapshot_properties + self._branch = branch def fast_append(self) -> _FastAppendFiles: return _FastAppendFiles( - operation=Operation.APPEND, transaction=self._transaction, io=self._io, snapshot_properties=self._snapshot_properties + operation=Operation.APPEND, + transaction=self._transaction, + io=self._io, + branch=self._branch, + snapshot_properties=self._snapshot_properties, ) def merge_append(self) -> _MergeAppendFiles: return _MergeAppendFiles( - operation=Operation.APPEND, transaction=self._transaction, io=self._io, snapshot_properties=self._snapshot_properties + operation=Operation.APPEND, + transaction=self._transaction, + io=self._io, + branch=self._branch, + snapshot_properties=self._snapshot_properties, ) def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: @@ -629,6 +656,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: else Operation.APPEND, transaction=self._transaction, io=self._io, + branch=self._branch, snapshot_properties=self._snapshot_properties, ) @@ -637,6 +665,7 @@ def delete(self) -> _DeleteFiles: operation=Operation.DELETE, transaction=self._transaction, io=self._io, + branch=self._branch, snapshot_properties=self._snapshot_properties, ) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 150d2b750c..343cde3b06 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -41,12 +41,13 @@ from pyiceberg.catalog.hive import HiveCatalog from pyiceberg.catalog.rest import RestCatalog from pyiceberg.catalog.sql import SqlCatalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not from pyiceberg.io.pyarrow import _dataframe_to_data_files from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import TableProperties +from pyiceberg.table.refs import MAIN_BRANCH from pyiceberg.table.sorting import SortDirection, SortField, SortOrder from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform from pyiceberg.types import ( @@ -1745,6 +1746,163 @@ def test_abort_table_transaction_on_exception( assert len(tbl.scan().to_pandas()) == table_size # type: ignore +@pytest.mark.integration +def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_non_existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, []) + with pytest.raises( + CommitFailedException, match=f"Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." + ): + tbl.append(arrow_table_with_null, branch="non_existing_branch") + + +@pytest.mark.integration +def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_append" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.append(arrow_table_with_null, branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 6 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + + +@pytest.mark.integration +def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_delete" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.delete(delete_filter="int = 9", branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + + +@pytest.mark.integration +def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_overwrite" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.overwrite(arrow_table_with_null, branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 3 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None + delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id) + assert delete_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert ( + delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + ) # Currently overwrite is a delete followed by an append operation + + +@pytest.mark.integration +def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_intertwined_branch_operations" + branch1 = "existing_branch_1" + branch2 = "existing_branch_2" + + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch1).commit() + + tbl.delete("int = 9", branch=branch1) + + tbl.append(arrow_table_with_null) + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch2).commit() + + tbl.overwrite(arrow_table_with_null, branch=branch2) + + assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2 + assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3 + assert len(tbl.scan().to_arrow()) == 6 + + +@pytest.mark.integration +def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: + # Initialize table with branch + identifier = "default.test_branch_spark_write_py_read" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + branch = "existing_spark_branch" + + # Create branch in Spark + spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}") + + # Spark Write + spark.sql( + f""" + DELETE FROM {identifier}.branch_{branch} + WHERE int = 9 + """ + ) + + # Refresh table to get new refs + tbl.refresh() + + # Python Read + assert len(tbl.scan().to_arrow()) == 3 + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 + + +@pytest.mark.integration +def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: + # Initialize table with branch + identifier = "default.test_branch_py_write_spark_read" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + branch = "existing_py_branch" + + assert tbl.metadata.current_snapshot_id is not None + + # Create branch + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + + # Python Write + tbl.delete("int = 9", branch=branch) + + # Spark Read + main_df = spark.sql( + f""" + SELECT * + FROM {identifier} + """ + ) + branch_df = spark.sql( + f""" + SELECT * + FROM {identifier}.branch_{branch} + """ + ) + assert main_df.count() == 3 + assert branch_df.count() == 2 + + @pytest.mark.integration def test_write_optional_list(session_catalog: Catalog) -> None: identifier = "default.test_write_optional_list" diff --git a/tests/table/test_init.py b/tests/table/test_init.py index dbac84bd81..3228388cc0 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -50,7 +50,7 @@ _match_deletes_to_data_file, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, Operation, @@ -1000,28 +1000,42 @@ def test_assert_table_uuid(table_v2: Table) -> None: def test_assert_ref_snapshot_id(table_v2: Table) -> None: base_metadata = table_v2.metadata - AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): - AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(None) with pytest.raises( CommitFailedException, - match="Requirement failed: branch main was created concurrently", + match=f"Requirement failed: branch {MAIN_BRANCH} was created concurrently", ): - AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=None).validate(base_metadata) with pytest.raises( CommitFailedException, - match="Requirement failed: branch main has changed: expected id 1, found 3055729675574597004", + match=f"Requirement failed: branch {MAIN_BRANCH} has changed: expected id 1, found 3055729675574597004", ): - AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(base_metadata) + + non_existing_ref = "not_exist_branch_or_tag" + assert table_v2.refs().get("not_exist_branch_or_tag") is None + + with pytest.raises( + CommitFailedException, + match=f"Requirement failed: branch or tag {non_existing_ref} is missing, expected 1", + ): + AssertRefSnapshotId(ref=non_existing_ref, snapshot_id=1).validate(base_metadata) + + # existing Tag in metadata: test + ref_tag = table_v2.refs().get("test") + assert ref_tag is not None + assert ref_tag.snapshot_ref_type == SnapshotRefType.TAG, "TAG test should be present in table to be tested" with pytest.raises( CommitFailedException, - match="Requirement failed: branch or tag not_exist is missing, expected 1", + match="Requirement failed: tag test has changed: expected id 3055729675574597004, found 3051729675574597004", ): - AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata) + AssertRefSnapshotId(ref="test", snapshot_id=3055729675574597004).validate(base_metadata) def test_assert_last_assigned_field_id(table_v2: Table) -> None: