@@ -167,7 +167,14 @@ def validate_config(args):
167167 assert minibatch_size <= horizon * total_agents , \
168168 f'minibatch_size { minibatch_size } > total_agents { total_agents } * horizon { horizon } '
169169
170- def _train_worker (args , backend = _C ):
170+ def _resolve_backend (args ):
171+ if args .get ('slowly' ):
172+ from pufferlib .torch_pufferl import PuffeRL
173+ return PuffeRL
174+ return _C
175+
176+ def _train_worker (args ):
177+ backend = _resolve_backend (args )
171178 pufferl = backend .create_pufferl (args )
172179 args .pop ('nccl_id' , None )
173180 while pufferl .global_step < args ['train' ]['total_timesteps' ]:
@@ -176,8 +183,9 @@ def _train_worker(args, backend=_C):
176183
177184 backend .close (pufferl )
178185
179- def _train (env_name , args , backend = _C , sweep_obj = None , result_queue = None , verbose = False ):
186+ def _train (env_name , args , sweep_obj = None , result_queue = None , verbose = False ):
180187 '''Single-GPU training worker. Process target for both DDP ranks and sweep trials.'''
188+ backend = _resolve_backend (args )
181189 rank = args ['rank' ]
182190 run_id = str (int (1000 * time .time ()))
183191 if args ['wandb' ]:
@@ -302,7 +310,7 @@ def _train(env_name, args, backend=_C, sweep_obj=None, result_queue=None, verbos
302310 if result_queue is not None :
303311 result_queue .put ((args ['gpu_id' ], metrics ['env/score' ], metrics ['uptime' ], metrics ['agent_steps' ]))
304312
305- def train (env_name , args = None , gpus = None , backend = _C , ** kwargs ):
313+ def train (env_name , args = None , gpus = None , ** kwargs ):
306314 args = args or load_config (env_name )
307315 validate_config (args )
308316
@@ -321,12 +329,12 @@ def train(env_name, args=None, gpus=None, backend=_C, **kwargs):
321329 worker_args ['rank' ] = rank
322330 worker_args ['gpu_id' ] = gpu_id
323331 if rank == 0 and not subprocess :
324- _train (env_name , worker_args , backend = backend , verbose = True )
332+ _train (env_name , worker_args , verbose = True )
325333 else :
326334 ctx .Process (target = _train , args = (env_name , worker_args ),
327- kwargs = { ** kwargs , 'backend' : backend } ).start ()
335+ kwargs = kwargs ).start ()
328336
329- def sweep (env_name , args = None , pareto = False , backend = _C ):
337+ def sweep (env_name , args = None , pareto = False ):
330338 '''Train entry point. Handles single-GPU, multi-GPU DDP, and sweeps.'''
331339 args = args or load_config (env_name )
332340 exp_gpus = args ['train' ]['gpus' ]
@@ -384,7 +392,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
384392 exp_args = deepcopy (args )
385393 active [gpu_id ] = exp_args
386394 train (env_name , exp_args , range (gpu_id , gpu_id + exp_gpus ),
387- backend = backend , sweep_obj = sweep_obj , result_queue = result_queue )
395+ sweep_obj = sweep_obj , result_queue = result_queue )
388396
389397def eval (env_name , args = None , load_path = None ):
390398 '''Evaluate a trained policy using the native pipeline.
@@ -490,17 +498,12 @@ def main():
490498 env_name = sys .argv .pop (1 )
491499 args = load_config (env_name )
492500
493- backend = _C
494- if args .get ('slowly' ):
495- from pufferlib .torch_pufferl import PuffeRL
496- backend = PuffeRL
497-
498501 if 'train' in mode :
499- train (env_name = env_name , args = args , backend = backend )
502+ train (env_name = env_name , args = args )
500503 elif 'eval' in mode :
501504 eval (env_name = env_name , args = args )
502505 elif 'sweep' in mode :
503- sweep (env_name = env_name , args = args , pareto = 'pareto' in mode , backend = backend )
506+ sweep (env_name = env_name , args = args , pareto = 'pareto' in mode )
504507 else :
505508 raise ValueError (err )
506509
0 commit comments