Skip to content

Commit 3e9f09f

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Enable rich type support by default in PyTreeMetadataOptions.
PiperOrigin-RevId: 893125190
1 parent 5d0e139 commit 3e9f09f

5 files changed

Lines changed: 86 additions & 24 deletions

File tree

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -860,16 +860,15 @@ def _partial_restore_with_omission(
860860
restore_args: PyTree,
861861
) -> Tuple[PyTree, PyTree]:
862862
"""Restores leaves specified in `item`. Skips omitted leaves."""
863-
if not self._pytree_metadata_options.support_rich_types:
864-
# Replace empty containers with scalar values (zeros). During saving,
865-
# some empty containers (like named tuples) were given
866-
# ValueMetadataEntries as if they were scalars. We normalize these
867-
# containers to scalars so that tree_trim is none the wiser.
868-
serialized_item = jax.tree.map(
869-
lambda v: 0 if empty_values.is_empty_container(v) else v,
870-
serialized_item,
871-
is_leaf=tree_utils.is_empty_or_leaf,
872-
)
863+
# Replace empty containers with scalar values (zeros). During saving,
864+
# some empty containers (like named tuples) were given
865+
# ValueMetadataEntries as if they were scalars. We normalize these
866+
# containers to scalars so that tree_trim is none the wiser.
867+
serialized_item = jax.tree.map(
868+
lambda v: 0 if empty_values.is_empty_container(v) else v,
869+
serialized_item,
870+
is_leaf=tree_utils.is_empty_or_leaf,
871+
)
873872

874873
value_metadata_tree = tree_structure_utils.tree_trim(
875874
serialized_item, value_metadata_tree, strict=False
@@ -1057,12 +1056,35 @@ class TrainState:
10571056
serialized_item, value_metadata_tree
10581057
)
10591058
else:
1060-
# Deserialize value metadata tree to the same structure as item to allow
1061-
# for comparison with item that contains rich types.
1059+
# When support_rich_types is True, we convert types (dict to Params) so
1060+
# that the standard tree_difference lower down does not fail on container
1061+
# mismatches. Since deserialize_tree trims missing leaves, we must verify
1062+
# structural equality (ignoring types) first. If it fails, we raise early.
1063+
# The lower block acts as the verify-and-raise for when rich types are
1064+
# disabled, or a final safety check.
10621065
if self._pytree_metadata_options.support_rich_types:
1066+
if not tree_structure_utils.is_tree_structure_match_ignore_types(
1067+
value_metadata_tree, item
1068+
):
1069+
diff = tree_structure_utils.tree_difference(
1070+
serialized_item,
1071+
value_metadata_tree,
1072+
is_leaf=tree_utils.is_empty_or_leaf,
1073+
leaves_equal=lambda a, b: True,
1074+
)
1075+
formatted_diff = tree_structure_utils.format_tree_diff(
1076+
diff, source_label='Item', target_label='Metadata'
1077+
)
1078+
raise ValueError(
1079+
'User-provided restore item and on-disk value metadata tree'
1080+
f' structures do not match:\n{formatted_diff}\nIf this mismatch'
1081+
' is intentional, pass `partial_restore=True` to only restore'
1082+
' parameters found in `item`.'
1083+
)
10631084
value_metadata_tree = tree_utils.deserialize_tree(
10641085
value_metadata_tree, item
10651086
)
1087+
10661088
# is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
10671089
# dicts, lists, custom nodes) as leaves, as they do not contain any
10681090
# actual data to be restored, but are needed to maintain the structure.
@@ -1089,11 +1111,12 @@ class TrainState:
10891111
restore_args, self._pytree_metadata_options
10901112
)
10911113

1092-
if not self._pytree_metadata_options.support_rich_types:
1093-
value_metadata_tree = tree_utils.deserialize_tree(
1094-
value_metadata_tree, item
1095-
)
1096-
restore_args = tree_utils.deserialize_tree(restore_args, item)
1114+
value_metadata_tree_deserialized = tree_utils.deserialize_tree(
1115+
value_metadata_tree, item
1116+
)
1117+
restore_args_deserialized = tree_utils.deserialize_tree(restore_args, item)
1118+
value_metadata_tree = value_metadata_tree_deserialized
1119+
restore_args = restore_args_deserialized
10971120

