Skip to content

Commit 378826c

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Enable Memory Regulator for Orbax checkpoint saving.
PiperOrigin-RevId: 897298923
1 parent 58bcb29 commit 378826c

5 files changed

Lines changed: 359 additions & 95 deletions

File tree

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 105 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import asyncio
25+
import contextlib
2526
import dataclasses
2627
import functools
2728
import json
@@ -50,6 +51,7 @@
5051
from orbax.checkpoint._src.path import format_utils
5152
from orbax.checkpoint._src.path import types as path_types
5253
from orbax.checkpoint._src.serialization import limits
54+
from orbax.checkpoint._src.serialization import memory_regulator
5355
from orbax.checkpoint._src.serialization import ocdbt_utils
5456
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
5557
from orbax.checkpoint._src.serialization import type_handler_registry as type_handler_registry_lib
@@ -323,6 +325,17 @@ def _maybe_set_default_save_restore_args(v, leaf_args):
323325
)
324326

325327

328+
@contextlib.contextmanager
329+
def _memory_profiler_context():
330+
"""Context manager for memory_regulator profiler."""
331+
memory_regulator.profiler_start()
332+
try:
333+
yield
334+
finally:
335+
# Explicitly stop the bg thread if an exception occurs
336+
memory_regulator.profiler_end()
337+
338+
326339

327340

