Skip to content

Feature/write to branch #2009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyiceberg/cli/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pyiceberg.cli.output import ConsoleOutput, JsonOutput, Output
from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchPropertyException, NoSuchTableError
from pyiceberg.table import TableProperties
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
from pyiceberg.utils.properties import property_as_int


Expand Down Expand Up @@ -419,7 +419,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:
refs = table.refs()
if type:
type = type.lower()
if type not in {"branch", "tag"}:
if type not in {SnapshotRefType.BRANCH, SnapshotRefType.TAG}:
raise ValueError(f"Type must be either branch or tag, got: {type}")

relevant_refs = [
Expand All @@ -433,7 +433,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None:

def _retention_properties(ref: SnapshotRef, table_properties: Dict[str, str]) -> Dict[str, str]:
retention_properties = {}
if ref.snapshot_ref_type == "branch":
if ref.snapshot_ref_type == SnapshotRefType.BRANCH:
default_min_snapshots_to_keep = property_as_int(
table_properties,
TableProperties.MIN_SNAPSHOTS_TO_KEEP,
Expand Down
56 changes: 39 additions & 17 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from pyiceberg.table.name_mapping import (
NameMapping,
)
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import (
Snapshot,
SnapshotLogEntry,
Expand Down Expand Up @@ -398,7 +398,7 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
expr = Or(expr, match_partition_expression)
return expr

def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles:
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: str = MAIN_BRANCH) -> _FastAppendFiles:
"""Determine the append type based on table properties.

Args:
Expand All @@ -411,7 +411,7 @@ def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _Fas
TableProperties.MANIFEST_MERGE_ENABLED,
TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT,
)
update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties)
update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch)
return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append()

def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
Expand All @@ -431,13 +431,13 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
name_mapping=self.table_metadata.name_mapping(),
)

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> UpdateSnapshot:
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.

Returns:
A new UpdateSnapshot
"""
return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties)
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)

def update_statistics(self) -> UpdateStatistics:
"""
Expand All @@ -448,13 +448,14 @@ def update_statistics(self) -> UpdateStatistics:
"""
return UpdateStatistics(transaction=self)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
branch: Branch Reference to run the append operation
"""
try:
import pyarrow as pa
Expand All @@ -477,7 +478,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

with self._append_snapshot_producer(snapshot_properties) as append_files:
with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = list(
Expand Down Expand Up @@ -549,6 +550,7 @@ def overwrite(
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
case_sensitive: bool = True,
) -> None:
"""
Expand All @@ -566,6 +568,7 @@ def overwrite(
or a boolean expression in case of a partial overwrite
case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot summary
branch: Branch Reference to run the overwrite operation
"""
try:
import pyarrow as pa
Expand All @@ -590,9 +593,14 @@ def overwrite(

if overwrite_filter != AlwaysFalse():
# Only delete when the filter is != AlwaysFalse
self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
self.delete(
delete_filter=overwrite_filter,
case_sensitive=case_sensitive,
snapshot_properties=snapshot_properties,
branch=branch,
)

with self._append_snapshot_producer(snapshot_properties) as append_files:
with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
Expand All @@ -605,6 +613,7 @@ def delete(
self,
delete_filter: Union[str, BooleanExpression],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
case_sensitive: bool = True,
) -> None:
"""
Expand All @@ -618,6 +627,7 @@ def delete(
Args:
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
branch: Branch Reference to run the delete operation
case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive
"""
from pyiceberg.io.pyarrow import (
Expand All @@ -635,15 +645,15 @@ def delete(
if isinstance(delete_filter, str):
delete_filter = _parse_row_filter(delete_filter)

with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot:
with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot:
delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)

# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files()
files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).use_ref(branch).plan_files()

commit_uuid = uuid.uuid4()
counter = itertools.count(0)
Expand Down Expand Up @@ -685,7 +695,9 @@ def delete(
)

if len(replaced_files) > 0:
with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot:
with self.update_snapshot(
branch=branch, snapshot_properties=snapshot_properties
).overwrite() as overwrite_snapshot:
overwrite_snapshot.commit_uuid = commit_uuid
for original_data_file, replaced_data_files in replaced_files:
overwrite_snapshot.delete_data_file(original_data_file)
Expand Down Expand Up @@ -1284,16 +1296,17 @@ def upsert(
case_sensitive=case_sensitive,
)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None:
"""
Shorthand API for appending a PyArrow table to the table.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
branch: Branch Reference to run the append operation
"""
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)

def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""Shorthand for dynamic overwriting the table with a PyArrow table.
Expand All @@ -1311,6 +1324,7 @@ def overwrite(
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
case_sensitive: bool = True,
) -> None:
"""
Expand All @@ -1327,17 +1341,23 @@ def overwrite(
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
snapshot_properties: Custom properties to be added to the snapshot summary
branch: Branch Reference to run the overwrite operation
case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive
"""
with self.transaction() as tx:
tx.overwrite(
df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties
df=df,
overwrite_filter=overwrite_filter,
case_sensitive=case_sensitive,
snapshot_properties=snapshot_properties,
branch=branch,
)

def delete(
self,
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
case_sensitive: bool = True,
) -> None:
"""
Expand All @@ -1346,10 +1366,12 @@ def delete(
Args:
delete_filter: The predicate that used to remove rows
snapshot_properties: Custom properties to be added to the snapshot summary
case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive
branch: Branch Reference to run the delete operation
"""
with self.transaction() as tx:
tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
tx.delete(
delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch
)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
Expand Down
8 changes: 6 additions & 2 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
MetadataLogEntry,
Snapshot,
Expand Down Expand Up @@ -139,7 +139,7 @@ class AddSnapshotUpdate(IcebergBaseModel):
class SetSnapshotRefUpdate(IcebergBaseModel):
action: Literal["set-snapshot-ref"] = Field(default="set-snapshot-ref")
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH]
snapshot_id: int = Field(alias="snapshot-id")
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
Expand Down Expand Up @@ -702,6 +702,10 @@ class AssertRefSnapshotId(ValidatableTableRequirement):
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif len(base_metadata.snapshots) == 0 and self.ref != MAIN_BRANCH:
raise CommitFailedException(
f"Requirement failed: Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH."
)
elif snapshot_ref := base_metadata.refs.get(self.ref):
ref_type = snapshot_ref.snapshot_ref_type
if self.snapshot_id is None:
Expand Down
Loading