10981121
param_infos = self._get_param_infos(
10991122
item=value_metadata_tree,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ def validate_metadata(
283283

284284
def _metadata(value):
285285
if empty_values.is_supported_empty_value(value, pytree_metadata_options):
286+
if (
287+
tree_utils.isinstance_of_namedtuple(value)
288+
and not value
289+
and pytree_metadata_options.support_rich_types
290+
):
291+
return empty_values.OrbaxEmptyNamedTuple()
286292
return value
287293
if isinstance(value, np.ndarray):
288294
return value_metadata.ArrayMetadata(
@@ -328,6 +334,9 @@ def _metadata(value):
328334
expected_reference_metadata_tree,
329335
is_leaf=tree_utils.is_empty_or_leaf,
330336
)
337+
expected_metadata = tree_utils.deserialize_tree(
338+
expected_metadata, actual_metadata.tree
339+
)
331340
test_utils.assert_tree_equal(self, expected_metadata, actual_metadata.tree)
332341

333342
def test_get_param_names(self):

checkpoint/orbax/checkpoint/_src/metadata/pytree_metadata_options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ class PyTreeMetadataOptions:
3434

3535

3636
# Global default options.
37-
PYTREE_METADATA_OPTIONS = PyTreeMetadataOptions(support_rich_types=False)
37+
PYTREE_METADATA_OPTIONS = PyTreeMetadataOptions(support_rich_types=True)

checkpoint/orbax/checkpoint/_src/tree/structure_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ def _tree_trim(
259259
structure_dict = structure
260260
elif structure is None:
261261
structure_dict = {}
262+
elif not strict and not isinstance(structure, abc.Mapping):
263+
structure_dict = {}
262264
else:
263265
raise TypeError(
264266
f'Type mismatch at key path {path}: template has type'
@@ -646,3 +648,37 @@ def build_mismatched_tree_structure_error(
646648

647649
formatted_diff = format_tree_diff(diff)
648650
return exception_cls(f'{log_message}.\n\n{formatted_diff}')
651+
652+
653+
def _get_raw_key(k: jax.tree_util.KeyEntry) -> str:
654+
"""Extracts raw key value from JAX KeyEntry by stripping punctuation."""
655+
s = jax.tree_util.keystr(k)
656+
return s.strip('[]"\'').strip('.')
657+
658+
659+
def is_tree_structure_match_ignore_types(
660+
a: PyTreeOf[Any], b: PyTreeOf[Any]
661+
) -> bool:
662+
"""Checks if two trees have identical keys at every level, ignoring container types."""
663+
is_a_leaf = utils.is_leaf_node(a)
664+
is_b_leaf = utils.is_leaf_node(b)
665+
666+
if is_a_leaf and is_b_leaf:
667+
return True
668+
if is_a_leaf or is_b_leaf:
669+
branch = a if not is_a_leaf else b
670+
children, _ = utils.tree_flatten_with_path_one_level(branch)
671+
return not children
672+
673+
a_children, _ = utils.tree_flatten_with_path_one_level(a)
674+
b_children, _ = utils.tree_flatten_with_path_one_level(b)
675+
676+
a_dict = {_get_raw_key(k): v for k, v in a_children}
677+
b_dict = {_get_raw_key(k): v for k, v in b_children}
678+
679+
if set(a_dict.keys()) != set(b_dict.keys()):
680+
return False
681+
682+
return all(
683+
is_tree_structure_match_ignore_types(a_dict[k], b_dict[k]) for k in a_dict
684+
)

checkpoint/orbax/checkpoint/aggregate_handlers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ async def serialize(
7575

7676
async def _serialize_fn(x):
7777
if utils.is_primary_host(self._primary_host):
78-
if self._pytree_metadata_options.support_rich_types:
79-
raise NotImplementedError(
80-
'Orbax does not support rich typed metadata in legacy msgpack'
81-
' checkpoint format. Please set'
82-
' PyTreeMetadataOptions.support_rich_types to False.'
83-
)
8478
serializable_dict = tree_metadata.serialize_tree(
8579
x, self._pytree_metadata_options
8680
)

0 commit comments

Comments
 (0)