Skip to content

Commit 6bed1ee

Browse files
lyglstOrbax Authors
authored andcommitted
Add an option to clean the directory after each benchmark run.
PiperOrigin-RevId: 897288357
1 parent 7ad416d commit 6bed1ee

4 files changed

Lines changed: 52 additions & 6 deletions

File tree

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/config_parsing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def create_test_suite_from_config(
107107
config_path: str,
108108
output_dir: str | None = None,
109109
local_directory: str | None = None,
110+
remove_repeated_dir: bool = False,
110111
) -> core.TestSuite:
111112
"""Creates a single TestSuite object from the benchmark configuration.
112113
@@ -116,6 +117,8 @@ def create_test_suite_from_config(
116117
results will be stored in a temporary directory.
117118
local_directory: Optional local directory for benchmark results. This is
118119
used for ECM benchmarks.
120+
remove_repeated_dir: Whether to remove the generated repeat_* directories
121+
after execution.
119122
120123
Returns:
121124
A TestSuite object containing all benchmarks generated from the config.
@@ -200,4 +203,5 @@ def create_test_suite_from_config(
200203
num_repeats=num_repeats,
201204
output_dir=output_dir,
202205
local_directory=local_directory,
206+
remove_repeated_dir=remove_repeated_dir,
203207
)

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class TestResult:
9898
error: Exception | None = (
9999
None # The error raised during the test run, if any.
100100
)
101+
path: epath.Path | None = None
102+
local_path: epath.Path | None = None
101103

102104
def is_successful(self) -> bool:
103105
"""Returns whether the test run was successful."""
@@ -203,6 +205,8 @@ def run(self, repeat_index: int | None = None) -> TestResult:
203205
)
204206
try:
205207
result = self.test_fn(context)
208+
result.path = path
209+
result.local_path = local_path
206210
except Exception as e: # pylint: disable=broad-exception-caught
207211
# We catch all exceptions to ensure that any error during the test
208212
# execution is recorded in the TestResult.
@@ -220,7 +224,12 @@ def run(self, repeat_index: int | None = None) -> TestResult:
220224
e,
221225
exc_info=True,
222226
)
223-
result = TestResult(metrics=metric_lib.Metrics(), error=e)
227+
result = TestResult(
228+
metrics=metric_lib.Metrics(),
229+
error=e,
230+
path=path,
231+
local_path=local_path,
232+
)
224233
result.metrics.name = name
225234

226235
result.metrics.report()
@@ -401,13 +410,15 @@ def __init__(
401410
skip_incompatible_mesh_configs: bool = True,
402411
num_repeats: int = 1,
403412
local_directory: str | None = None,
413+
remove_repeated_dir: bool = False,
404414
):
405415
self._name = name
406416
self._benchmarks_generators = benchmarks_generators
407417
self._skip_incompatible_mesh_configs = skip_incompatible_mesh_configs
408418
self._num_repeats = num_repeats
409419
self._output_dir = output_dir
410420
self._local_directory = local_directory
421+
self._remove_repeated_dir = remove_repeated_dir
411422
tensorboard_dir = None
412423
if output_dir:
413424
tensorboard_dir = epath.Path(output_dir) / "tensorboard"
@@ -459,6 +470,10 @@ def run(self) -> Sequence[TestResult]:
459470
checkpoint_config=benchmark.checkpoint_config,
460471
error=result.error,
461472
)
473+
if self._remove_repeated_dir:
474+
multihost.sync_global_processes("test_suite:repeat_cleanup")
475+
self._remove_repeat_directory(result.path)
476+
self._remove_repeat_directory(result.local_path)
462477

463478
if not all_results:
464479
logging.warning("No benchmarks were run for this suite.")

checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
False,
5656
'Enables HLO dumping to a subdirectory within --output_directory.',
5757
)
58+
_REMOVE_REPEATED_DIR = flags.DEFINE_bool(
59+
'remove_repeated_dir',
60+
False,
61+
'Remove the generated repeat_* directories after execution.',
62+
)
5863

5964

6065

@@ -126,7 +131,10 @@ def _configure_hlo_dump(output_directory: str):
126131

127132

128133
def _run_benchmarks(
129-
config_file: str, output_directory: str, local_directory: str | None = None
134+
config_file: str,
135+
output_directory: str,
136+
local_directory: str | None = None,
137+
remove_repeated_dir: bool = False,
130138
) -> None:
131139
"""Runs Orbax checkpoint benchmarks based on a generator class and a config file.
132140
@@ -135,6 +143,8 @@ def _run_benchmarks(
135143
output_directory: Directory to store benchmark results in.
136144
local_directory: Local directory for benchmark results. This is used for ECM
137145
benchmarks.
146+
remove_repeated_dir: Whether to remove the generated repeat_* directories
147+
after execution.
138148
139149
Raises:
140150
RuntimeError: If any benchmark test fails.
@@ -155,6 +165,7 @@ def _run_benchmarks(
155165
config_file,
156166
output_dir=output_directory,
157167
local_directory=local_directory,
168+
remove_repeated_dir=remove_repeated_dir,
158169
)
159170
except Exception as e:
160171
logging.error('Failed to create test suite from config: %s', e)
@@ -201,7 +212,10 @@ def main(argv: List[str]) -> None:
201212
logging.info('Set jax_enable_x64=True')
202213

203214
_run_benchmarks(
204-
_CONFIG_FILE.value, _OUTPUT_DIRECTORY.value, _LOCAL_DIRECTORY.value
215+
_CONFIG_FILE.value,
216+
_OUTPUT_DIRECTORY.value,
217+
_LOCAL_DIRECTORY.value,
218+
remove_repeated_dir=_REMOVE_REPEATED_DIR.value,
205219
)
206220

207221
logging.info('run_benchmarks.py finished.')

checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks_pytorch.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
'Output directory for benchmark results.',
4040
required=True,
4141
)
42+
_REMOVE_REPEATED_DIR = flags.DEFINE_bool(
43+
'remove_repeated_dir',
44+
False,
45+
'Remove the generated repeat_* directories after execution.',
46+
)
4247

