@@ -162,12 +162,10 @@ def validate_config(args):
162162 minibatch_size = args ['train' ]['minibatch_size' ]
163163 horizon = args ['train' ]['horizon' ]
164164 total_agents = args ['vec' ]['total_agents' ]
165- if (minibatch_size % horizon ) != 0 :
166- raise pufferlib .APIUsageError (
167- f'minibatch_size { minibatch_size } must be divisible by horizon { horizon } ' )
168- if minibatch_size > horizon * total_agents :
169- raise pufferlib .APIUsageError (
170- f'minibatch_size { minibatch_size } > total_agents { total_agents } * horizon { horizon } ' )
165+ assert (minibatch_size % horizon ) == 0 , \
166+ f'minibatch_size { minibatch_size } must be divisible by horizon { horizon } '
167+ assert minibatch_size <= horizon * total_agents , \
168+ f'minibatch_size { minibatch_size } > total_agents { total_agents } * horizon { horizon } '
171169
172170def _train_worker (args , backend = _C ):
173171 pufferl = backend .create_pufferl (args )
@@ -343,7 +341,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
343341 try :
344342 sweep_cls = getattr (pufferlib .sweep , method )
345343 except :
346- raise pufferlib . APIUsageError (f'Invalid sweep method { method } . See pufferlib.sweep' )
344+ raise ValueError (f'Invalid sweep method { method } . See pufferlib.sweep' )
347345
348346 sweep_obj = sweep_cls (sweep_config )
349347 num_experiments = args ['sweep' ]['max_runs' ]
@@ -379,7 +377,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
379377
380378 try :
381379 validate_config (args )
382- except pufferlib . APIUsageError as e :
380+ except ( AssertionError , ValueError ) as e :
383381 print (f'WARNING: { e } , skipping' )
384382 sweep_obj .observe (args , 0 , 0 , is_failure = True )
385383 continue
@@ -449,7 +447,7 @@ def load_config(env_name):
449447 p .read ([puffer_default_config , path ])
450448 if env_name in p ['base' ]['env_name' ].split (): break
451449 else :
452- raise pufferlib . APIUsageError ('No config for env_name {}' .format (env_name ))
450+ raise ValueError ('No config for env_name {}' .format (env_name ))
453451
454452 for section in p .sections ():
455453 for key in p [section ]:
@@ -487,15 +485,15 @@ def load_config(env_name):
487485def main ():
488486 err = 'Usage: puffer [train, eval, sweep, paretosweep] [env_name] [optional args]. --help for more info'
489487 if len (sys .argv ) < 3 :
490- raise pufferlib . APIUsageError (err )
488+ raise ValueError (err )
491489
492490 mode = sys .argv .pop (1 )
493491 env_name = sys .argv .pop (1 )
494492 args = load_config (env_name )
495493
496494 backend = _C
497495 if args .get ('slowly' ):
498- from pufferlib .python_pufferl import PuffeRL
496+ from pufferlib .torch_pufferl import PuffeRL
499497 backend = PuffeRL
500498
501499 if 'train' in mode :
@@ -505,7 +503,7 @@ def main():
505503 elif 'sweep' in mode :
506504 sweep (env_name = env_name , args = args , pareto = 'pareto' in mode , backend = backend )
507505 else :
508- raise pufferlib . APIUsageError (err )
506+ raise ValueError (err )
509507
510508if __name__ == '__main__' :
511509 main ()
0 commit comments