@@ -141,14 +141,13 @@ def _determine_expected_outcome(
141141 is_pytree : bool ,
142142 handler_registered : bool ,
143143 pytree_registered : bool ,
144- ) -> Tuple [bool , Type [Exception ] | None , str | None ]:
144+ ) -> Tuple [Type [Exception ] | None , str | None ]:
145145 """Encapsulates the complex boolean logic to determine load behavior."""
146146 # LAYOUT VALIDATION BEHAVIOR:
147147 if version == 'v1' :
148148 # V1 strictly requires that checkpoint metadata is present.
149149 if not metadata_present :
150150 return (
151- True ,
152151 InvalidLayoutError ,
153152 (
154153 r'Could not recognize the checkpoint at .* as a valid Orbax'
@@ -158,7 +157,6 @@ def _determine_expected_outcome(
158157 # V1 does not support loading a top-level pytree, redirects to V0.
159158 if checkpointable_name is None :
160159 return (
161- True ,
162160 ValueError ,
163161 (
164162 r'Failed to interpret path .* as a .* Orbax PyTree'
@@ -173,7 +171,6 @@ def _determine_expected_outcome(
173171 or (not is_direct_checkpoint and checkpointable_name is None )
174172 ):
175173 return (
176- True ,
177174 InvalidLayoutError ,
178175 (
179176 r'Failed to interpret path .* as a .* Orbax PyTree'
@@ -197,15 +194,14 @@ def _determine_expected_outcome(
197194
198195 if not can_resolve :
199196 return (
200- True ,
201197 registration .NoEntryError ,
202198 (
203199 r'Could not resolve a handler for .* and no \'pytree\' handler'
204200 r' found in .*'
205201 ),
206202 )
207203
208- return False , None , None
204+ return None , None
209205
210206 @parameterized .product (
211207 version = ['v0' , 'v1' ],
@@ -244,7 +240,7 @@ def test_load_pytree_compatibility(
244240 pytree_registered ,
245241 )
246242
247- should_fail , error_type , expected_error_msg = (
243+ error_type , expected_error_msg = (
248244 self ._determine_expected_outcome (
249245 version ,
250246 checkpointable_name ,
@@ -267,20 +263,20 @@ def test_load_pytree_compatibility(
267263 registry = registry
268264 )
269265 ):
270- if should_fail :
271- with self .assertRaisesRegex (error_type , expected_error_msg ):
272- ocp .load_pytree (
273- path ,
274- checkpointable_name = checkpointable_name ,
275- abstract_pytree = actual_abstract_pytree ,
276- )
277- else :
266+ if error_type is None :
278267 loaded = ocp .load_pytree (
279268 path ,
280269 checkpointable_name = checkpointable_name ,
281270 abstract_pytree = actual_abstract_pytree ,
282271 )
283272 test_utils .assert_tree_equal (self , loaded , self .expected_state )
273+ else :
274+ with self .assertRaisesRegex (error_type , expected_error_msg ):
275+ ocp .load_pytree (
276+ path ,
277+ checkpointable_name = checkpointable_name ,
278+ abstract_pytree = actual_abstract_pytree ,
279+ )
284280
285281 @parameterized .product (
286282 version = ['v0' , 'v1' ],
0 commit comments