4348

4449

@@ -73,7 +78,9 @@ def _init_torch_distributed() -> None:
7378
raise
7479

7580

76-
def _run_benchmarks(config_file: str, output_directory: str) -> None:
81+
def _run_benchmarks(
82+
config_file: str, output_directory: str, remove_repeated_dir: bool = False
83+
) -> None:
7784
"""Runs the benchmarks."""
7885
logging.info('Running benchmarks from config: %s', config_file)
7986
logging.info('Output directory: %s', output_directory)
@@ -89,7 +96,9 @@ def _run_benchmarks(config_file: str, output_directory: str) -> None:
8996

9097
try:
9198
test_suite = config_parsing.create_test_suite_from_config(
92-
config_file, output_dir=output_directory
99+
config_file,
100+
output_dir=output_directory,
101+
remove_repeated_dir=remove_repeated_dir,
93102
)
94103
except Exception as e:
95104
logging.error('Failed to create test suite from config: %s', e)
@@ -119,7 +128,11 @@ def main(argv: Sequence[str]) -> None:
119128

120129
logging.info('run_benchmarks_pytorch.py started.')
121130
_init_torch_distributed()
122-
_run_benchmarks(_CONFIG_FILE.value, _OUTPUT_DIRECTORY.value)
131+
_run_benchmarks(
132+
_CONFIG_FILE.value,
133+
_OUTPUT_DIRECTORY.value,
134+
remove_repeated_dir=_REMOVE_REPEATED_DIR.value,
135+
)
123136
logging.info('run_benchmarks_pytorch.py finished.')
124137

125138

0 commit comments

Comments
 (0)