2222from __future__ import annotations
2323
2424import asyncio
25+ import contextlib
2526import dataclasses
2627import functools
2728import json
5051from orbax .checkpoint ._src .path import format_utils
5152from orbax .checkpoint ._src .path import types as path_types
5253from orbax .checkpoint ._src .serialization import limits
54+ from orbax .checkpoint ._src .serialization import memory_regulator
5355from orbax .checkpoint ._src .serialization import ocdbt_utils
5456from orbax .checkpoint ._src .serialization import tensorstore_utils as ts_utils
5557from orbax .checkpoint ._src .serialization import type_handler_registry as type_handler_registry_lib
@@ -323,6 +325,17 @@ def _maybe_set_default_save_restore_args(v, leaf_args):
323325 )
324326
325327
328+ @contextlib .contextmanager
329+ def _memory_profiler_context ():
330+ """Context manager for memory_regulator profiler."""
331+ memory_regulator .profiler_start ()
332+ try :
333+ yield
334+ finally :
335+ # Explicitly stop the bg thread if an exception occurs
336+ memory_regulator .profiler_end ()
337+
338+
326339
327340
328341def _format_bytes (bytes_value : Optional [int ]) -> str :
@@ -347,7 +360,9 @@ def __init__(
347360 * ,
348361 save_concurrent_bytes : Optional [int ] = None ,
349362 restore_concurrent_bytes : Optional [int ] = None ,
350- save_device_host_concurrent_bytes : Optional [int ] = None ,
363+ save_device_host_concurrent_bytes : int | str | None = None ,
364+ max_save_device_host_concurrent_bytes : int | None = None ,
365+ fallback_host_limit_gb : int | None = None ,
351366 use_ocdbt : bool = True ,
352367 use_zarr3 : bool = False ,
353368 use_compression : bool = True ,
@@ -377,7 +392,12 @@ def __init__(
377392 save_device_host_concurrent_bytes: max concurrent bytes allowed to be
378393 transferred from device to host memory at once when saving. When the
379394 limit is reached, arrays must be finished writing to the checkpoint
380- before a new array can start being transferred.
395+ before a new array can start being transferred. Can be "auto".
396+ max_save_device_host_concurrent_bytes: The maximum memory limit in bytes
397+ allowed for regulation. Required if `save_device_host_concurrent_bytes`
398+ is "auto".
399+ fallback_host_limit_gb: Fallback physical machine size in GB to use if the
400+ profiler fails to fetch the total memory dynamically.
381401 use_ocdbt: Whether to use OCDBT format for saving.
382402 use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
383403 use_compression: If True, use zstd compression.
@@ -408,6 +428,30 @@ def __init__(
408428 self ._save_concurrent_bytes = save_concurrent_bytes
409429 self ._restore_concurrent_bytes = restore_concurrent_bytes
410430 self ._save_device_host_concurrent_bytes = save_device_host_concurrent_bytes
431+ self ._max_save_device_host_concurrent_bytes = (
432+ max_save_device_host_concurrent_bytes
433+ )
434+ self ._fallback_host_limit_gib = None
435+ if fallback_host_limit_gb is not None :
436+ self ._fallback_host_limit_gib = (
437+ fallback_host_limit_gb * 10 ** 9
438+ ) / (1024 ** 3 )
439+ if self ._save_device_host_concurrent_bytes == 'auto' :
440+ if self ._max_save_device_host_concurrent_bytes is None :
441+ raise ValueError (
442+ 'max_save_device_host_concurrent_bytes must be provided if'
443+ ' save_device_host_concurrent_bytes is "auto"'
444+ )
445+ max_memory_limit_gib = self ._max_save_device_host_concurrent_bytes / (
446+ 1024 ** 3
447+ )
448+ self ._memory_regulator = memory_regulator .MemoryRegulator (
449+ max_memory_limit_gib = max_memory_limit_gib ,
450+ fallback_host_limit_gib = self ._fallback_host_limit_gib ,
451+ )
452+ self ._current_device_host_limit_bytes = int (
453+ self ._memory_regulator .min_memory_limit_gib * 1024 ** 3
454+ )
411455 self ._use_ocdbt = use_ocdbt
412456 self ._use_zarr3 = use_zarr3
413457 self ._use_compression = use_compression
@@ -656,9 +700,38 @@ async def async_save(
656700
657701 save_args = _fill_missing_save_or_restore_args (item , save_args , mode = 'save' )
658702 byte_limiter = limits .get_byte_limiter (self ._save_concurrent_bytes )
659- device_host_byte_limiter = limits .get_byte_limiter (
660- self ._save_device_host_concurrent_bytes
661- )
703+
704+ device_host_concurrent_bytes = self ._save_device_host_concurrent_bytes
705+ if device_host_concurrent_bytes == 'auto' :
706+ peak_usage_gib = memory_regulator .profiler_peak_usage_gib ()
707+ blocking_time_sec = memory_regulator .get_prev_blocking_time_sec ()
708+ expected_surge_gib = memory_regulator .get_expected_surge_gib ()
709+
710+ total_memory_gib = memory_regulator .get_total_memory_gib ()
711+ current_limit_gib = self ._current_device_host_limit_bytes / (1024 ** 3 )
712+ next_limit_gib = self ._memory_regulator .get_next_memory_limit (
713+ current_limit_gib = current_limit_gib ,
714+ peak_memory_usage_gib = peak_usage_gib ,
715+ blocking_time_sec = blocking_time_sec ,
716+ expected_surge_gib = expected_surge_gib ,
717+ total_memory_gib = total_memory_gib ,
718+ )
719+ self ._current_device_host_limit_bytes = int (next_limit_gib * 1024 ** 3 )
720+ logging .info (
721+ 'MemoryRegulated: Updated device_host_concurrent_bytes to %s'
722+ ' (peak=%f GiB)' ,
723+ humanize .naturalsize (
724+ self ._current_device_host_limit_bytes , binary = True
725+ ),
726+ peak_usage_gib ,
727+ )
728+ device_host_byte_limiter = limits .get_byte_limiter (
729+ self ._current_device_host_limit_bytes
730+ )
731+ else :
732+ device_host_byte_limiter = limits .get_byte_limiter (
733+ device_host_concurrent_bytes
734+ )
662735 param_infos = self ._get_param_infos (
663736 item ,
664737 directory ,
@@ -698,27 +771,34 @@ async def async_save(
698771 directory / PYTREE_METADATA_FILE
699772 )
700773 batch_requests_ready_time = time .time ()
701- if partial_save :
702- serialize_ops , tree_memory_size , param_infos , save_args = (
703- await self ._async_partial_save (
704- directory , item , batch_requests , param_infos , save_args
705- )
706- )
707- else :
708- tree_memory_size = 0
709- for request in batch_requests :
710- serialize_ops += [
711- _logging_serialize (
712- request .handler ,
713- request .handler .serialize (
714- request .values , request .infos , request .args
715- ),
774+ with _memory_profiler_context ():
775+ if partial_save :
776+ serialize_ops , tree_memory_size , param_infos , save_args = (
777+ await self ._async_partial_save (
778+ directory , item , batch_requests , param_infos , save_args
716779 )
717- ]
718- write_size , _ = _get_batch_memory_size (request .handler , request .values )
719- tree_memory_size += write_size
720- # Await copy futures. Returns List[List[future.Future]].
721- commit_futures = await asyncio .gather (* serialize_ops )
780+ )
781+ else :
782+ tree_memory_size = 0
783+ for request in batch_requests :
784+ serialize_ops += [
785+ _logging_serialize (
786+ request .handler ,
787+ request .handler .serialize (
788+ request .values , request .infos , request .args
789+ ),
790+ )
791+ ]
792+ write_size , _ = _get_batch_memory_size (
793+ request .handler , request .values
794+ )
795+ tree_memory_size += write_size
796+ # Await copy futures. Returns List[List[future.Future]].
797+ commit_futures = await asyncio .gather (* serialize_ops )
798+ logging .info (
799+ 'MemoryRegulated: Peak usage: %f GiB' ,
800+ memory_regulator .profiler_peak_usage_gib (),
801+ )
722802 # Flatten to List[future.Future].
723803 commit_futures , _ = jax .tree .flatten (commit_futures )
724804
0 commit comments