Skip to content

Commit 090b876

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Add multi-host support to Safetensors loading in Orbax.
PiperOrigin-RevId: 895468231
1 parent ff834e2 commit 090b876

4 files changed

Lines changed: 570 additions & 4 deletions

File tree

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
checkpoint_layout: options_lib.CheckpointLayout | None = None,
120120
deletion_options: options_lib.DeletionOptions | None = None,
121121
memory_options: options_lib.MemoryOptions | None = None,
122+
safetensors_options: options_lib.SafetensorsOptions | None = None,
122123
):
123124
self._pytree_options = pytree_options or (
124125
context.pytree_options if context else options_lib.PyTreeOptions()
@@ -156,6 +157,11 @@ def __init__(
156157
self._memory_options = memory_options or (
157158
context.memory_options if context else options_lib.MemoryOptions()
158159
)
160+
self._safetensors_options = safetensors_options or (
161+
context.safetensors_options
162+
if context
163+
else options_lib.SafetensorsOptions()
164+
)
159165

160166
@property
161167
def pytree_options(self) -> options_lib.PyTreeOptions:
@@ -197,6 +203,10 @@ def deletion_options(self) -> options_lib.DeletionOptions:
197203
def memory_options(self) -> options_lib.MemoryOptions:
198204
return self._memory_options
199205

206+
@property
207+
def safetensors_options(self) -> options_lib.SafetensorsOptions:
208+
return self._safetensors_options
209+
200210
def operation_id(self) -> str:
201211
return synchronization.OperationIdGenerator.get_current_operation_id()
202212

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ class MemoryOptions:
570570
is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None
571571

572572

573+
@dataclasses.dataclass(frozen=True, kw_only=True)
574+
class SafetensorsOptions:
575+
"""Options for configuring Safetensors loading.
576+
577+
Attributes:
578+
ignore_load_sharding: If True, skips sharding of the tensors across
579+
hosts/devices during load. Whole tensors will be present on each host,
580+
allowing for efficient conversion.
581+
"""
582+
583+
ignore_load_sharding: bool = False
584+
585+
573586
class CheckpointLayout(enum.Enum):
574587
"""The layout of the checkpoint.
575588

0 commit comments

Comments
 (0)