Skip to content

Commit cd4e05c

Browse files
angel-coreOrbax Authors
authored andcommitted
Create v1 ocp.pytree_metadata backwards compatibility tests against static v0 and v1 checkpoints.
PiperOrigin-RevId: 895570820
1 parent 11260de commit cd4e05c

3 files changed

Lines changed: 344 additions & 24 deletions

File tree

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,18 @@ async def metadata(
112112
Returns:
113113
The metadata describing the Orbax checkpoint.
114114
"""
115-
checkpoint_metadata = await orbax_layout.read_checkpoint_metadata(
116-
path
117-
)
118-
# Delegate to OrbaxLayout if the checkpoint is a composite checkpoint.
119-
if checkpoint_metadata and isinstance(
120-
checkpoint_metadata.item_handlers, dict
115+
checkpoint_metadata = await orbax_layout.read_checkpoint_metadata(path)
116+
if (
117+
checkpoint_metadata
118+
and isinstance(checkpoint_metadata.item_handlers, str)
119+
or await orbax_layout.has_pytree_metadata_file(path)
121120
):
122-
return await self._orbax_layout.metadata(path)
123-
# Otherwise, load the metadata as a PyTree checkpoint.
124-
return await self._load_pytree_metadata(path, checkpoint_metadata)
121+
return await self._load_pytree_metadata(path, checkpoint_metadata)
122+
# Delegate to OrbaxLayout if the checkpoint is a composite checkpoint.
123+
# If there is no checkpoint metadata, we assume it is a composite
124+
# checkpoint, and even if it is a direct pytree checkpoint it will load as
125+
# a composite checkpoint as there is no metadata to indicate otherwise.
126+
return await self._orbax_layout.metadata(path)
125127

126128
async def _validate(self, path: Path) -> None:
127129
"""Validates a V0 checkpoint directory.

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,13 @@ def _determine_expected_outcome(
141141
is_pytree: bool,
142142
handler_registered: bool,
143143
pytree_registered: bool,
144-
) -> Tuple[bool, Type[Exception] | None, str | None]:
144+
) -> Tuple[Type[Exception] | None, str | None]:
145145
"""Encapsulates the complex boolean logic to determine load behavior."""
146146
# LAYOUT VALIDATION BEHAVIOR:
147147
if version == 'v1':
148148
# V1 strictly requires that checkpoint metadata is present.
149149
if not metadata_present:
150150
return (
151-
True,
152151
InvalidLayoutError,
153152
(
154153
r'Could not recognize the checkpoint at .* as a valid Orbax'
@@ -158,7 +157,6 @@ def _determine_expected_outcome(
158157
# V1 does not support loading a top-level pytree, redirects to V0.
159158
if checkpointable_name is None:
160159
return (
161-
True,
162160
ValueError,
163161
(
164162
r'Failed to interpret path .* as a .* Orbax PyTree'
@@ -173,7 +171,6 @@ def _determine_expected_outcome(
173171
or (not is_direct_checkpoint and checkpointable_name is None)
174172
):
175173
return (
176-
True,
177174
InvalidLayoutError,
178175
(
179176
r'Failed to interpret path .* as a .* Orbax PyTree'
@@ -197,15 +194,14 @@ def _determine_expected_outcome(
197194

198195
if not can_resolve:
199196
return (
200-
True,
201197
registration.NoEntryError,
202198
(
203199
r'Could not resolve a handler for .* and no \'pytree\' handler'
204200
r' found in .*'
205201
),
206202
)
207203

208-
return False, None, None
204+
return None, None
209205

210206
@parameterized.product(
211207
version=['v0', 'v1'],
@@ -244,7 +240,7 @@ def test_load_pytree_compatibility(
244240
pytree_registered,
245241
)
246242

247-
should_fail, error_type, expected_error_msg = (
243+
error_type, expected_error_msg = (
248244
self._determine_expected_outcome(
249245
version,
250246
checkpointable_name,
@@ -267,20 +263,20 @@ def test_load_pytree_compatibility(
267263
registry=registry
268264
)
269265
):
270-
if should_fail:
271-
with self.assertRaisesRegex(error_type, expected_error_msg):
272-
ocp.load_pytree(
273-
path,
274-
checkpointable_name=checkpointable_name,
275-
abstract_pytree=actual_abstract_pytree,
276-
)
277-
else:
266+
if error_type is None:
278267
loaded = ocp.load_pytree(
279268
path,
280269
checkpointable_name=checkpointable_name,
281270
abstract_pytree=actual_abstract_pytree,
282271
)
283272
test_utils.assert_tree_equal(self, loaded, self.expected_state)
273+
else:
274+
with self.assertRaisesRegex(error_type, expected_error_msg):
275+
ocp.load_pytree(
276+
path,
277+
checkpointable_name=checkpointable_name,
278+
abstract_pytree=actual_abstract_pytree,
279+
)
284280

285281
@parameterized.product(
286282
version=['v0', 'v1'],

0 commit comments

Comments
 (0)