@@ -860,16 +860,15 @@ def _partial_restore_with_omission(
860860 restore_args : PyTree ,
861861 ) -> Tuple [PyTree , PyTree ]:
862862 """Restores leaves specified in `item`. Skips omitted leaves."""
863- if not self ._pytree_metadata_options .support_rich_types :
864- # Replace empty containers with scalar values (zeros). During saving,
865- # some empty containers (like named tuples) were given
866- # ValueMetadataEntries as if they were scalars. We normalize these
867- # containers to scalars so that tree_trim is none the wiser.
868- serialized_item = jax .tree .map (
869- lambda v : 0 if empty_values .is_empty_container (v ) else v ,
870- serialized_item ,
871- is_leaf = tree_utils .is_empty_or_leaf ,
872- )
863+ # Replace empty containers with scalar values (zeros). During saving,
864+ # some empty containers (like named tuples) were given
865+ # ValueMetadataEntries as if they were scalars. We normalize these
866+ # containers to scalars so that tree_trim is none the wiser.
867+ serialized_item = jax .tree .map (
868+ lambda v : 0 if empty_values .is_empty_container (v ) else v ,
869+ serialized_item ,
870+ is_leaf = tree_utils .is_empty_or_leaf ,
871+ )
873872
874873 value_metadata_tree = tree_structure_utils .tree_trim (
875874 serialized_item , value_metadata_tree , strict = False
@@ -1057,12 +1056,35 @@ class TrainState:
10571056 serialized_item , value_metadata_tree
10581057 )
10591058 else :
1060- # Deserialize value metadata tree to the same structure as item to allow
1061- # for comparison with item that contains rich types.
1059+ # When support_rich_types is True, we convert types (dict to Params) so
1060+ # that the standard tree_difference lower down does not fail on container
1061+ # mismatches. Since deserialize_tree trims missing leaves, we must verify
1062+ # structural equality (ignoring types) first. If it fails, we raise early.
1063+ # The lower block acts as the verify-and-raise for when rich types are
1064+ # disabled, or a final safety check.
10621065 if self ._pytree_metadata_options .support_rich_types :
1066+ if not tree_structure_utils .is_tree_structure_match_ignore_types (
1067+ value_metadata_tree , item
1068+ ):
1069+ diff = tree_structure_utils .tree_difference (
1070+ serialized_item ,
1071+ value_metadata_tree ,
1072+ is_leaf = tree_utils .is_empty_or_leaf ,
1073+ leaves_equal = lambda a , b : True ,
1074+ )
1075+ formatted_diff = tree_structure_utils .format_tree_diff (
1076+ diff , source_label = 'Item' , target_label = 'Metadata'
1077+ )
1078+ raise ValueError (
1079+ 'User-provided restore item and on-disk value metadata tree'
1080+ f' structures do not match:\n { formatted_diff } \n If this mismatch'
1081+ ' is intentional, pass `partial_restore=True` to only restore'
1082+ ' parameters found in `item`.'
1083+ )
10631084 value_metadata_tree = tree_utils .deserialize_tree (
10641085 value_metadata_tree , item
10651086 )
1087+
10661088 # is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
10671089 # dicts, lists, custom nodes) as leaves, as they do not contain any
10681090 # actual data to be restored, but are needed to maintain the structure.
@@ -1089,11 +1111,12 @@ class TrainState:
10891111 restore_args , self ._pytree_metadata_options
10901112 )
10911113
1092- if not self ._pytree_metadata_options .support_rich_types :
1093- value_metadata_tree = tree_utils .deserialize_tree (
1094- value_metadata_tree , item
1095- )
1096- restore_args = tree_utils .deserialize_tree (restore_args , item )
1114+ value_metadata_tree_deserialized = tree_utils .deserialize_tree (
1115+ value_metadata_tree , item
1116+ )
1117+ restore_args_deserialized = tree_utils .deserialize_tree (restore_args , item )
1118+ value_metadata_tree = value_metadata_tree_deserialized
1119+ restore_args = restore_args_deserialized
10971120
10981121 param_infos = self ._get_param_infos (
10991122 item = value_metadata_tree ,
0 commit comments