diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 68089beb54..08e3ada8bb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -450,6 +450,32 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) + def replace( + self, + files_to_delete: Iterable[DataFile], + files_to_add: Iterable[DataFile], + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: + """ + Shorthand for replacing existing files with new files. + + A replace will produce a REPLACE snapshot that will ignore existing + files and replace them with the new files. + + Args: + files_to_delete: The files to delete + files_to_add: The new files to add that replace the deleted files + snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the replace operation + """ + with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).replace() as replace_snapshot: + for file_to_delete in files_to_delete: + replace_snapshot.delete_data_file(file_to_delete) + + for data_file in files_to_add: + replace_snapshot.append_data_file(data_file) + def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. @@ -1384,6 +1410,33 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, with self.transaction() as tx: tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) + def replace( + self, + files_to_delete: Iterable[DataFile], + files_to_add: Iterable[DataFile], + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: + """ + Shorthand for replacing existing files with new files. + + A replace will produce a REPLACE snapshot that will ignore existing + files and replace them with the new files. + + Args: + files_to_delete: The files to delete + files_to_add: The new files to add that replace the deleted files + snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the replace operation + """ + with self.transaction() as tx: + tx.replace( + files_to_delete=files_to_delete, + files_to_add=files_to_add, + snapshot_properties=snapshot_properties, + branch=branch, + ) + def dynamic_partition_overwrite( self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH ) -> None: diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..7bd4597399 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -344,7 +344,7 @@ def _partition_summary(self, update_metrics: UpdateMetrics) -> str: def update_snapshot_summaries(summary: Summary, previous_summary: Mapping[str, str] | None = None) -> Summary: - if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE}: + if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE, Operation.REPLACE}: raise ValueError(f"Operation not implemented: {summary.operation}") if not previous_summary: diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 37d120969a..0157a40eb8 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -667,6 +667,82 @@ def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: return [] +class _RewriteFiles(_SnapshotProducer["_RewriteFiles"]): + """A snapshot producer that rewrites data files.""" + + def __init__(self, operation: Operation, transaction: Transaction, io: FileIO, snapshot_properties: dict[str, str]): + super().__init__(operation, transaction, io, snapshot_properties=snapshot_properties) + + def _commit(self) -> UpdatesAndRequirements: + # Only produce a commit when there is something to rewrite + if self._deleted_data_files or self._added_data_files: + return super()._commit() + else: + return (), () + + def _deleted_entries(self) -> list[ManifestEntry]: + """Check if we need to mark the files as deleted.""" + if self._parent_snapshot_id is not None: + previous_snapshot = self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) + if previous_snapshot is None: + raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") + + executor = ExecutorFactory.get_or_create() + + def _get_entries(manifest: ManifestFile) -> list[ManifestEntry]: + return [ + ManifestEntry.from_args( + status=ManifestEntryStatus.DELETED, + snapshot_id=entry.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + for entry in manifest.fetch_manifest_entry(self._io, discard_deleted=True) + if entry.data_file.content == DataFileContent.DATA and entry.data_file in self._deleted_data_files + ] + + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._io)) + return list(itertools.chain(*list_of_entries)) + else: + return [] + + def _existing_manifests(self) -> list[ManifestFile]: + """To determine if there are any existing manifests.""" + existing_files = [] + if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch): + for manifest_file in snapshot.manifests(io=self._io): + entries_to_write: set[ManifestEntry] = set() + found_deleted_entries: set[ManifestEntry] = set() + + for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): + if entry.data_file in self._deleted_data_files: + found_deleted_entries.add(entry) + else: + entries_to_write.add(entry) + + if len(found_deleted_entries) == 0: + existing_files.append(manifest_file) + continue + + if len(entries_to_write) == 0: + continue + + with self.new_manifest_writer(self.spec(manifest_file.partition_spec_id)) as writer: + for entry in entries_to_write: + writer.add_entry( + ManifestEntry.from_args( + status=ManifestEntryStatus.EXISTING, + snapshot_id=entry.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + ) + existing_files.append(writer.to_manifest_file()) + return existing_files + + class UpdateSnapshot: _transaction: Transaction _io: FileIO @@ -724,6 +800,14 @@ def delete(self) -> _DeleteFiles: snapshot_properties=self._snapshot_properties, ) + def replace(self) -> _RewriteFiles: + return _RewriteFiles( + operation=Operation.REPLACE, + transaction=self._transaction, + io=self._io, + snapshot_properties=self._snapshot_properties, + ) + class _ManifestMergeManager(Generic[U]): _target_size_bytes: int diff --git a/tests/table/test_replace.py b/tests/table/test_replace.py new file mode 100644 index 0000000000..60f270c0a7 --- /dev/null +++ b/tests/table/test_replace.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from pyiceberg.catalog import Catalog +from pyiceberg.manifest import DataFile, DataFileContent, FileFormat +from pyiceberg.schema import Schema +from pyiceberg.table.snapshots import Operation +from pyiceberg.typedef import Record + + +def test_replace_api(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace", + schema=Schema(), + ) + + # Create mock DataFiles for the test + file_to_delete = DataFile.from_args( + file_path="s3://bucket/test/data/deleted.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_delete.spec_id = 0 + + file_to_add = DataFile.from_args( + file_path="s3://bucket/test/data/added.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=100, + file_size_in_bytes=1024, + content=DataFileContent.DATA, + ) + file_to_add.spec_id = 0 + + # Initially append to have something to replace + with table.transaction() as tx: + with tx.update_snapshot().fast_append() as append_snapshot: + append_snapshot.append_data_file(file_to_delete) + + # Verify initial append snapshot + assert len(table.history()) == 1 + snapshot = table.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + assert snapshot.summary["operation"] == Operation.APPEND + + # Call the replace API + table.replace(files_to_delete=[file_to_delete], files_to_add=[file_to_add]) + + # Verify the replacement created a REPLACE snapshot + assert len(table.history()) == 2 + snapshot = table.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + assert snapshot.summary["operation"] == Operation.REPLACE + + # Verify the correct files are added and deleted + # The summary property tracks these counts + assert snapshot.summary["added-data-files"] == "1" + assert snapshot.summary["deleted-data-files"] == "1" + assert snapshot.summary["added-records"] == "100" + assert snapshot.summary["deleted-records"] == "100" + + # Verify the new file exists in the new manifest + manifest_files = snapshot.manifests(table.io) + assert len(manifest_files) == 2 # One for ADDED, one for DELETED + + # Check that sequence numbers were handled properly natively by verifying the manifest contents + entries = [] + for manifest in manifest_files: + for entry in manifest.fetch_manifest_entry(table.io, discard_deleted=False): + entries.append(entry) + + # One entry for ADDED (new file), one for DELETED (old file) + assert len(entries) == 2 + + +def test_replace_empty_files(catalog: Catalog) -> None: + # Setup a basic table using the catalog fixture + catalog.create_namespace("default") + table = catalog.create_table( + identifier="default.test_replace_empty", + schema=Schema(), + ) + + # Replacing empty lists should not throw errors, but should produce no changes. + table.replace([], []) + + # History should be completely empty since no files were rewritten + assert len(table.history()) == 0 + assert table.current_snapshot() is None diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index cfdc516227..7f78a7546d 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -398,8 +398,8 @@ def test_merge_snapshot_summaries_overwrite_summary() -> None: def test_invalid_operation() -> None: with pytest.raises(ValueError) as e: - update_snapshot_summaries(summary=Summary(Operation.REPLACE)) - assert "Operation not implemented: Operation.REPLACE" in str(e.value) + update_snapshot_summaries(summary=Summary.model_construct(operation="unknown_operation")) + assert "Operation not implemented: unknown_operation" in str(e.value) def test_invalid_type() -> None: