diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 57f09ba172..f391abfea2 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1714,7 +1714,7 @@ def fill_parquet_file_metadata( data_file.split_offsets = split_offsets -def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: +def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[Schema] = None) -> Iterator[DataFile]: task = next(tasks) try: @@ -1727,7 +1727,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties) file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}' - file_schema = schema_to_pyarrow(table.schema()) + file_schema = file_schema or table.schema() + arrow_file_schema = schema_to_pyarrow(file_schema) fo = table.io.new_output(file_path) row_group_size = PropertyUtil.property_as_int( @@ -1736,7 +1737,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, ) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=file_schema, **parquet_writer_kwargs) as writer: + with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: writer.write_table(task.df, row_group_size=row_group_size) data_file = DataFile( @@ -1758,8 +1759,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: fill_parquet_file_metadata( data_file=data_file, parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(table.schema(), table.properties), - parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), + stats_columns=compute_statistics_plan(file_schema, table.properties), + parquet_column_mapping=parquet_path_to_id_mapping(file_schema), ) return iter([data_file]) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index a87435fcfb..a939294f30 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -329,6 +329,14 @@ def update_schema(self) -> UpdateSchema: """ return UpdateSchema(self._table, self) + def update_snapshot(self) -> UpdateSnapshot: + """Create a new UpdateSnapshot to produce a new snapshot for the table. + + Returns: + A new UpdateSnapshot + """ + return UpdateSnapshot(self._table, self) + def remove_properties(self, *removals: str) -> Transaction: """Remove properties. @@ -351,6 +359,12 @@ def update_location(self, location: str) -> Transaction: """ raise NotImplementedError("Not yet implemented") + def schema(self) -> Schema: + try: + return next(update for update in self._updates if isinstance(update, AddSchemaUpdate)).schema_ + except StopIteration: + return self._table.schema() + def commit_transaction(self) -> Table: """Commit the changes to the catalog. @@ -965,8 +979,21 @@ def history(self) -> List[SnapshotLogEntry]: return self.metadata.snapshot_log def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: + """Create a new UpdateSchema to alter the columns of this table. + + Returns: + A new UpdateSchema. + """ return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) + def update_snapshot(self) -> UpdateSnapshot: + """Create a new UpdateSnapshot to produce a new snapshot for the table. + + Returns: + A new UpdateSnapshot + """ + return UpdateSnapshot(self) + def name_mapping(self) -> NameMapping: """Return the table's field-id NameMapping.""" if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING): @@ -976,7 +1003,7 @@ def name_mapping(self) -> NameMapping: def append(self, df: pa.Table) -> None: """ - Append data to the table. + Shorthand API for appending a PyArrow table to the table. Args: df: The Arrow dataframe that will be appended to overwrite the table @@ -992,19 +1019,16 @@ def append(self, df: pa.Table) -> None: if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self) - - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files(self, df=df) - for data_file in data_files: - merge.append_data_file(data_file) - - merge.commit() + with self.update_snapshot().fast_append() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files(self, df=df) + for data_file in data_files: + update_snapshot.append_data_file(data_file) def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None: """ - Overwrite all the data in the table. + Shorthand for overwriting the table with a PyArrow table. Args: df: The Arrow dataframe that will be used to overwrite the table @@ -1025,18 +1049,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - merge = _MergingSnapshotProducer( - operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND, - table=self, - ) - - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files(self, df=df) - for data_file in data_files: - merge.append_data_file(data_file) - - merge.commit() + with self.update_snapshot().overwrite() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files(self, df=df) + for data_file in data_files: + update_snapshot.append_data_file(data_file) def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" @@ -2331,7 +2349,12 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int, return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro' -def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]: +def _dataframe_to_data_files(table: Table, df: pa.Table, file_schema: Optional[Schema] = None) -> Iterable[DataFile]: + """Convert a PyArrow table into a DataFile. + + Returns: + An iterable that supplies datafiles that represent the table. + """ from pyiceberg.io.pyarrow import write_file if len(table.spec().fields) > 0: @@ -2342,7 +2365,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]: # This is an iter, so we don't have to materialize everything every time # This will be more relevant when we start doing partitioned writes - yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)])) + yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]), file_schema=file_schema) class _MergingSnapshotProducer: @@ -2352,8 +2375,9 @@ class _MergingSnapshotProducer: _parent_snapshot_id: Optional[int] _added_data_files: List[DataFile] _commit_uuid: uuid.UUID + _transaction: Optional[Transaction] - def __init__(self, operation: Operation, table: Table) -> None: + def __init__(self, operation: Operation, table: Table, transaction: Optional[Transaction] = None) -> None: self._operation = operation self._table = table self._snapshot_id = table.new_snapshot_id() @@ -2361,46 +2385,25 @@ def __init__(self, operation: Operation, table: Table) -> None: self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None self._added_data_files = [] self._commit_uuid = uuid.uuid4() + self._transaction = transaction + + def __enter__(self) -> _MergingSnapshotProducer: + """Start a transaction to update the table.""" + return self + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Close and commit the transaction.""" + self.commit() def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer: self._added_data_files.append(data_file) return self - def _deleted_entries(self) -> List[ManifestEntry]: - """To determine if we need to record any deleted entries. - - With partial overwrites we have to use the predicate to evaluate - which entries are affected. - """ - if self._operation == Operation.OVERWRITE: - if self._parent_snapshot_id is not None: - previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) - if previous_snapshot is None: - # This should never happen since you cannot overwrite an empty table - raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") - - executor = ExecutorFactory.get_or_create() - - def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: - return [ - ManifestEntry( - status=ManifestEntryStatus.DELETED, - snapshot_id=entry.snapshot_id, - data_sequence_number=entry.data_sequence_number, - file_sequence_number=entry.file_sequence_number, - data_file=entry.data_file, - ) - for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True) - if entry.data_file.content == DataFileContent.DATA - ] + @abstractmethod + def _deleted_entries(self) -> List[ManifestEntry]: ... - list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io)) - return list(chain(*list_of_entries)) - return [] - elif self._operation == Operation.APPEND: - return [] - else: - raise ValueError(f"Not implemented for: {self._operation}") + @abstractmethod + def _existing_manifests(self) -> List[ManifestFile]: ... def _manifests(self) -> List[ManifestFile]: def _write_added_manifest() -> List[ManifestFile]: @@ -2430,7 +2433,7 @@ def _write_added_manifest() -> List[ManifestFile]: def _write_delete_manifest() -> List[ManifestFile]: # Check if we need to mark the files as deleted deleted_entries = self._deleted_entries() - if deleted_entries: + if len(deleted_entries) > 0: output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self._commit_uuid) with write_manifest( format_version=self._table.format_version, @@ -2445,32 +2448,11 @@ def _write_delete_manifest() -> List[ManifestFile]: else: return [] - def _fetch_existing_manifests() -> List[ManifestFile]: - existing_manifests = [] - - # Add existing manifests - if self._operation == Operation.APPEND and self._parent_snapshot_id is not None: - # In case we want to append, just add the existing manifests - previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) - - if previous_snapshot is None: - raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}") - - for manifest in previous_snapshot.manifests(io=self._table.io): - if ( - manifest.has_added_files() - or manifest.has_existing_files() - or manifest.added_snapshot_id == self._snapshot_id - ): - existing_manifests.append(manifest) - - return existing_manifests - executor = ExecutorFactory.get_or_create() added_manifests = executor.submit(_write_added_manifest) delete_manifests = executor.submit(_write_delete_manifest) - existing_manifests = executor.submit(_fetch_existing_manifests) + existing_manifests = executor.submit(self._existing_manifests) return added_manifests.result() + delete_manifests.result() + existing_manifests.result() @@ -2515,10 +2497,107 @@ def commit(self) -> Snapshot: schema_id=self._table.schema().schema_id, ) - with self._table.transaction() as tx: - tx.add_snapshot(snapshot=snapshot) - tx.set_ref_snapshot( + if self._transaction is not None: + self._transaction.add_snapshot(snapshot=snapshot) + self._transaction.set_ref_snapshot( snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" ) + else: + with self._table.transaction() as tx: + tx.add_snapshot(snapshot=snapshot) + tx.set_ref_snapshot( + snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" + ) return snapshot + + +class FastAppendFiles(_MergingSnapshotProducer): + def _existing_manifests(self) -> List[ManifestFile]: + """To determine if there are any existing manifest files. + + A fast append will add another ManifestFile to the ManifestList. + All the existing manifest files are considered existing. + """ + existing_manifests = [] + + if self._parent_snapshot_id is not None: + previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + + if previous_snapshot is None: + raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}") + + for manifest in previous_snapshot.manifests(io=self._table.io): + if manifest.has_added_files() or manifest.has_existing_files() or manifest.added_snapshot_id == self._snapshot_id: + existing_manifests.append(manifest) + + return existing_manifests + + def _deleted_entries(self) -> List[ManifestEntry]: + """To determine if we need to record any deleted manifest entries. + + In case of an append, nothing is deleted. + """ + return [] + + +class OverwriteFiles(_MergingSnapshotProducer): + def _existing_manifests(self) -> List[ManifestFile]: + """To determine if there are any existing manifest files. + + In the of a full overwrite, all the previous manifests are + considered deleted. + """ + return [] + + def _deleted_entries(self) -> List[ManifestEntry]: + """To determine if we need to record any deleted entries. + + With a full overwrite all the entries are considered deleted. + With partial overwrites we have to use the predicate to evaluate + which entries are affected. + """ + if self._parent_snapshot_id is not None: + previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + if previous_snapshot is None: + # This should never happen since you cannot overwrite an empty table + raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") + + executor = ExecutorFactory.get_or_create() + + def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: + return [ + ManifestEntry( + status=ManifestEntryStatus.DELETED, + snapshot_id=entry.snapshot_id, + data_sequence_number=entry.data_sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True) + if entry.data_file.content == DataFileContent.DATA + ] + + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io)) + return list(chain(*list_of_entries)) + else: + return [] + + +class UpdateSnapshot: + _table: Table + _transaction: Optional[Transaction] + + def __init__(self, table: Table, transaction: Optional[Transaction] = None) -> None: + self._table = table + self._transaction = transaction + + def fast_append(self) -> FastAppendFiles: + return FastAppendFiles(table=self._table, operation=Operation.APPEND, transaction=self._transaction) + + def overwrite(self) -> OverwriteFiles: + return OverwriteFiles( + table=self._table, + operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND, + transaction=self._transaction, + ) diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index 54b647b8ed..fa5e93d925 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -35,6 +35,7 @@ from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NamespaceAlreadyExistsError, NoSuchTableError from pyiceberg.schema import Schema +from pyiceberg.table import _dataframe_to_data_files from pyiceberg.types import ( BinaryType, BooleanType, @@ -634,3 +635,44 @@ def test_duckdb_url_import(warehouse: Path, arrow_table_with_null: pa.Table) -> b'\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11', ), ] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.arrow_write_data_and_evolve_schema_v{format_version}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + tbl = session_catalog.create_table( + identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} + ) + + pa_table_with_column = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + 'bar': [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with tbl.transaction() as txn: + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table=tbl, df=pa_table_with_column, file_schema=txn.schema()): + snapshot_update.append_data_file(data_file)