Skip to content

Commit fb66e6a

Browse files
Nikhil BansalOrbax Authors
authored andcommitted
Simplify JAX distributed initialization in benchmarks.
PiperOrigin-RevId: 895163261
1 parent be05482 commit fb66e6a

4 files changed

Lines changed: 54 additions & 50 deletions

File tree

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/checkpoint_manager_benchmark.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ mesh_configs:
44
- mesh_axes: ["data", "model"]
55
ici_parallelism: {"data": 2, "model": 2}
66
- mesh_axes: ["data", "model"]
7-
ici_parallelism: {"data": 1, "model": 1}
8-
dcn_parallelism: {"data": 4, "model": 1}
7+
ici_parallelism: {"data": 16, "model": 1}
8+
dcn_parallelism: {"data": 2, "model": 1}
99

1010
checkpoint_config:
1111
spec:
@@ -16,4 +16,4 @@ benchmarks:
1616
options:
1717
save_interval_steps: [1]
1818
max_to_keep: [1]
19-
train_steps: [1]
19+
train_steps: [100]

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/pytree_checkpoint_benchmark.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mesh_configs:
88
process_is_granule: True
99
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
1010
# DCN: Across processes assuming 4 processes.
11-
dcn_parallelism: {"tensor": 1, "data": 1}
11+
dcn_parallelism: {"tensor": 1, "data": 2}
1212
ici_parallelism: {"tensor": 16, "data": 1}
1313

1414
# The checkpoint configuration, shared across all generated benchmarks.
@@ -37,4 +37,4 @@ benchmarks:
3737
# - use_ocdbt
3838
# - use_zarr3
3939
use_ocdbt: [True]
40-
use_zarr3: [True, False]
40+
use_zarr3: [False]

checkpoint/orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,26 +166,28 @@ def test_fn(
166166
checkpointer = ocp.Checkpointer(handler)
167167
metrics_to_measure = _metrics_to_measure(options)
168168

169-
with metrics.measure("save", metrics_to_measure):
170-
checkpointer.save(save_path, args=ocp.args.PyTreeSave(pytree))
171-
172-
if options.async_enabled:
173-
with metrics.measure("wait_until_finished", metrics_to_measure):
174-
assert hasattr(checkpointer, "wait_until_finished")
175-
checkpointer.wait_until_finished()
176-
177-
context.pytree = self._clear_pytree(context.pytree)
178-
179-
with metrics.measure("restore", metrics_to_measure):
180-
restored_pytree = checkpointer.restore(
181-
save_path,
182-
args=ocp.args.PyTreeRestore(
183-
item=pytree,
184-
restore_args=ocp.checkpoint_utils.construct_restore_args(pytree),
185-
),
186-
)
187-
188-
self._clear_pytree(restored_pytree)
169+
for i in range(300):
170+
save_path = context.path / "pytree" / str(i)
171+
with metrics.measure("save", metrics_to_measure):
172+
checkpointer.save(
173+
save_path, args=ocp.args.PyTreeSave(pytree)
174+
)
175+
176+
if options.async_enabled:
177+
with metrics.measure("wait_until_finished", metrics_to_measure):
178+
assert hasattr(checkpointer, "wait_until_finished")
179+
checkpointer.wait_until_finished()
180+
181+
with metrics.measure("restore", metrics_to_measure):
182+
_ = checkpointer.restore(
183+
save_path,
184+
args=ocp.args.PyTreeRestore(
185+
item=pytree,
186+
restore_args=ocp.checkpoint_utils.construct_restore_args(
187+
pytree
188+
),
189+
),
190+
)
189191

190192
checkpointer.close()
191193
return benchmarks_core.TestResult(metrics=metrics)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from etils import epath
3131
import jax
3232
from orbax.checkpoint._src.testing.benchmarks.core import config_parsing
33+
import pathwaysutils
3334

3435

3536
# Core Flags
@@ -63,29 +64,30 @@ def _init_jax_distributed():
6364
"""Initializes JAX distributed system if not managed by XManager."""
6465

6566
try:
66-
jax_platforms = os.environ.get('JAX_PLATFORMS')
67-
jax_coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS')
68-
jax_process_id = os.environ.get('JAX_PROCESS_ID')
69-
jax_num_processes = os.environ.get('JAX_NUM_PROCESSES')
70-
jax_coordinator_port = os.environ.get('JAX_COORDINATOR_PORT')
71-
logging.info('JAX_PLATFORMS: %s', jax_platforms)
72-
logging.info(
73-
'JAX_COORDINATOR_ADDRESS: %s',
74-
jax_coordinator_address,
75-
)
76-
logging.info('JAX_PROCESS_ID: %s', jax_process_id)
77-
logging.info('JAX_NUM_PROCESSES: %s', jax_num_processes)
78-
logging.info('JAX_COORDINATOR_PORT: %s', jax_coordinator_port)
79-
if jax_num_processes is not None:
80-
jax_num_processes = int(jax_num_processes)
81-
if jax_process_id is not None:
82-
jax_process_id = int(jax_process_id)
83-
jax.distributed.initialize(
84-
coordinator_address=jax_coordinator_address,
85-
num_processes=jax_num_processes,
86-
process_id=jax_process_id,
87-
initialization_timeout=600,
88-
)
67+
# jax_platforms = os.environ.get('JAX_PLATFORMS')
68+
# jax_coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS')
69+
# jax_process_id = os.environ.get('JAX_PROCESS_ID')
70+
# jax_num_processes = os.environ.get('JAX_NUM_PROCESSES')
71+
# jax_coordinator_port = os.environ.get('JAX_COORDINATOR_PORT')
72+
# logging.info('JAX_PLATFORMS: %s', jax_platforms)
73+
# logging.info(
74+
# 'JAX_COORDINATOR_ADDRESS: %s',
75+
# jax_coordinator_address,
76+
# )
77+
# logging.info('JAX_PROCESS_ID: %s', jax_process_id)
78+
# logging.info('JAX_NUM_PROCESSES: %s', jax_num_processes)
79+
# logging.info('JAX_COORDINATOR_PORT: %s', jax_coordinator_port)
80+
# if jax_num_processes is not None:
81+
# jax_num_processes = int(jax_num_processes)
82+
# if jax_process_id is not None:
83+
# jax_process_id = int(jax_process_id)
84+
# jax.distributed.initialize(
85+
# coordinator_address=jax_coordinator_address,
86+
# num_processes=jax_num_processes,
87+
# process_id=jax_process_id,
88+
# initialization_timeout=600,
89+
# )
90+
pathwaysutils.initialize()
8991
logging.info('JAX distributed system initialized.')
9092
except Exception as e: # pylint: disable=broad-exception-caught
9193
logging.warning(
@@ -96,8 +98,8 @@ def _init_jax_distributed():
9698
exc_info=True,
9799
)
98100

99-
logging.info('Default JAX backend: %s', jax.default_backend())
100-
logging.info('Available devices: %s', jax.devices())
101+
# logging.info('Default JAX backend: %s', jax.default_backend())
102+
# logging.info('Available devices: %s', jax.devices())
101103
logging.info('JAX process index: %d', jax.process_index())
102104
logging.info('JAX process count: %d', jax.process_count())
103105
logging.info('JAX device count: %d', jax.device_count())

0 commit comments

Comments
 (0)