55"""
66# std imports
77from argparse import ArgumentParser
8+ import contextlib
89import json
910import logging
1011import os
11- import tempfile
1212from typing import Optional
1313
1414# tpl imports
@@ -30,6 +30,7 @@ def get_args():
3030 parser .add_argument ("input_json" , type = str , help = "Input JSON file containing the test cases." )
3131 parser .add_argument ("-o" , "--output" , type = str , help = "Output JSON file containing the results." )
3232 parser .add_argument ("--scratch-dir" , type = str , help = "If provided, put scratch files here." )
33+ parser .add_argument ("--driver-root" , type = str , help = "Where to look for the driver files, if not in cwd." )
3334 parser .add_argument ("--launch-configs" , type = str , default = "launch-configs.json" ,
3435 help = "config for how to run samples." )
3536 parser .add_argument ("--build-configs" , type = str , default = "build-configs.json" ,
@@ -58,7 +59,15 @@ def get_args():
5859 parser .add_argument ("--log-runs" , action = "store_true" , help = "Display the stderr and stdout of runs." )
5960 return parser .parse_args ()
6061
61- def get_driver (prompt : dict , scratch_dir : Optional [os .PathLike ], launch_configs : dict , build_configs : dict , problem_sizes : dict , dry : bool , ** kwargs ) -> DriverWrapper :
62+ def get_driver (
63+ prompt : dict ,
64+ scratch_dir : Optional [os .PathLike ],
65+ launch_configs : dict ,
66+ build_configs : dict ,
67+ problem_sizes : dict ,
68+ dry : bool ,
69+ ** kwargs
70+ ) -> DriverWrapper :
6271 """ Get the language drive wrapper for this prompt """
6372 driver_cls = LANGUAGE_DRIVERS [prompt ["language" ]]
6473 return driver_cls (parallelism_model = prompt ["parallelism_model" ], launch_configs = launch_configs ,
@@ -112,6 +121,17 @@ def main():
112121 problem_sizes = load_json (args .problem_sizes )
113122 logging .info (f"Loaded problem sizes from { args .problem_sizes } ." )
114123
124+ # set driver root; If provided, use user argument. If it's not provided, then check if the PAREVAL_ROOT environment
125+ # variable is set, then use "${PAREVAL_ROOT}/drivers" as the root. If neither is set, then use the location of
126+ # this script as the root.
127+ if args .driver_root :
128+ DRIVER_ROOT = args .driver_root
129+ elif "PAREVAL_ROOT" in os .environ :
130+ DRIVER_ROOT = os .path .join (os .environ ["PAREVAL_ROOT" ], "drivers" )
131+ else :
132+ DRIVER_ROOT = os .path .dirname (os .path .abspath (__file__ ))
133+ logging .info (f"Using driver root: { DRIVER_ROOT } " )
134+
115135 # gather the list of parallelism models to test
116136 models_to_test = args .include_models if args .include_models else ["serial" , "omp" , "mpi" , "mpi+omp" , "kokkos" , "cuda" , "hip" ]
117137 if args .exclude_models :
@@ -152,9 +172,11 @@ def main():
152172 display_runs = args .log_runs ,
153173 early_exit_runs = args .early_exit_runs ,
154174 build_timeout = args .build_timeout ,
155- run_timeout = args .run_timeout
175+ run_timeout = args .run_timeout ,
156176 )
157- driver .test_all_outputs_in_prompt (prompt )
177+
178+ with contextlib .chdir (DRIVER_ROOT ):
179+ driver .test_all_outputs_in_prompt (prompt )
158180
159181 # go ahead and write out outputs now
160182 if args .output and args .output != '-' :
0 commit comments