328341
def _format_bytes(bytes_value: Optional[int]) -> str:
@@ -347,7 +360,9 @@ def __init__(
347360
*,
348361
save_concurrent_bytes: Optional[int] = None,
349362
restore_concurrent_bytes: Optional[int] = None,
350-
save_device_host_concurrent_bytes: Optional[int] = None,
363+
save_device_host_concurrent_bytes: int | str | None = None,
364+
max_save_device_host_concurrent_bytes: int | None = None,
365+
fallback_host_limit_gb: int | None = None,
351366
use_ocdbt: bool = True,
352367
use_zarr3: bool = False,
353368
use_compression: bool = True,
@@ -377,7 +392,12 @@ def __init__(
377392
save_device_host_concurrent_bytes: max concurrent bytes allowed to be
378393
transferred from device to host memory at once when saving. When the
379394
limit is reached, arrays must be finished writing to the checkpoint
380-
before a new array can start being transferred.
395+
before a new array can start being transferred. Can be "auto".
396+
max_save_device_host_concurrent_bytes: The maximum memory limit in bytes
397+
allowed for regulation. Required if `save_device_host_concurrent_bytes`
398+
is "auto".
399+
fallback_host_limit_gb: Fallback physical machine size in GB to use if the
400+
profiler fails to fetch the total memory dynamically.
381401
use_ocdbt: Whether to use OCDBT format for saving.
382402
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
383403
use_compression: If True, use zstd compression.
@@ -408,6 +428,30 @@ def __init__(
408428
self._save_concurrent_bytes = save_concurrent_bytes
409429
self._restore_concurrent_bytes = restore_concurrent_bytes
410430
self._save_device_host_concurrent_bytes = save_device_host_concurrent_bytes
431+
self._max_save_device_host_concurrent_bytes = (
432+
max_save_device_host_concurrent_bytes
433+
)
434+
self._fallback_host_limit_gib = None
435+
if fallback_host_limit_gb is not None:
436+
self._fallback_host_limit_gib = (
437+
fallback_host_limit_gb * 10**9
438+
) / (1024**3)
439+
if self._save_device_host_concurrent_bytes == 'auto':
440+
if self._max_save_device_host_concurrent_bytes is None:
441+
raise ValueError(
442+
'max_save_device_host_concurrent_bytes must be provided if'
443+
' save_device_host_concurrent_bytes is "auto"'
444+
)
445+
max_memory_limit_gib = self._max_save_device_host_concurrent_bytes / (
446+
1024**3
447+
)
448+
self._memory_regulator = memory_regulator.MemoryRegulator(
449+
max_memory_limit_gib=max_memory_limit_gib,
450+
fallback_host_limit_gib=self._fallback_host_limit_gib,
451+
)
452+
self._current_device_host_limit_bytes = int(
453+
self._memory_regulator.min_memory_limit_gib * 1024**3
454+
)
411455
self._use_ocdbt = use_ocdbt
412456
self._use_zarr3 = use_zarr3
413457
self._use_compression = use_compression
@@ -656,9 +700,38 @@ async def async_save(
656700

657701
save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
658702
byte_limiter = limits.get_byte_limiter(self._save_concurrent_bytes)
659-
device_host_byte_limiter = limits.get_byte_limiter(
660-
self._save_device_host_concurrent_bytes
661-
)
703+
704+
device_host_concurrent_bytes = self._save_device_host_concurrent_bytes
705+
if device_host_concurrent_bytes == 'auto':
706+
peak_usage_gib = memory_regulator.profiler_peak_usage_gib()
707+
blocking_time_sec = memory_regulator.get_prev_blocking_time_sec()
708+
expected_surge_gib = memory_regulator.get_expected_surge_gib()
709+
710+
total_memory_gib = memory_regulator.get_total_memory_gib()
711+
current_limit_gib = self._current_device_host_limit_bytes / (1024**3)
712+
next_limit_gib = self._memory_regulator.get_next_memory_limit(
713+
current_limit_gib=current_limit_gib,
714+
peak_memory_usage_gib=peak_usage_gib,
715+
blocking_time_sec=blocking_time_sec,
716+
expected_surge_gib=expected_surge_gib,
717+
total_memory_gib=total_memory_gib,
718+
)
719+
self._current_device_host_limit_bytes = int(next_limit_gib * 1024**3)
720+
logging.info(
721+
'MemoryRegulated: Updated device_host_concurrent_bytes to %s'
722+
' (peak=%f GiB)',
723+
humanize.naturalsize(
724+
self._current_device_host_limit_bytes, binary=True
725+
),
726+
peak_usage_gib,
727+
)
728+
device_host_byte_limiter = limits.get_byte_limiter(
729+
self._current_device_host_limit_bytes
730+
)
731+
else:
732+
device_host_byte_limiter = limits.get_byte_limiter(
733+
device_host_concurrent_bytes
734+
)
662735
param_infos = self._get_param_infos(
663736
item,
664737
directory,
@@ -698,27 +771,34 @@ async def async_save(
698771
directory / PYTREE_METADATA_FILE
699772
)
700773
batch_requests_ready_time = time.time()
701-
if partial_save:
702-
serialize_ops, tree_memory_size, param_infos, save_args = (
703-
await self._async_partial_save(
704-
directory, item, batch_requests, param_infos, save_args
705-
)
706-
)
707-
else:
708-
tree_memory_size = 0
709-
for request in batch_requests:
710-
serialize_ops += [
711-
_logging_serialize(
712-
request.handler,
713-
request.handler.serialize(
714-
request.values, request.infos, request.args
715-
),
774+
with _memory_profiler_context():
775+
if partial_save:
776+
serialize_ops, tree_memory_size, param_infos, save_args = (
777+
await self._async_partial_save(
778+
directory, item, batch_requests, param_infos, save_args
716779
)
717-
]
718-
write_size, _ = _get_batch_memory_size(request.handler, request.values)
719-
tree_memory_size += write_size
720-
# Await copy futures. Returns List[List[future.Future]].
721-
commit_futures = await asyncio.gather(*serialize_ops)
780+
)
781+
else:
782+
tree_memory_size = 0
783+
for request in batch_requests:
784+
serialize_ops += [
785+
_logging_serialize(
786+
request.handler,
787+
request.handler.serialize(
788+
request.values, request.infos, request.args
789+
),
790+
)
791+
]
792+
write_size, _ = _get_batch_memory_size(
793+
request.handler, request.values
794+
)
795+
tree_memory_size += write_size
796+
# Await copy futures. Returns List[List[future.Future]].
797+
commit_futures = await asyncio.gather(*serialize_ops)
798+
logging.info(
799+
'MemoryRegulated: Peak usage: %f GiB',
800+
memory_regulator.profiler_peak_usage_gib(),
801+
)
722802
# Flatten to List[future.Future].
723803
commit_futures, _ = jax.tree.flatten(commit_futures)
724804

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,10 @@ def _get_impl_save_args(
460460

461461

462462
def _concurrent_bytes(
463-
concurrent_gb: int | None, *, use_default_if_none: bool = True
464-
) -> int | None:
463+
concurrent_gb: int | str | None, *, use_default_if_none: bool = True
464+
) -> int | str | None:
465+
if concurrent_gb == 'auto':
466+
return 'auto'
465467
if concurrent_gb is None:
466468
if use_default_if_none:
467469
return DEFAULT_CONCURRENT_GB * 10**9
@@ -500,7 +502,9 @@ def __init__(
500502
*,
501503
save_concurrent_gb: Optional[int] = None,
502504
restore_concurrent_gb: Optional[int] = None,
503-
save_device_host_concurrent_gb: Optional[int] = None,
505+
save_device_host_concurrent_gb: int | str | None = None,
506+
max_save_device_host_concurrent_gb: int | None = None,
507+
fallback_host_limit_gb: int | None = None,
504508
use_ocdbt: bool = True,
505509
use_zarr3: bool = False,
506510
use_compression: bool = True,
@@ -544,6 +548,12 @@ def __init__(
544548
transferred from device to host. Note that asynchronous saves may not be
545549
truly asynchronous with this option enabled, as we have to block on some
546550
array writes before beginning others. Also see `is_prioritized_key_fn`.
551+
Can be set to "auto" to enable Memory Regulator.
552+
max_save_device_host_concurrent_gb: The maximum memory limit in GB allowed
553+
for regulation. Required if `save_device_host_concurrent_gb` is set to
554+
"auto".
555+
fallback_host_limit_gb: Fallback physical machine size in GB to use if
556+
the profiler fails to fetch the total memory dynamically.
547557
use_ocdbt: enables Tensorstore OCDBT driver. This option allows using a
548558
different checkpoint format which is faster to read and write, as well
549559
as more space efficient.
@@ -586,9 +596,22 @@ def __init__(
586596
self._type_handler_registry = type_handler_registry
587597
self._save_concurrent_bytes = _concurrent_bytes(save_concurrent_gb)
588598
self._restore_concurrent_bytes = _concurrent_bytes(restore_concurrent_gb)
599+
if (
600+
save_device_host_concurrent_gb == 'auto'
601+
and max_save_device_host_concurrent_gb is None
602+
):
603+
raise ValueError(
604+
'max_save_device_host_concurrent_gb must be provided if'
605+
' save_device_host_concurrent_gb is "auto"'
606+
)
589607
self._save_device_host_concurrent_bytes = _concurrent_bytes(
590608
save_device_host_concurrent_gb, use_default_if_none=False
591609
)
610+
max_save_device_host_concurrent_bytes = (
611+
None
612+
if max_save_device_host_concurrent_gb is None
613+
else int(max_save_device_host_concurrent_gb * 10**9)
614+
)
592615
logging.info(
593616
'save_device_host_concurrent_bytes=%s',
594617
self._save_device_host_concurrent_bytes,
@@ -597,6 +620,8 @@ def __init__(
597620
save_concurrent_bytes=self._save_concurrent_bytes,
598621
restore_concurrent_bytes=self._restore_concurrent_bytes,
599622
save_device_host_concurrent_bytes=self._save_device_host_concurrent_bytes,
623+
max_save_device_host_concurrent_bytes=max_save_device_host_concurrent_bytes,
624+
fallback_host_limit_gb=fallback_host_limit_gb,
600625
use_ocdbt=use_ocdbt,
601626
use_zarr3=use_zarr3,
602627
use_compression=use_compression,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,7 +1943,8 @@ def test_concurrent_gb_save(self, limit_bytes):
19431943
with mock.patch.object(
19441944
limits,
19451945
'get_byte_limiter',
1946-
new=lambda _: byte_limiter,
1946+
autospec=True,
1947+
return_value=byte_limiter,
19471948
):
19481949
handler.save(self.directory, args=PyTreeSaveArgs(tree))
19491950
# Replicated shards are handled within the _write_array_shard function.
@@ -1959,6 +1960,25 @@ def test_concurrent_gb_save(self, limit_bytes):
19591960
sleep_time,
19601961
)
19611962

1963+
def test_concurrent_gb_auto_save(self):
1964+
handler = PyTreeCheckpointHandler(
1965+
save_device_host_concurrent_gb='auto',
1966+
max_save_device_host_concurrent_gb=80,
1967+
use_ocdbt=False,
1968+
)
1969+
tree = {'a': np.ones(10)}
1970+
# Verify it doesn't crash when saving with "auto".
1971+
with mock.patch.object(
1972+
base_pytree_checkpoint_handler.memory_regulator,
1973+
'get_total_memory_gib',
1974+
autospec=True,
1975+
return_value=250.0,
1976+
):
1977+
handler.save(self.directory, args=PyTreeSaveArgs(tree))
1978+
# Verify we can also restore
1979+
restored = handler.restore(self.directory)
1980+
np.testing.assert_array_equal(restored['a'], tree['a'])
1981+
19621982
@parameterized.parameters((5,), (9,))
19631983
def test_concurrent_gb_restore(self, limit_bytes):
19641984
# TODO(b/346811105): Enable for Pathways.
@@ -1975,7 +1995,8 @@ def test_concurrent_gb_restore(self, limit_bytes):
19751995
with mock.patch.object(
19761996
limits,
19771997
'get_byte_limiter',
1978-
new=lambda _,: byte_limiter,
1998+
autospec=True,
1999+
return_value=byte_limiter,
19792000
):
19802001
restored = handler.restore(
19812002
self.directory,

0 commit comments

Comments
 (0)