Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
b4210c1
Merge py file changes from benchmark-algs
taufeeque9 Jan 4, 2023
97bc063
Clean parallel script
taufeeque9 Jan 10, 2023
9291225
Undo the changes from #653 to the dagger benchmark config files.
ernestum Jan 26, 2023
276d863
Improve readability and interpretability of benchmarking tests.
ernestum Jan 25, 2023
37eb914
Add pxponential beta scheduler for dagger
taufeeque9 Mar 1, 2023
877383b
Ignore coverage for unknown algorithms.
ernestum Feb 2, 2023
c8e55cb
Cleanup and extend tests for beta schedules in dagger.
ernestum Feb 2, 2023
6b9b306
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 6, 2023
8576465
Fix test cases
taufeeque9 Feb 8, 2023
d81eb68
Add optuna to dependencies
taufeeque9 Feb 8, 2023
27467d3
Fix test case
taufeeque9 Feb 8, 2023
b59a768
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 8, 2023
1a3b6b8
Clean up the scripts
taufeeque9 Feb 9, 2023
7a438da
Remove reporter(done) since mean_return is reported by the runs
taufeeque9 Feb 9, 2023
5bc5835
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 20, 2023
2e56de8
Add beta_schedule parameter to dagger script
taufeeque9 Feb 23, 2023
84e854a
Merge branch 'master' into benchmark-pr
taufeeque9 Mar 16, 2023
73d8576
Update config policy kwargs
taufeeque9 Mar 16, 2023
9fdf878
Changes from review
taufeeque9 May 16, 2023
1c1dbc4
Fix errors with some configs
taufeeque9 May 16, 2023
3467af2
Merge branch 'master' into benchmark-pr
taufeeque9 May 16, 2023
44c4e97
Updates based on review
taufeeque9 Jun 14, 2023
4d493ae
Merge branch 'master' into benchmark-pr
taufeeque9 Jun 14, 2023
ab01269
Change metric everywhere
taufeeque9 Jun 14, 2023
f64580e
Merge branch 'master' into benchmark-pr
taufeeque9 Jul 11, 2023
e896d7d
Separate tuning code from parallel.py
taufeeque9 Jul 11, 2023
64c3a8d
Fix docstring
taufeeque9 Jul 11, 2023
8fba0d3
Removing resume option as it is getting tricky to correctly implement
taufeeque9 Jul 11, 2023
12ab31c
Minor fixes
taufeeque9 Jul 11, 2023
19b0f2c
Updates from review
taufeeque9 Jul 16, 2023
046b8d9
fix lint error
taufeeque9 Jul 16, 2023
8eee082
Add documentation for using the tuning script
taufeeque9 Jul 16, 2023
5ce7658
Fix lint error
taufeeque9 Jul 17, 2023
a8be331
Updates from the review
taufeeque9 Jul 18, 2023
4ff006d
Fix file name test errors
taufeeque9 Jul 18, 2023
6933afa
Add tune_run_kwargs in parallel script
taufeeque9 Jul 19, 2023
77f9d9b
Fix test errors
taufeeque9 Jul 19, 2023
54eb8a6
Fix test
taufeeque9 Jul 19, 2023
d50238f
Fix lint
taufeeque9 Jul 19, 2023
3fe22d4
Updates from review
taufeeque9 Jul 19, 2023
c50aa20
Simplify few lines of code
taufeeque9 Jul 20, 2023
000af61
Updates from review
taufeeque9 Aug 4, 2023
8b55134
Fix test
taufeeque9 Aug 4, 2023
f3ba2b5
Revert "Fix test"
taufeeque9 Aug 4, 2023
f8251c7
Fix test
taufeeque9 Aug 4, 2023
664fc37
Convert Dict to Mapping in input argument
taufeeque9 Aug 7, 2023
8690e1d
Ignore coverage in script configurations.
ernestum Aug 30, 2023
dd9eb6a
Pin huggingface_sb3 version.
ernestum Aug 30, 2023
b3930f4
Merge branch 'master' into benchmark-pr
ernestum Sep 26, 2023
40d87ef
Update to the newest seals environment versions.
ernestum Sep 26, 2023
71f6c92
Push gymnasium dependency to 0.29 to ensure mujoco envs work.
ernestum Sep 27, 2023
747ad32
Incorporate review comments
taufeeque9 Oct 4, 2023
691e759
Fix test errors
taufeeque9 Oct 4, 2023
2038e60
Move benchmarking/ to scripts/ and add named configs for tuned hyperp…
taufeeque9 Oct 4, 2023
35c7265
Bump cache version & remove unnecessary files
taufeeque9 Oct 5, 2023
fdf4f49
Include tuned hyperparam json files in package data
taufeeque9 Oct 5, 2023
5f9a4e6
Update storage hash
taufeeque9 Oct 5, 2023
91bb785
Update search space of bc
taufeeque9 Oct 5, 2023
3d93c84
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
f59fea2
update benchmark and hyper parameter tuning readme
ZiyueWang25 Oct 5, 2023
95110dc
Update README.md
taufeeque9 Oct 5, 2023
75f3477
Incorporate reviewer's comments in benchmarking readme
taufeeque9 Oct 6, 2023
77c1115
Merge branch 'master' into benchmark-pr
taufeeque9 Oct 6, 2023
1ba2b00
Update gymnasium version and render mode in eval policy
taufeeque9 Oct 7, 2023
ba4b693
Fix error
taufeeque9 Oct 7, 2023
bb76ee1
Merge branch 'update-gymnasium-dep' into benchmark-pr
taufeeque9 Oct 7, 2023
278f225
Merge branch 'master' into benchmark-pr
taufeeque9 Oct 8, 2023
01755a2
Update commands.py hex strings
taufeeque9 Oct 9, 2023
fdcef92
Merge branch 'master' into benchmark-pr
taufeeque9 Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"sacred>=0.8.4",
"tensorboard>=1.14",
"huggingface_sb3>=2.2.1",
"optuna>=3.0.1",
],
tests_require=TESTS_REQUIRE,
extras_require={
Expand Down
30 changes: 21 additions & 9 deletions src/imitation/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def _get_algo_name(sd: sacred_util.SacredDicts) -> str:

def _return_summaries(sd: sacred_util.SacredDicts) -> dict:
imit_stats = get(sd.run, "result.imit_stats")
if imit_stats is None:
# stored in rollout key for preference comparison
imit_stats = get(sd.run, "result.rollout")
Comment thread
taufeeque9 marked this conversation as resolved.
Outdated
expert_stats = get(sd.run, "result.expert_stats")

expert_return_summary = None
Expand Down Expand Up @@ -262,26 +265,35 @@ def analyze_imitation(
csv_output_path: If provided, then save a CSV output file to this path.
tex_output_path: If provided, then save a LaTeX-format table to this path.
print_table: If True, then print the dataframe to stdout.
table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the
number of columns in the table.
table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the
number of columns in the table. Level 3 prints all of the columns available.

Returns:
The DataFrame generated from the Sacred logs.
"""
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
if table_verbosity == 3:
Comment thread
taufeeque9 marked this conversation as resolved.
Outdated
table_entry_fns_subset = _get_table_entry_fns_subset(2)
else:
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
Comment thread
taufeeque9 marked this conversation as resolved.
Outdated

rows = []
df = pd.DataFrame()
Comment thread
taufeeque9 marked this conversation as resolved.
Outdated
for sd in _gather_sacred_dicts():
row = {}
new_df = pd.DataFrame()
if table_verbosity == -1:
# gets all config columns
new_df = pd.json_normalize(sd.config)
else:
new_df = new_df.append({}, ignore_index=True)

for col_name, make_entry_fn in table_entry_fns_subset.items():
row[col_name] = make_entry_fn(sd)
rows.append(row)
new_df[col_name] = make_entry_fn(sd)

df = pd.concat([df, new_df])

df = pd.DataFrame(rows)
if len(df) > 0:
df.sort_values(by=["algo", "env_name"], inplace=True)

display_options = dict(index=False)
display_options: Mapping[str, Any] = dict(index=False)
if csv_output_path is not None:
df.to_csv(csv_output_path, **display_options)
print(f"Wrote CSV file to {csv_output_path}")
Expand Down
235 changes: 198 additions & 37 deletions src/imitation/scripts/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
`@parallel_ex.named_config` to define a new parallel experiment.

Adding custom named configs is necessary because the CLI interface can't add
search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`.
search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`.

For tuning hyperparameters of an algorithm on a given environment, override
the `base_named_configs` argument with the named config of the environment.
Ex: python -m imitation.scripts.parallel with example_gail \
'base_named_configs=["logging.wandb_logging", "seals_half_cheetah"]'
"""

import numpy as np
import ray.tune as tune
import sacred
from torch import nn

from imitation.algorithms import dagger
from imitation.util.util import make_unique_timestamp

parallel_ex = sacred.Experiment("parallel")
Expand All @@ -33,17 +40,11 @@ def config():

local_dir = None # `local_dir` arg for `ray.tune.run`
upload_dir = None # `upload_dir` arg for `ray.tune.run`
n_seeds = 3 # Number of seeds to search over by default


@parallel_ex.config
def seeds(n_seeds):
search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}}


