Skip to content

Commit c1c31a2

Browse files
committed
refactor errors
1 parent 94511ba commit c1c31a2

3 files changed

Lines changed: 12 additions & 14 deletions

File tree

examples/vectorization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def close(self):
7272
try:
7373
vecenv = pufferlib.vector.make(SamplePufferEnv,
7474
num_envs=1, num_workers=2, batch_size=3, backend=pufferlib.vector.Multiprocessing)
75-
except pufferlib.APIUsageError:
75+
except (AssertionError, ValueError):
7676
#Make sure num_envs divides num_workers, and both num_envs and num_workers should divide batch_size
7777
pass
7878

pufferlib/pufferl.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,10 @@ def validate_config(args):
162162
minibatch_size = args['train']['minibatch_size']
163163
horizon = args['train']['horizon']
164164
total_agents = args['vec']['total_agents']
165-
if (minibatch_size % horizon) != 0:
166-
raise pufferlib.APIUsageError(
167-
f'minibatch_size {minibatch_size} must be divisible by horizon {horizon}')
168-
if minibatch_size > horizon * total_agents:
169-
raise pufferlib.APIUsageError(
170-
f'minibatch_size {minibatch_size} > total_agents {total_agents} * horizon {horizon}')
165+
assert (minibatch_size % horizon) == 0, \
166+
f'minibatch_size {minibatch_size} must be divisible by horizon {horizon}'
167+
assert minibatch_size <= horizon * total_agents, \
168+
f'minibatch_size {minibatch_size} > total_agents {total_agents} * horizon {horizon}'
171169

172170
def _train_worker(args, backend=_C):
173171
pufferl = backend.create_pufferl(args)
@@ -343,7 +341,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
343341
try:
344342
sweep_cls = getattr(pufferlib.sweep, method)
345343
except:
346-
raise pufferlib.APIUsageError(f'Invalid sweep method {method}. See pufferlib.sweep')
344+
raise ValueError(f'Invalid sweep method {method}. See pufferlib.sweep')
347345

348346
sweep_obj = sweep_cls(sweep_config)
349347
num_experiments = args['sweep']['max_runs']
@@ -379,7 +377,7 @@ def sweep(env_name, args=None, pareto=False, backend=_C):
379377

380378
try:
381379
validate_config(args)
382-
except pufferlib.APIUsageError as e:
380+
except (AssertionError, ValueError) as e:
383381
print(f'WARNING: {e}, skipping')
384382
sweep_obj.observe(args, 0, 0, is_failure=True)
385383
continue
@@ -449,7 +447,7 @@ def load_config(env_name):
449447
p.read([puffer_default_config, path])
450448
if env_name in p['base']['env_name'].split(): break
451449
else:
452-
raise pufferlib.APIUsageError('No config for env_name {}'.format(env_name))
450+
raise ValueError('No config for env_name {}'.format(env_name))
453451

454452
for section in p.sections():
455453
for key in p[section]:
@@ -487,15 +485,15 @@ def load_config(env_name):
487485
def main():
488486
err = 'Usage: puffer [train, eval, sweep, paretosweep] [env_name] [optional args]. --help for more info'
489487
if len(sys.argv) < 3:
490-
raise pufferlib.APIUsageError(err)
488+
raise ValueError(err)
491489

492490
mode = sys.argv.pop(1)
493491
env_name = sys.argv.pop(1)
494492
args = load_config(env_name)
495493

496494
backend = _C
497495
if args.get('slowly'):
498-
from pufferlib.python_pufferl import PuffeRL
496+
from pufferlib.torch_pufferl import PuffeRL
499497
backend = PuffeRL
500498

501499
if 'train' in mode:
@@ -505,7 +503,7 @@ def main():
505503
elif 'sweep' in mode:
506504
sweep(env_name=env_name, args=args, pareto='pareto' in mode, backend=backend)
507505
else:
508-
raise pufferlib.APIUsageError(err)
506+
raise ValueError(err)
509507

510508
if __name__ == '__main__':
511509
main()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def load_policy(args, vec):
443443
data_dir = artifact.download()
444444
path = f'{data_dir}/{max(os.listdir(data_dir))}'
445445
else:
446-
raise pufferlib.APIUsageError('load_id requires --wandb')
446+
raise ValueError('load_id requires --wandb')
447447

448448
state_dict = torch.load(path, map_location=device)
449449
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

0 commit comments

Comments
 (0)