Skip to content

Commit e47fdfd

Browse files
committed
Add replace_branch to ManageSnapshots
1 parent d99e463 commit e47fdfd

2 files changed

Lines changed: 89 additions & 1 deletion

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,38 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots:
894894
"""
895895
return self._remove_ref_snapshot(ref_name=tag_name)
896896

897+
def replace_branch(self, branch_name: str, snapshot_id: int) -> ManageSnapshots:
898+
"""
899+
Replace the branch with the given name to point to the specified snapshot.
900+
901+
Args:
902+
branch_name (str): Branch to replace
903+
snapshot_id (int): new snapshot id for the given branch
904+
Returns:
905+
This for method chaining
906+
"""
907+
self._commit_if_ref_updates_exist()
908+
909+
refs = self._transaction.table_metadata.refs
910+
if branch_name not in refs:
911+
raise ValueError(f"Branch does not exist: {branch_name}")
912+
913+
ref = refs[branch_name]
914+
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
915+
raise ValueError(f"Ref {branch_name} is not a branch")
916+
917+
update, requirement = self._transaction._set_ref_snapshot(
918+
snapshot_id=snapshot_id,
919+
ref_name=branch_name,
920+
type=SnapshotRefType.BRANCH,
921+
max_ref_age_ms=ref.max_ref_age_ms,
922+
max_snapshot_age_ms=ref.max_snapshot_age_ms,
923+
min_snapshots_to_keep=ref.min_snapshots_to_keep,
924+
)
925+
self._updates += update
926+
self._requirements += requirement
927+
return self
928+
897929
def create_branch(
898930
self,
899931
snapshot_id: int,

tests/integration/test_snapshot_operations.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from pyiceberg.catalog import Catalog
2525
from pyiceberg.table import Table
26-
from pyiceberg.table.refs import SnapshotRef
26+
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
2727

2828

2929
@pytest.fixture
@@ -107,6 +107,62 @@ def test_remove_branch(catalog: Catalog) -> None:
107107
assert tbl.metadata.refs.get(branch_name, None) is None
108108

109109

110+
@pytest.mark.integration
111+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
112+
def test_replace_branch(catalog: Catalog) -> None:
113+
identifier = "default.test_table_snapshot_operations"
114+
tbl = catalog.load_table(identifier)
115+
assert len(tbl.history()) > 2
116+
117+
current_snapshot_id = tbl.history()[-1].snapshot_id
118+
older_snapshot_id = tbl.history()[-2].snapshot_id
119+
120+
branch_name = "my-branch"
121+
tbl.manage_snapshots().create_branch(older_snapshot_id, branch_name, 1, 2, 3).commit()
122+
branch = tbl.metadata.refs.get(branch_name)
123+
assert branch is not None
124+
assert branch.snapshot_id == older_snapshot_id
125+
assert branch.snapshot_ref_type == SnapshotRefType.BRANCH
126+
assert branch.max_ref_age_ms == 1
127+
assert branch.max_snapshot_age_ms == 2
128+
assert branch.min_snapshots_to_keep == 3
129+
130+
tbl.manage_snapshots().replace_branch(branch_name=branch_name, snapshot_id=current_snapshot_id).commit()
131+
132+
branch = tbl.metadata.refs.get(branch_name)
133+
assert branch is not None
134+
assert branch.snapshot_id == current_snapshot_id
135+
assert branch.snapshot_ref_type == SnapshotRefType.BRANCH
136+
assert branch.max_ref_age_ms == 1
137+
assert branch.max_snapshot_age_ms == 2
138+
assert branch.min_snapshots_to_keep == 3
139+
140+
141+
@pytest.mark.integration
142+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
143+
def test_replace_missing_branch(catalog: Catalog) -> None:
144+
identifier = "default.test_table_snapshot_operations"
145+
tbl = catalog.load_table(identifier)
146+
snapshot_id = tbl.history()[-1].snapshot_id
147+
148+
with pytest.raises(ValueError, match="Branch does not exist: test"):
149+
tbl.manage_snapshots().replace_branch(branch_name="test", snapshot_id=snapshot_id).commit()
150+
151+
152+
@pytest.mark.integration
153+
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
154+
def test_replace_branch_with_tag(catalog: Catalog) -> None:
155+
identifier = "default.test_table_snapshot_operations"
156+
tbl = catalog.load_table(identifier)
157+
snapshot_id = tbl.history()[-1].snapshot_id
158+
159+
tag_name = "my-tag"
160+
tbl.manage_snapshots().create_tag(snapshot_id=snapshot_id, tag_name=tag_name).commit()
161+
162+
with pytest.raises(ValueError, match="Ref my-tag is not a branch"):
163+
tbl.manage_snapshots().replace_branch(branch_name=tag_name, snapshot_id=snapshot_id).commit()
164+
165+
110166
@pytest.mark.integration
111167
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
112168
def test_set_current_snapshot(catalog: Catalog) -> None:

0 commit comments

Comments
 (0)