@parallel_ex.named_config
def s3():
upload_dir = "s3://shwang-chai/private"
experiment_checkpoint_path = ""
eval_best_trial = False
eval_trial_seeds = 5 # Number of seeds to search over by default
num_samples = 1 # Number of samples per grid search configuration
repeat = 1


# Debug named configs
Expand All @@ -58,12 +59,12 @@ def generate_test_data():
"""
sacred_ex_name = "train_rl"
run_name = "TEST"
n_seeds = 1
repeat = 1
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
"learning_rate": tune.choice(
[3e-4 * x for x in (1 / 3, 1 / 2)],
),
},
Expand All @@ -86,13 +87,13 @@ def generate_test_data():
def example_cartpole_rl():
sacred_ex_name = "train_rl"
run_name = "example-cartpole"
n_seeds = 2
repeat = 2
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
"learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.choice([16, 32, 64]),
},
},
},
Expand All @@ -105,44 +106,204 @@ def example_cartpole_rl():


@parallel_ex.named_config
def example_rl_easy():
def example_rl():
sacred_ex_name = "train_rl"
run_name = "example-rl-easy"
n_seeds = 2
run_name = "rl_tuning"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {"environment": {"num_vec": 1}}
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"rl": {
"batch_size": tune.choice([512, 1024, 2048, 4096, 8192]),
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
"learning_rate": tune.loguniform(1e-5, 1e-2),
"batch_size": tune.choice([64, 128, 256, 512]),
"n_epochs": tune.choice([5, 10, 20]),
},
},
},
}
resources_per_trial = dict(cpu=4)
num_samples = 100
eval_best_trial = True
eval_trial_seeds = 5
repeat = 1
resources_per_trial = dict(cpu=1)


@parallel_ex.named_config
def example_bc():
sacred_ex_name = "train_imitation"
run_name = "bc_tuning"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {"environment": {"num_vec": 1}}
search_space = {
"config_updates": {
"bc_kwargs": dict(
batch_size=tune.choice([8, 16, 32, 64]),
l2_weight=tune.loguniform(1e-6, 1e-2), # L2 regularization weight
optimizer_kwargs=dict(
lr=tune.loguniform(1e-5, 1e-2),
),
),
"bc_train_kwargs": dict(
n_epochs=tune.choice([1, 5, 10, 20]),
),
},
"command_name": "bc",
}
num_samples = 64
eval_best_trial = True
eval_trial_seeds = 5
repeat = 3
resources_per_trial = dict(cpu=1)


@parallel_ex.named_config
def example_dagger():
sacred_ex_name = "train_imitation"
run_name = "dagger_tuning"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {
"environment": {"num_vec": 1},
"dagger": {"total_timesteps": 1e5},
"bc_kwargs": {
"batch_size": 16,
"l2_weight": 1e-4,
"optimizer_kwargs": {"lr": 1e-3},
},
}
search_space = {
"config_updates": {
"bc_train_kwargs": dict(
n_epochs=tune.choice([1, 5, 10]),
),
"dagger": dict(
beta_schedule=tune.choice(
[dagger.LinearBetaSchedule(i) for i in [1, 5, 15]]
+ [dagger.ExponentialBetaSchedule(i) for i in [0.3, 0.5, 0.7]],
),
rollout_round_min_episodes=tune.choice([3, 5, 10]),
),
},
"command_name": "dagger",
}
num_samples = 50
repeat = 3
eval_best_trial = True
eval_trial_seeds = 5
resources_per_trial = dict(cpu=1)


@parallel_ex.named_config
def example_gail():
sacred_ex_name = "train_adversarial"
run_name = "gail_tuning_hc"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {
"environment": {"num_vec": 1},
"total_timesteps": 1e7,
}
search_space = {
"config_updates": {
"algorithm_kwargs": dict(
demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]),
n_disc_updates_per_round=tune.choice([8, 16]),
),
"rl": {
"batch_size": tune.choice([4096, 8192, 16384]),
"rl_kwargs": {
"ent_coef": tune.loguniform(1e-7, 1e-3),
"learning_rate": tune.loguniform(1e-5, 1e-2),
},
},
"algorithm_specific": {},
},
"command_name": "gail",
}
num_samples = 100
eval_best_trial = True
eval_trial_seeds = 5
repeat = 3
resources_per_trial = dict(cpu=1)


@parallel_ex.named_config
def example_gail_easy():
def example_airl():
sacred_ex_name = "train_adversarial"
run_name = "example-gail-easy"
n_seeds = 1
run_name = "airl_tuning"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {
"environment": {"num_vec": 1},
"total_timesteps": 1e7,
}
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"init_trainer_kwargs": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
np.logspace(3e-6, 1e-1, num=3),
),
"nminibatches": tune.grid_search([16, 32, 64]),
},
"algorithm_kwargs": dict(
demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]),
n_disc_updates_per_round=tune.choice([8, 16]),
),
"rl": {
"batch_size": tune.choice([4096, 8192, 16384]),
"rl_kwargs": {
"ent_coef": tune.loguniform(1e-7, 1e-3),
"learning_rate": tune.loguniform(1e-5, 1e-2),
},
},
"algorithm_specific": {},
},
"command_name": "airl",
}
num_samples = 100
eval_best_trial = True
eval_trial_seeds = 5
repeat = 3
resources_per_trial = dict(cpu=1)


@parallel_ex.named_config
def example_pc():
sacred_ex_name = "train_preference_comparisons"
run_name = "pc_tuning"
base_named_configs = ["logging.wandb_logging"]
base_config_updates = {
"environment": {"num_vec": 1},
"total_timesteps": 2e7,
"total_comparisons": 5000,
"query_schedule": "hyperbolic",
"gatherer_kwargs": {"sample": True},
}
search_space = {
"command_name": "gail",
"named_configs": tune.choice(
[
["reward.normalize_output_disable"],
],
),
"config_updates": {
"train": {
"policy_kwargs": {
"activation_fn": tune.choice(
[
nn.ReLU,
],
),
},
},
"num_iterations": tune.choice([25, 50]),
"initial_comparison_frac": tune.choice([0.1, 0.25]),
"reward_trainer_kwargs": {
"epochs": tune.choice([1, 3, 6]),
},
"rl": {
"batch_size": tune.choice([512, 2048, 8192]),
"rl_kwargs": {
"learning_rate": tune.loguniform(1e-5, 1e-2),
"ent_coef": tune.loguniform(1e-7, 1e-3),
},
},
},
}
num_samples = 100
eval_best_trial = True
eval_trial_seeds = 5
repeat = 3
resources_per_trial = dict(cpu=1)
Loading