Skip to content

Commit 1aad92d

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Add multi-host support to Safetensors loading in Orbax.
PiperOrigin-RevId: 895468231
1 parent d7880aa commit 1aad92d

5 files changed

Lines changed: 913 additions & 4 deletions

File tree

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
checkpoint_layout: options_lib.CheckpointLayout | None = None,
120120
deletion_options: options_lib.DeletionOptions | None = None,
121121
memory_options: options_lib.MemoryOptions | None = None,
122+
safetensors_options: options_lib.SafetensorsOptions | None = None,
122123
):
123124
self._pytree_options = pytree_options or (
124125
context.pytree_options if context else options_lib.PyTreeOptions()
@@ -156,6 +157,11 @@ def __init__(
156157
self._memory_options = memory_options or (
157158
context.memory_options if context else options_lib.MemoryOptions()
158159
)
160+
self._safetensors_options = safetensors_options or (
161+
context.safetensors_options
162+
if context
163+
else options_lib.SafetensorsOptions()
164+
)
159165

160166
@property
161167
def pytree_options(self) -> options_lib.PyTreeOptions:
@@ -197,6 +203,10 @@ def deletion_options(self) -> options_lib.DeletionOptions:
197203
def memory_options(self) -> options_lib.MemoryOptions:
198204
return self._memory_options
199205

206+
@property
207+
def safetensors_options(self) -> options_lib.SafetensorsOptions:
208+
return self._safetensors_options
209+
200210
def operation_id(self) -> str:
201211
return synchronization.OperationIdGenerator.get_current_operation_id()
202212

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ class MemoryOptions:
570570
is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None
571571

572572

573+
@dataclasses.dataclass(frozen=True, kw_only=True)
574+
class SafetensorsOptions:
575+
"""Options for configuring Safetensors loading.
576+
577+
Attributes:
578+
ignore_load_sharding: If True, skips sharding of the tensors across
579+
hosts/devices during load. Whole tensors will be present on each host,
580+
allowing for efficient conversion.
581+
"""
582+
583+
ignore_load_sharding: bool = False
584+
585+
573586
class CheckpointLayout(enum.Enum):
574587
"""The layout of the checkpoint.
575588
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmark for safetensors layout."""
16+
17+
import asyncio
18+
import json
19+
import time
20+
21+
from absl import app
22+
from absl import flags
23+
from absl import logging
24+
from etils import epath
25+
import jax
26+
import jax.sharding
27+
import numpy as np
28+
from orbax.checkpoint._src.arrays import numpy_utils
29+
from orbax.checkpoint._src.path import async_path
30+
from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout
31+
32+
33+
Mesh = jax.sharding.Mesh
34+
NamedSharding = jax.sharding.NamedSharding
35+
PartitionSpec = jax.sharding.PartitionSpec
36+
37+
_ROWS = 128
38+
39+
FLAGS = flags.FLAGS
40+
41+
_TENSOR_SIZES_MB = flags.DEFINE_list(
42+
"tensor_sizes_mb",
43+
["256"],
44+
"List of tensor sizes in MB to include in the file.",
45+
)
46+
_GCS_DIR = flags.DEFINE_string(
47+
"gcs_dir",
48+
None,
49+
"GCS directory for benchmark.",
50+
required=True,
51+
)
52+
_DISABLE_OLD_BENCHMARK = flags.DEFINE_boolean(
53+
"disable_old_benchmark",
54+
False,
55+
"If true, only run the new benchmark (layout).",
56+
)
57+
58+
59+
# Wrapper for tracking read bytes while performing real IO.
60+
class TrackingFile:
61+
"""Wrapper for tracking read bytes while performing real IO."""
62+
63+
def __init__(self, f):
64+
self.f = f
65+
self.bytes_read = 0
66+
67+
async def seek(self, offset):
68+
await self.f.seek(offset)
69+
70+
async def read(self, size=-1):
71+
data = await self.f.read(size)
72+
self.bytes_read += len(data)
73+
return data
74+
75+
76+
async def _read_non_contiguous_slice(
77+
f, idx, stored_shape, stored_dtype, tensor_file_offset
78+
):
79+
"""Reads a non-contiguous slice from a file."""
80+
if not idx:
81+
await f.seek(tensor_file_offset)
82+
num_bytes = np.dtype(stored_dtype).itemsize
83+
data = await f.read(num_bytes)
84+
return np.frombuffer(data, dtype=stored_dtype)
85+
86+
# Calculate global strides for the stored shape.
87+
itemsize = np.dtype(stored_dtype).itemsize
88+
global_strides = [itemsize] * len(stored_shape)
89+
for i in range(len(stored_shape) - 2, -1, -1):
90+
global_strides[i] = global_strides[i + 1] * stored_shape[i + 1]
91+
92+
shard_shape = numpy_utils.slice_shape(idx)
93+
out_array = np.empty(shard_shape, dtype=stored_dtype)
94+
95+
# Recursively read the slice.
96+
async def _read_slice_recursively(
97+
dim: int, base_offset: int, out_idx: tuple[int, ...]
98+
):
99+
s = idx[dim]
100+
if dim == len(stored_shape) - 1:
101+
start = base_offset + s.start * global_strides[dim]
102+
num_bytes = (s.stop - s.start) * itemsize
103+
await f.seek(tensor_file_offset + start)
104+
data = await f.read(num_bytes)
105+
106+
# Assign the chunk of bytes into the correct slice of the output array
107+
out_array[out_idx] = np.frombuffer(data, dtype=stored_dtype)
108+
return
109+
110+
# Recursively read the slice for each dimension.
111+
for out_i, i in enumerate(range(s.start, s.stop)):
112+
offset = base_offset + i * global_strides[dim]
113+
await _read_slice_recursively(dim + 1, offset, out_idx + (out_i,))
114+
115+
# Start the recursive reading process from the first dimension.
116+
await _read_slice_recursively(dim=0, base_offset=0, out_idx=())
117+
return out_array
118+
119+
120+
async def _benchmark_old(file_path, sharding, tensor_sizes: list[int]):
121+
"""Benchmarks a current read."""
122+
logging.info("Starting _benchmark_old for %s", file_path)
123+
async with async_path.open_file(file_path, mode="rb") as raw_f:
124+
f = TrackingFile(raw_f)
125+
target_dtype = np.float32
126+
127+
# Read header size from file.
128+
header_size_bytes = await f.read(8)
129+
header_size = int.from_bytes(header_size_bytes, byteorder="little")
130+
start_data_offset = 8 + header_size
131+
132+
current_offset = 0
133+
restored_tensors = []
134+
for i, size_mb in enumerate(tensor_sizes):
135+
num_elements = size_mb * 1024 * 1024 // 4
136+
rows = _ROWS
137+
cols = num_elements // rows
138+
target_shape = (rows, cols)
139+
tensor_size_bytes = num_elements * 4
140+
tensor_offset = start_data_offset + current_offset
141+
current_offset += tensor_size_bytes
142+
143+
device_indices_map = sharding.addressable_devices_indices_map(
144+
target_shape
145+
)
146+
logging.info(
147+
"Reading shards for tensor_%d for %d addressable devices",
148+
i,
149+
len(sharding.addressable_devices),
150+
)
151+
device_map = []
152+
# Guarantee strict iteration order matching addressable_devices
153+
for device in sharding.addressable_devices:
154+
idx = device_indices_map[device]
155+
resolved_idx = numpy_utils.resolve_slice(idx, target_shape)
156+
shard_shape = numpy_utils.slice_shape(resolved_idx)
157+
158+
shard_np = await _read_non_contiguous_slice(
159+
f, resolved_idx, target_shape, target_dtype, tensor_offset
160+
)
161+
shard_np = shard_np.reshape(shard_shape)
162+
device_map.append(jax.device_put(shard_np, device))
163+
164+
logging.info(
165+
"Assembling device arrays into global array for tensor_%d", i
166+
)
167+
restored = jax.make_array_from_single_device_arrays(
168+
target_shape, sharding, device_map
169+
)
170+
restored_tensors.append(restored)
171+
172+
logging.info("Blocking until ready (old)")
173+
for restored in restored_tensors:
174+
jax.block_until_ready(restored)
175+
logging.info("Finished _benchmark_old")
176+
177+
return restored_tensors, np.int64(f.bytes_read)
178+
179+
180+
async def _benchmark_current(file_path, sharding, tensor_sizes: list[int]):
181+
"""Benchmarks the new SafetensorsLayout implementation."""
182+
logging.info("Starting _benchmark_current (new) for %s", file_path)
183+
layout = safetensors_layout.SafetensorsLayout()
184+
abstract_pytree = {}
185+
for i, size_mb in enumerate(tensor_sizes):
186+
num_elements = size_mb * 1024 * 1024 // 4
187+
rows = _ROWS
188+
cols = num_elements // rows
189+
shape = (rows, cols)
190+
abstract_pytree[f"tensor_{i}"] = jax.ShapeDtypeStruct(
191+
shape=shape, dtype=np.float32, sharding=sharding
192+
)
193+
194+
restore_fn = await layout.load_pytree(
195+
file_path, abstract_pytree=abstract_pytree
196+
)
197+
restored_pytree = await restore_fn
198+
199+
logging.info("Blocking until ready (current)")
200+
for i in range(len(tensor_sizes)):
201+
jax.block_until_ready(restored_pytree[f"tensor_{i}"])
202+
logging.info("Finished _benchmark_current")
203+
204+
num_hosts = jax.process_count()
205+
total_size_bytes = sum(size * 1024 * 1024 for size in tensor_sizes)
206+
bytes_read = total_size_bytes // num_hosts
207+
208+
return restored_pytree, np.int64(bytes_read)
209+
210+
211+
async def _create_file_if_needed(
212+
path: epath.Path,
213+
tensor_sizes: list[int],
214+
):
215+
"""Creates a dummy safetensors file if it doesn't exist."""
216+
if jax.process_index() != 0:
217+
return
218+
219+
header_dict = {}
220+
current_offset = 0
221+
for i, size_mb in enumerate(tensor_sizes):
222+
num_elements = size_mb * 1024 * 1024 // 4
223+
rows = _ROWS
224+
cols = num_elements // rows
225+
shape = [rows, cols]
226+
size_bytes = num_elements * 4
227+
header_dict[f"tensor_{i}"] = {
228+
"dtype": "F32",
229+
"shape": shape,
230+
"data_offsets": [current_offset, current_offset + size_bytes],
231+
}
232+
current_offset += size_bytes
233+
234+
header_json = json.dumps(header_dict).encode("utf-8")
235+
236+
# Pad header to multiple of 8 bytes.
237+
padding_len = (8 - len(header_json) % 8) % 8
238+
header_json += b" " * padding_len
239+
240+
header_size = len(header_json)
241+
header_size_bytes = header_size.to_bytes(8, byteorder="little")
242+
243+
total_bytes_to_write = current_offset
244+
expected_file_size = 8 + header_size + total_bytes_to_write
245+
246+
if path.exists() and path.stat().length == expected_file_size:
247+
logging.info(
248+
"File %s already exists with correct size, skipping creation.", path
249+
)
250+
return
251+
252+
logging.info("Creating dummy file %s with size %d", path, expected_file_size)
253+
with path.open("wb") as f:
254+
f.write(header_size_bytes)
255+
f.write(header_json)
256+
chunk_size = 1024 * 1024 * 100
257+
bytes_written = 0
258+
while bytes_written < total_bytes_to_write:
259+
write_size = min(chunk_size, total_bytes_to_write - bytes_written)
260+
f.write(b"\0" * write_size)
261+
bytes_written += write_size
262+
263+
264+
async def run_benchmarks(sharding_type, tensor_sizes: list[int]):
265+
"""Runs benchmarks for a given sharding type and tensor sizes."""
266+
if not _GCS_DIR.value:
267+
return
268+
269+
dir_path = epath.Path(_GCS_DIR.value)
270+
if jax.process_index() == 0 and not dir_path.exists():
271+
dir_path.mkdir(parents=True, exist_ok=True)
272+
273+
# Ensure directory is created by rank 0 before others proceed
274+
jax.experimental.multihost_utils.sync_global_devices("mkdir")
275+
276+
devices = jax.devices()
277+
mesh_shape = (len(devices) // 2, 2)
278+
mesh = Mesh(np.array(devices).reshape(mesh_shape), ("data", "model"))
279+
280+
if sharding_type == "leading":
281+
sharding_spec = PartitionSpec("data", None)
282+
else:
283+
sharding_spec = PartitionSpec(None, "model")
284+
285+
sharding = NamedSharding(mesh, sharding_spec)
286+
287+
sizes_str = "_".join(map(str, tensor_sizes))
288+
file_path = (
289+
dir_path / f"benchmark_v2_{sharding_type}_{sizes_str}mb.safetensors"
290+
)
291+
await _create_file_if_needed(file_path, tensor_sizes)
292+
jax.experimental.multihost_utils.sync_global_devices("create_file")
293+
294+
t_old = 0.0
295+
bytes_old_total = 0
296+
num_hosts = jax.process_count()
297+
298+
if not _DISABLE_OLD_BENCHMARK.value:
299+
t0 = time.time()
300+
_, bytes_old = await _benchmark_old(file_path, sharding, tensor_sizes)
301+
t_old = time.time() - t0
302+
bytes_old_total = int(bytes_old) * num_hosts
303+
304+
jax.experimental.multihost_utils.sync_global_devices(
305+
"sync_between_benchmarks"
306+
)
307+
308+
t0 = time.time()
309+
_, bytes_new = await _benchmark_current(file_path, sharding, tensor_sizes)
310+
t_new = time.time() - t0
311+
312+
bytes_new_total = int(bytes_new) * num_hosts
313+
314+
if jax.process_index() == 0:
315+
res = "\n=======================================================\n"
316+
res += (
317+
f"Results for {sharding_type} sharding, sizes: {tensor_sizes} MB, "
318+
f"{num_hosts} hosts, gcs storage\n"
319+
)
320+
if not _DISABLE_OLD_BENCHMARK.value:
321+
res += (
322+
f"Old (Manual): {t_old*1000:.2f}ms, Bytes read:"
323+
f" {bytes_old_total / 1024 / 1024:.2f}MB\n"
324+
)
325+
res += (
326+
f"New (Layout): {t_new*1000:.2f}ms, Bytes read:"
327+
f" {bytes_new_total / 1024 / 1024:.2f}MB\n"
328+
)
329+
res += "=======================================================\n"
330+
logging.info(res)
331+
332+
333+
def main(_):
334+
tensor_sizes = [int(s) for s in _TENSOR_SIZES_MB.value]
335+
for sharding_type in ["leading", "trailing"]:
336+
asyncio.run(run_benchmarks(sharding_type, tensor_sizes))
337+
338+
339+
if __name__ == "__main__":
340+
app.run(main)

0 commit comments

Comments
 (0)