Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
checkpoint_layout: options_lib.CheckpointLayout | None = None,
deletion_options: options_lib.DeletionOptions | None = None,
memory_options: options_lib.MemoryOptions | None = None,
safetensors_options: options_lib.SafetensorsOptions | None = None,
):
self._pytree_options = pytree_options or (
context.pytree_options if context else options_lib.PyTreeOptions()
Expand Down Expand Up @@ -156,6 +157,11 @@ def __init__(
self._memory_options = memory_options or (
context.memory_options if context else options_lib.MemoryOptions()
)
self._safetensors_options = safetensors_options or (
context.safetensors_options
if context
else options_lib.SafetensorsOptions()
)

@property
def pytree_options(self) -> options_lib.PyTreeOptions:
Expand Down Expand Up @@ -197,6 +203,10 @@ def deletion_options(self) -> options_lib.DeletionOptions:
def memory_options(self) -> options_lib.MemoryOptions:
return self._memory_options

@property
def safetensors_options(self) -> options_lib.SafetensorsOptions:
return self._safetensors_options

def operation_id(self) -> str:
return synchronization.OperationIdGenerator.get_current_operation_id()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,19 @@ class MemoryOptions:
is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None


@dataclasses.dataclass(frozen=True, kw_only=True)
class SafetensorsOptions:
"""Options for configuring Safetensors loading.

Attributes:
ignore_load_sharding: If True, skips sharding of the tensors across
hosts/devices during load. Whole tensors will be present on each host,
allowing for efficient conversion.
"""

ignore_load_sharding: bool = False


class CheckpointLayout(enum.Enum):
"""The layout of the checkpoint.

Expand Down
Loading
Loading