@@ -37,7 +37,7 @@ def __init__(self, run_command="mpiexec", platform_info=None):
3737 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
3838 self .arg_ppn = ("--LIBE_PPN_ARG_EMPTY" ,)
3939 self .default_mpi_options = None
40- self .default_gpu_arg = None
40+ self .default_gpu_args = None
4141 self .default_gpu_arg_type = None
4242 self .platform_info = platform_info
4343
@@ -126,16 +126,24 @@ def _set_gpu_env_var(self, wresources, task, gpus_per_node, gpus_env):
126126
127127 def _local_runner_set_gpus (self , task , wresources , extra_args , gpus_per_node , ppn ):
128128 """Set default GPU setting for MPI runner"""
129- if self .default_gpu_arg is not None :
130- arg_type = self .default_gpu_arg_type
129+
130+ arg_type = self .default_gpu_arg_type
131+ if arg_type is not None :
131132 gpu_value = gpus_per_node // ppn if arg_type == "option_gpus_per_task" else gpus_per_node
132- gpu_setting_name = self .default_gpu_arg
133+ gpu_setting_name = self .default_gpu_args [ arg_type ]
133134 extra_args = self ._set_gpu_cli_option (wresources , extra_args , gpu_setting_name , gpu_value )
134135 else :
135136 gpus_env = "CUDA_VISIBLE_DEVICES"
136137 self ._set_gpu_env_var (wresources , task , gpus_per_node , gpus_env )
137138 return extra_args
138139
140+ def _get_default_arg (self , gpu_setting_type ):
141+ """Return default setting for the given gpu_setting_type if it exists, else error"""
142+ assert gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ]
143+ gpu_setting_name = self .default_gpu_args [gpu_setting_type ]
144+ jassert (gpu_setting_name is not None , f"No default GPU setting for { gpu_setting_type } " )
145+ return gpu_setting_name
146+
139147 def _assign_gpus (self , task , resources , nprocs , nnodes , ppn , ngpus , extra_args , match_procs_to_gpus ):
140148 """Assign GPU resources to slots, limited by ngpus if present.
141149
@@ -199,7 +207,8 @@ def _assign_gpus(self, task, resources, nprocs, nnodes, ppn, ngpus, extra_args,
199207
200208 elif gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ]:
201209 gpu_value = gpus_per_node // ppn if gpu_setting_type == "option_gpus_per_task" else gpus_per_node
202- gpu_setting_name = self .platform_info .get ("gpu_setting_name" , self .default_gpu_arg )
210+ print (f"{ gpu_setting_type = } " )
211+ gpu_setting_name = self .platform_info .get ("gpu_setting_name" , self ._get_default_arg (gpu_setting_type ))
203212 extra_args = self ._set_gpu_cli_option (wresources , extra_args , gpu_setting_name , gpu_value )
204213
205214 elif gpu_setting_type == "env" :
@@ -319,7 +328,7 @@ def __init__(self, run_command="mpirun", platform_info=None):
319328 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
320329 self .arg_ppn = ("--ppn" , "-ppn" )
321330 self .default_mpi_options = None
322- self .default_gpu_arg = None
331+ self .default_gpu_args = None
323332 self .default_gpu_arg_type = None
324333 self .platform_info = platform_info
325334
@@ -343,7 +352,7 @@ def __init__(self, run_command="mpirun", platform_info=None):
343352 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
344353 self .arg_ppn = ("-npernode" ,)
345354 self .default_mpi_options = None
346- self .default_gpu_arg = None
355+ self .default_gpu_args = None
347356 self .default_gpu_arg_type = None
348357 self .platform_info = platform_info
349358 self .mpi_command = [
@@ -388,7 +397,7 @@ def __init__(self, run_command="aprun", platform_info=None):
388397 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
389398 self .arg_ppn = ("-N" ,)
390399 self .default_mpi_options = None
391- self .default_gpu_arg = None
400+ self .default_gpu_args = None
392401 self .default_gpu_arg_type = None
393402 self .platform_info = platform_info
394403 self .mpi_command = [
@@ -410,7 +419,7 @@ def __init__(self, run_command="mpiexec", platform_info=None):
410419 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
411420 self .arg_ppn = ("-cores" ,)
412421 self .default_mpi_options = None
413- self .default_gpu_arg = None
422+ self .default_gpu_args = None
414423 self .default_gpu_arg_type = None
415424 self .platform_info = platform_info
416425 self .mpi_command = [
@@ -431,8 +440,9 @@ def __init__(self, run_command="srun", platform_info=None):
431440 self .arg_nnodes = ("-N" , "--nodes" )
432441 self .arg_ppn = ("--ntasks-per-node" ,)
433442 self .default_mpi_options = "--exact"
434- self .default_gpu_arg = "--gpus-per-task"
435443 self .default_gpu_arg_type = "option_gpus_per_task"
444+ self .default_gpu_args = {"option_gpus_per_task" : "--gpus-per-task" , "option_gpus_per_node" : "--gpus-per-node" }
445+
436446 self .platform_info = platform_info
437447 self .mpi_command = [
438448 self .run_command ,
@@ -453,8 +463,8 @@ def __init__(self, run_command="jsrun", platform_info=None):
453463 self .arg_nnodes = ("--LIBE_NNODES_ARG_EMPTY" ,)
454464 self .arg_ppn = ("-r" ,)
455465 self .default_mpi_options = None
456- self .default_gpu_arg = "-g"
457466 self .default_gpu_arg_type = "option_gpus_per_task"
467+ self .default_gpu_args = {"option_gpus_per_task" : None , "option_gpus_per_node" : "-g" }
458468
459469 self .platform_info = platform_info
460470 self .mpi_command = [self .run_command , "-n {num_procs}" , "-r {procs_per_node}" , "{extra_args}" ]
0 commit comments