Skip to content

Commit 8ec7d7a

Browse files
author
Daniel Nichols
committed
run drivers in driver root for relocatable runs
1 parent 2ee65c6 commit 8ec7d7a

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

drivers/run-all.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
"""
66
# std imports
77
from argparse import ArgumentParser
8+
import contextlib
89
import json
910
import logging
1011
import os
11-
import tempfile
1212
from typing import Optional
1313

1414
# tpl imports
@@ -30,6 +30,7 @@ def get_args():
3030
parser.add_argument("input_json", type=str, help="Input JSON file containing the test cases.")
3131
parser.add_argument("-o", "--output", type=str, help="Output JSON file containing the results.")
3232
parser.add_argument("--scratch-dir", type=str, help="If provided, put scratch files here.")
33+
parser.add_argument("--driver-root", type=str, help="Where to look for the driver files, if not in cwd.")
3334
parser.add_argument("--launch-configs", type=str, default="launch-configs.json",
3435
help="config for how to run samples.")
3536
parser.add_argument("--build-configs", type=str, default="build-configs.json",
@@ -58,7 +59,15 @@ def get_args():
5859
parser.add_argument("--log-runs", action="store_true", help="Display the stderr and stdout of runs.")
5960
return parser.parse_args()
6061

61-
def get_driver(prompt: dict, scratch_dir: Optional[os.PathLike], launch_configs: dict, build_configs: dict, problem_sizes: dict, dry: bool, **kwargs) -> DriverWrapper:
62+
def get_driver(
63+
prompt: dict,
64+
scratch_dir: Optional[os.PathLike],
65+
launch_configs: dict,
66+
build_configs: dict,
67+
problem_sizes: dict,
68+
dry: bool,
69+
**kwargs
70+
) -> DriverWrapper:
6271
""" Get the language drive wrapper for this prompt """
6372
driver_cls = LANGUAGE_DRIVERS[prompt["language"]]
6473
return driver_cls(parallelism_model=prompt["parallelism_model"], launch_configs=launch_configs,
@@ -112,6 +121,17 @@ def main():
112121
problem_sizes = load_json(args.problem_sizes)
113122
logging.info(f"Loaded problem sizes from {args.problem_sizes}.")
114123

124+
# set driver root; If provided, use user argument. If it's not provided, then check if the PAREVAL_ROOT environment
125+
# variable is set, then use "${PAREVAL_ROOT}/drivers" as the root. If neither is set, then use the location of
126+
# this script as the root.
127+
if args.driver_root:
128+
DRIVER_ROOT = args.driver_root
129+
elif "PAREVAL_ROOT" in os.environ:
130+
DRIVER_ROOT = os.path.join(os.environ["PAREVAL_ROOT"], "drivers")
131+
else:
132+
DRIVER_ROOT = os.path.dirname(os.path.abspath(__file__))
133+
logging.info(f"Using driver root: {DRIVER_ROOT}")
134+
115135
# gather the list of parallelism models to test
116136
models_to_test = args.include_models if args.include_models else ["serial", "omp", "mpi", "mpi+omp", "kokkos", "cuda", "hip"]
117137
if args.exclude_models:
@@ -152,9 +172,11 @@ def main():
152172
display_runs=args.log_runs,
153173
early_exit_runs=args.early_exit_runs,
154174
build_timeout=args.build_timeout,
155-
run_timeout=args.run_timeout
175+
run_timeout=args.run_timeout,
156176
)
157-
driver.test_all_outputs_in_prompt(prompt)
177+
178+
with contextlib.chdir(DRIVER_ROOT):
179+
driver.test_all_outputs_in_prompt(prompt)
158180

159181
# go ahead and write out outputs now
160182
if args.output and args.output != '-':

0 commit comments

Comments
 (0)