Skip to content

Commit b370a1e

Browse files
angel-coreOrbax Authors
authored andcommitted
Add generic options field to SerializationParam to pass SaveArgs through.
PiperOrigin-RevId: 897257284
1 parent 7ad416d commit b370a1e

7 files changed

Lines changed: 31 additions & 70 deletions

File tree

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from absl import logging
2727
import jax
28-
import jax.numpy as jnp
2928
from orbax.checkpoint._src.arrays import types as arrays_types_v0
3029
from orbax.checkpoint._src.futures import future
3130
from orbax.checkpoint._src.metadata import sharding as sharding_metadata
@@ -39,7 +38,9 @@
3938

4039
Shape = arrays_types_v0.Shape
4140
AbstractShardedArray = types.AbstractShardedArray
42-
ArraySerializationParam = types.SerializationParam[jax.Array]
41+
ArraySerializationParam = types.SerializationParam[
42+
jax.Array, type_handlers_v0.SaveArgs
43+
]
4344
ArrayDeserializationParam = types.DeserializationParam[AbstractShardedArray]
4445

4546

@@ -107,23 +108,6 @@ def _create_v0_saving_paraminfo(
107108
)
108109

109110

110-
def _create_v0_savearg(
111-
param: ArraySerializationParam,
112-
context: context_lib.Context,
113-
) -> type_handlers_v0.SaveArgs:
114-
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
115-
fn = context.pytree_options.saving.create_array_storage_options_fn
116-
if fn:
117-
storage_options = fn(param.keypath, param.value)
118-
else:
119-
storage_options = context.array_options.saving.storage_options
120-
return type_handlers_v0.SaveArgs(
121-
dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None,
122-
chunk_byte_size=storage_options.chunk_byte_size,
123-
shard_axes=storage_options.shard_axes,
124-
)
125-
126-
127111
def _create_v0_restore_paraminfo(
128112
param: (
129113
types.DeserializationParam[None]
@@ -223,7 +207,7 @@ async def serialize(
223207
_create_v0_saving_paraminfo(p, self._context, serialization_context)
224208
for p in params
225209
]
226-
saveargs = [_create_v0_savearg(p, self._context) for p in params]
210+
saveargs = [p.options for p in params]
227211

228212
commit_futures = await self._handler_impl.serialize(
229213
values, paraminfos, saveargs

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/compatibility.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ def _keypath_from_param_name(param_name: str) -> tree_types.PyTreeKeyPath:
8989
def _construct_serialization_param(
9090
value: types.Leaf,
9191
info: types_v0.ParamInfo,
92-
) -> types.SerializationParam[types.Leaf]:
92+
options: Any | None = None,
93+
) -> types.SerializationParam[types.Leaf, Any]:
9394
return types.SerializationParam(
9495
keypath=_keypath_from_param_name(info.name),
9596
value=value,
97+
options=options,
9698
)
9799

98100

@@ -290,9 +292,9 @@ async def serialize(
290292

291293
params = []
292294
info0 = infos[0]
293-
for info, value in zip(infos, values):
295+
for info, value, arg in zip(infos, values, args or [None] * len(values)):
294296
logging.vlog(1, 'info: %s', info)
295-
params.append(_construct_serialization_param(value, info))
297+
params.append(_construct_serialization_param(value, info, options=arg))
296298
serialization_context = _construct_serialization_context(info0)
297299
serialization_task = await self._leaf_handler.serialize(
298300
params, serialization_context

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
from orbax.checkpoint.experimental.v1._src.serialization import types
3535

3636

37-
NumpySerializationParam = types.SerializationParam[np.ndarray]
38-
NumpyDeserializationParam = types.DeserializationParam[
39-
types.AbstractArray
37+
NumpySerializationParam = types.SerializationParam[
38+
np.ndarray, type_handlers_v0.SaveArgs
4039
]
40+
NumpyDeserializationParam = types.DeserializationParam[types.AbstractArray]
4141
Shape = arrays_types.Shape
4242
AbstractArray = types.AbstractArray
4343

@@ -94,23 +94,6 @@ def _create_v0_saving_paraminfo(
9494
)
9595

9696

97-
def _create_v0_savearg(
98-
param: NumpySerializationParam,
99-
context: context_lib.Context,
100-
) -> type_handlers_v0.SaveArgs:
101-
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
102-
fn = context.pytree_options.saving.create_array_storage_options_fn
103-
if fn:
104-
storage_options = fn(param.keypath, param.value)
105-
else:
106-
storage_options = context.array_options.saving.storage_options
107-
return type_handlers_v0.SaveArgs(
108-
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
109-
chunk_byte_size=storage_options.chunk_byte_size,
110-
shard_axes=storage_options.shard_axes,
111-
)
112-
113-
11497
def _create_v0_restore_paraminfo(
11598
param: types.DeserializationParam[AbstractArray | None],
11699
context: context_lib.Context,
@@ -188,7 +171,7 @@ async def serialize(
188171
_create_v0_saving_paraminfo(p, self._context, serialization_context)
189172
for p in params
190173
]
191-
saveargs = [_create_v0_savearg(p, self._context) for p in params]
174+
saveargs = [p.options for p in params]
192175

193176
commit_futures = await self._handler_impl.serialize(
194177
values, paraminfos, saveargs

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232

3333
Scalar = types.Scalar
3434
AbstractScalar = types.AbstractScalar
35-
ScalarSerializationParam = types.SerializationParam[Scalar]
36-
ScalarDeserializationParam = types.DeserializationParam[
37-
AbstractScalar
35+
ScalarSerializationParam = types.SerializationParam[
36+
Scalar, type_handlers_v0.SaveArgs
3837
]
38+
ScalarDeserializationParam = types.DeserializationParam[AbstractScalar]
3939

4040

4141
def _create_v0_scalar_handler() -> type_handlers_v0.ScalarHandler:
@@ -65,23 +65,6 @@ def _create_v0_saving_paraminfo(
6565
)
6666

6767

68-
def _create_v0_savearg(
69-
param: ScalarSerializationParam,
70-
context: context_lib.Context,
71-
) -> type_handlers_v0.SaveArgs:
72-
"""Creates a V0 SaveArgs from V1 params and context for saving."""
73-
fn = context.pytree_options.saving.create_array_storage_options_fn
74-
if fn:
75-
storage_options = fn(param.keypath, param.value)
76-
else:
77-
storage_options = context.array_options.saving.storage_options
78-
return type_handlers_v0.SaveArgs(
79-
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
80-
chunk_byte_size=storage_options.chunk_byte_size,
81-
shard_axes=storage_options.shard_axes,
82-
)
83-
84-
8568
def _create_v0_restore_paraminfo(
8669
param: types.DeserializationParam[
8770
AbstractScalar | Type[AbstractScalar] | None
@@ -168,7 +151,7 @@ async def serialize(
168151
_create_v0_saving_paraminfo(p, self._context, serialization_context)
169152
for p in params
170153
]
171-
saveargs = [_create_v0_savearg(p, self._context) for p in params]
154+
saveargs = [p.options for p in params]
172155

173156
commit_futures = await self._handler_impl.serialize(
174157
values, paraminfos, saveargs

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from orbax.checkpoint.experimental.v1._src.serialization import types
2929

3030
AbstractString = types.AbstractString
31-
StringSerializationParam = types.SerializationParam[str]
31+
StringSerializationParam = types.SerializationParam[str, None]
3232
StringDeserializationParam = types.DeserializationParam[
3333
AbstractString
3434
]

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@ class AbstractShardedArray(Protocol):
9797
def is_placeholder(value: Any) -> bool:
9898
return value is PLACEHOLDER
9999

100+
OptionsT = TypeVar('OptionsT')
101+
100102

101103
@dataclasses.dataclass
102-
class SerializationParam(Generic[Leaf]):
104+
class SerializationParam(Generic[Leaf, OptionsT]):
103105
"""Represents a specific leaf-level parameter within a PyTree.
104106
105107
SerializationParam represents a single PyTree leaf by pairing its value
@@ -133,9 +135,11 @@ async def serialize(
133135
value (Any): The data associated with the leaf. This could be a jax.Array,
134136
a numpy.ndarray, or a metadata object depending on the stage of
135137
the checkpointing process.
138+
options: Optional[OptionsT] = None: Optional options for the leaf.
136139
"""
137140
keypath: tree_types.PyTreeKeyPath
138141
value: Leaf
142+
options: OptionsT | None = None
139143

140144
@property
141145
def name(self) -> str:
@@ -287,7 +291,7 @@ class LeafHandler(Protocol[Leaf, AbstractLeaf]):
287291

288292
async def serialize(
289293
self,
290-
params: Sequence[SerializationParam[Leaf]],
294+
params: Sequence[SerializationParam[Leaf, Any]],
291295
serialization_context: SerializationContext,
292296
) -> Awaitable[None]:
293297
"""Writes the specified leaves of a checkpointable to a storage location.

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,16 @@ def __init__(self, context=None):
388388

389389
async def serialize(
390390
self,
391-
params: Sequence[types.SerializationParam[LazyArray]],
391+
params: Sequence[types.SerializationParam[LazyArray, Any]],
392392
serialization_context: types.SerializationContext,
393393
) -> Awaitable[None]:
394394
array_params = [
395-
types.SerializationParam(p.keypath, p.value.array) for p in params
395+
types.SerializationParam(
396+
p.keypath,
397+
p.value.array,
398+
p.options,
399+
)
400+
for p in params
396401
]
397402
return await self._array_handler.serialize(
398403
array_params, serialization_context

0 commit comments

Comments
 (0)