@@ -131,6 +131,7 @@ def _local_runner_set_gpus(self, task, wresources, extra_args, gpus_per_node, pp
131131 if arg_type is not None :
132132 gpu_value = gpus_per_node // ppn if arg_type == "option_gpus_per_task" else gpus_per_node
133133 gpu_setting_name = self .default_gpu_args [arg_type ]
134+ jassert (gpu_setting_name is not None , f"No default gpu_setting_name for { arg_type } " )
134135 extra_args = self ._set_gpu_cli_option (wresources , extra_args , gpu_setting_name , gpu_value )
135136 else :
136137 gpus_env = "CUDA_VISIBLE_DEVICES"
@@ -139,7 +140,14 @@ def _local_runner_set_gpus(self, task, wresources, extra_args, gpus_per_node, pp
139140
140141 def _get_default_arg (self , gpu_setting_type ):
141142 """Return default setting for the given gpu_setting_type if it exists, else error"""
142- assert gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ]
143+ jassert (
144+ gpu_setting_type in ["option_gpus_per_node" , "option_gpus_per_task" ],
145+ f"Unrecognized gpu_setting_type { gpu_setting_type } " ,
146+ )
147+ jassert (
148+ self .default_gpu_args is not None ,
149+ "The current MPI runner has no default command line option for setting GPUs" ,
150+ )
143151 gpu_setting_name = self .default_gpu_args [gpu_setting_type ]
144152 jassert (gpu_setting_name is not None , f"No default GPU setting for { gpu_setting_type } " )
145153 return gpu_setting_name
@@ -464,7 +472,7 @@ def __init__(self, run_command="jsrun", platform_info=None):
464472 self .arg_ppn = ("-r" ,)
465473 self .default_mpi_options = None
466474 self .default_gpu_arg_type = "option_gpus_per_task"
467- self .default_gpu_args = {"option_gpus_per_task" : None , "option_gpus_per_node" : "-g" }
475+ self .default_gpu_args = {"option_gpus_per_task" : "-g" , "option_gpus_per_node" : None }
468476
469477 self .platform_info = platform_info
470478 self .mpi_command = [self .run_command , "-n {num_procs}" , "-r {procs_per_node}" , "{extra_args}" ]
0 commit comments