Skip to content

Commit dd7b2cb

Browse files
committed
don't try to pickle backend
1 parent 4e0c951 commit dd7b2cb

1 file changed

Lines changed: 17 additions & 14 deletions

File tree

pufferlib/pufferl.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,14 @@ def validate_config(args):
167167
assert minibatch_size <= horizon * total_agents, \
168168
f'minibatch_size {minibatch_size} > total_agents {total_agents} * horizon {horizon}'
169169

170-
def _train_worker(args, backend=_C):
170+
def _resolve_backend(args):
171+
if args.get('slowly'):
172+
from pufferlib.torch_pufferl import PuffeRL
173+
return PuffeRL
174+
return _C
175+
176+
def _train_worker(args):
177+
backend = _resolve_backend(args)
171178
pufferl = backend.create_pufferl(args)
172179
args.pop('nccl_id', None)
173180
while pufferl.global_step < args['train']['total_timesteps']:
@@ -176,8 +183,9 @@ def _train_worker(args, backend=_C):
176183

177184
backend.close(pufferl)
178185

179-
def _train(env_name, args, backend=_C, sweep_obj=None, result_queue=None, verbose=False):
186+
def _train(env_name, args, sweep_obj=None, result_queue=None, verbose=False):
180187
'''Single-GPU training worker. Process target for both DDP ranks and sweep trials.'''
188+
backend = _resolve_backend(args)
181189
rank = args['rank']
182190
run_id = str(int(1000*time.time()))
183191
if args['wandb']:
@@ -302,7 +310,7 @@ def _train(env_name, args, backend=_C, sweep_obj=None, result_queue=None, verbos
302310
if result_queue is not None:
303311
result_queue.put((args['gpu_id'], metrics['env/score'], metrics['uptime'], metrics['agent_steps']))
304312

305-
def train(env_name, args=None, gpus=None, backend=_C, **kwargs):
313+
def train(env_name, args=None, gpus=None, **kwargs):
306314
args = args or load_config(env_name)
307315
validate_config(args)
308316

@@ -321,12 +329,12 @@ def train(env_name, args=None, gpus=None, backend=_C, **kwargs):
321329
worker_args['rank'] = rank
322330
worker_args['gpu_id'] = gpu_id
323331
if rank == 0 and not subprocess:
324-
_train(env_name, worker_args, backend=backend, verbose=True)
332+
_train(env_name, worker_args, verbose=True)
325333
else:
326334
ctx.Process(target=_train, args=(env_name, worker_args),
327-
kwargs={**kwargs, 'backend': backend}).start()
335+
kwargs=kwargs).start()
328336

329-
def sweep(env_name, args=None, pareto=False, backend=_C):
337+
def sweep(env_name, args=None, pareto=False):
330338
'''Train entry point. Handles single-GPU, multi-GPU DDP, and sweeps.'''
331339
args = args or load_config(env_name)
332340
exp_gpus = args['train']['gpus']
@@ -384,7 +392,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
384392
exp_args = deepcopy(args)
385393
active[gpu_id] = exp_args
386394
train(env_name, exp_args, range(gpu_id, gpu_id + exp_gpus),
387-
backend=backend, sweep_obj=sweep_obj, result_queue=result_queue)
395+
sweep_obj=sweep_obj, result_queue=result_queue)
388396

389397
def eval(env_name, args=None, load_path=None):
390398
'''Evaluate a trained policy using the native pipeline.
@@ -490,17 +498,12 @@ def main():
490498
env_name = sys.argv.pop(1)
491499
args = load_config(env_name)
492500

493-
backend = _C
494-
if args.get('slowly'):
495-
from pufferlib.torch_pufferl import PuffeRL
496-
backend = PuffeRL
497-
498501
if 'train' in mode:
499-
train(env_name=env_name, args=args, backend=backend)
502+
train(env_name=env_name, args=args)
500503
elif 'eval' in mode:
501504
eval(env_name=env_name, args=args)
502505
elif 'sweep' in mode:
503-
sweep(env_name=env_name, args=args, pareto='pareto' in mode, backend=backend)
506+
sweep(env_name=env_name, args=args, pareto='pareto' in mode)
504507
else:
505508
raise ValueError(err)
506509

0 commit comments

Comments
 (0)