@@ -37,7 +37,7 @@ def __init__(self, run_command="mpiexec", platform_info=None):
3737 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
3838 self .arg_ppn = ("--LIBE_PPN_ARG_EMPTY" ,)
3939 self .default_mpi_options = None
40- self .default_gpu_arg = None
40+ self .default_gpu_args = None
4141 self .default_gpu_arg_type = None
4242 self .platform_info = platform_info
4343
@@ -126,16 +126,32 @@ def _set_gpu_env_var(self, wresources, task, gpus_per_node, gpus_env):
126126
127127 def _local_runner_set_gpus (self , task , wresources , extra_args , gpus_per_node , ppn ):
128128 """Set default GPU setting for MPI runner"""
129- if self .default_gpu_arg is not None :
130- arg_type = self .default_gpu_arg_type
129+
130+ arg_type = self .default_gpu_arg_type
131+ if arg_type is not None :
131132 gpu_value = gpus_per_node // ppn if arg_type == "option_gpus_per_task" else gpus_per_node
132- gpu_setting_name = self .default_gpu_arg
133+ gpu_setting_name = self .default_gpu_args [arg_type ]
134+ jassert (gpu_setting_name is not None , f"No default gpu_setting_name for { arg_type } " )
133135 extra_args = self ._set_gpu_cli_option (wresources , extra_args , gpu_setting_name , gpu_value )
134136 else :
135137 gpus_env = "CUDA_VISIBLE_DEVICES"
136138 self ._set_gpu_env_var (wresources , task , gpus_per_node , gpus_env )
137139 return extra_args
138140
141+ def _get_default_arg (self , gpu_setting_type ):
142+ """Return default setting for the given gpu_setting_type if it exists, else error"""
143+ jassert (
144+ gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ],
145+ f"Unrecognized gpu_setting_type { gpu_setting_type } " ,
146+ )
147+ jassert (
148+ self .default_gpu_args is not None ,
149+ "The current MPI runner has no default command line option for setting GPUs" ,
150+ )
151+ gpu_setting_name = self .default_gpu_args [gpu_setting_type ]
152+ jassert (gpu_setting_name is not None , f"No default GPU setting for { gpu_setting_type } " )
153+ return gpu_setting_name
154+
139155 def _assign_gpus (self , task , resources , nprocs , nnodes , ppn , ngpus , extra_args , match_procs_to_gpus ):
140156 """Assign GPU resources to slots, limited by ngpus if present.
141157
@@ -199,7 +215,7 @@ def _assign_gpus(self, task, resources, nprocs, nnodes, ppn, ngpus, extra_args,
199215
200216 elif gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ]:
201217 gpu_value = gpus_per_node // ppn if gpu_setting_type == "option_gpus_per_task" else gpus_per_node
202- gpu_setting_name = self .platform_info .get ("gpu_setting_name" , self .default_gpu_arg )
218+ gpu_setting_name = self .platform_info .get ("gpu_setting_name" , self ._get_default_arg ( gpu_setting_type ) )
203219 extra_args = self ._set_gpu_cli_option (wresources , extra_args , gpu_setting_name , gpu_value )
204220
205221 elif gpu_setting_type == "env" :
@@ -319,7 +335,7 @@ def __init__(self, run_command="mpirun", platform_info=None):
319335 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
320336 self .arg_ppn = ("--ppn" , "-ppn" )
321337 self .default_mpi_options = None
322- self .default_gpu_arg = None
338+ self .default_gpu_args = None
323339 self .default_gpu_arg_type = None
324340 self .platform_info = platform_info
325341
@@ -343,7 +359,7 @@ def __init__(self, run_command="mpirun", platform_info=None):
343359 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
344360 self .arg_ppn = ("-npernode" ,)
345361 self .default_mpi_options = None
346- self .default_gpu_arg = None
362+ self .default_gpu_args = None
347363 self .default_gpu_arg_type = None
348364 self .platform_info = platform_info
349365 self .mpi_command = [
@@ -388,7 +404,7 @@ def __init__(self, run_command="aprun", platform_info=None):
388404 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
389405 self .arg_ppn = ("-N" ,)
390406 self .default_mpi_options = None
391- self .default_gpu_arg = None
407+ self .default_gpu_args = None
392408 self .default_gpu_arg_type = None
393409 self .platform_info = platform_info
394410 self .mpi_command = [
@@ -410,7 +426,7 @@ def __init__(self, run_command="mpiexec", platform_info=None):
410426 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
411427 self .arg_ppn = ("-cores" ,)
412428 self .default_mpi_options = None
413- self .default_gpu_arg = None
429+ self .default_gpu_args = None
414430 self .default_gpu_arg_type = None
415431 self .platform_info = platform_info
416432 self .mpi_command = [
@@ -431,8 +447,9 @@ def __init__(self, run_command="srun", platform_info=None):
431447 self .arg_nnodes = ("-N" , "--nodes" )
432448 self .arg_ppn = ("--ntasks-per-node" ,)
433449 self .default_mpi_options = "--exact"
434- self .default_gpu_arg = "--gpus-per-task"
435450 self .default_gpu_arg_type = "option_gpus_per_task"
451+ self .default_gpu_args = {"option_gpus_per_task" : "--gpus-per-task" , "option_gpus_per_node" : "--gpus-per-node" }
452+
436453 self .platform_info = platform_info
437454 self .mpi_command = [
438455 self .run_command ,
@@ -453,8 +470,8 @@ def __init__(self, run_command="jsrun", platform_info=None):
453470 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
454471 self .arg_ppn = ("-r" ,)
455472 self .default_mpi_options = None
456- self .default_gpu_arg = "-g"
457473 self .default_gpu_arg_type = "option_gpus_per_task"
474+ self .default_gpu_args = {"option_gpus_per_task" : "-g" , "option_gpus_per_node" : None }
458475
459476 self .platform_info = platform_info
460477 self .mpi_command = [self .run_command , "-n {num_procs}" , "-r {procs_per_node}" , "{extra_args}" ]
0 commit comments