-
Notifications
You must be signed in to change notification settings - Fork 301
Add scripts and configs for hyperparameter tuning #675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 51 commits
b4210c1
97bc063
9291225
276d863
37eb914
877383b
c8e55cb
6b9b306
8576465
d81eb68
27467d3
b59a768
1a3b6b8
7a438da
5bc5835
2e56de8
84e854a
73d8576
9fdf878
1c1dbc4
3467af2
44c4e97
4d493ae
ab01269
f64580e
e896d7d
64c3a8d
8fba0d3
12ab31c
19b0f2c
046b8d9
8eee082
5ce7658
a8be331
4ff006d
6933afa
77f9d9b
54eb8a6
d50238f
3fe22d4
c50aa20
000af61
8b55134
f3ba2b5
f8251c7
664fc37
8690e1d
dd9eb6a
b3930f4
40d87ef
71f6c92
747ad32
691e759
2038e60
35c7265
fdf4f49
5f9a4e6
91bb785
3d93c84
f59fea2
95110dc
75f3477
77c1115
1ba2b00
ba4b693
bb76ee1
278f225
01755a2
fdcef92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Ant-v0" | ||
| "gym_id": "seals/Ant-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/HalfCheetah-v0" | ||
| "gym_id": "seals/HalfCheetah-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Hopper-v0" | ||
| "gym_id": "seals/Hopper-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Ant-v0" | ||
| "gym_id": "seals/Ant-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/HalfCheetah-v0" | ||
| "gym_id": "seals/HalfCheetah-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Hopper-v0" | ||
| "gym_id": "seals/Hopper-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Swimmer-v0" | ||
| "gym_id": "seals/Swimmer-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Walker2d-v0" | ||
| "gym_id": "seals/Walker2d-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Ant-v0" | ||
| "gym_id": "seals/Ant-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/HalfCheetah-v0" | ||
| "gym_id": "seals/HalfCheetah-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Hopper-v0" | ||
| "gym_id": "seals/Hopper-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Swimmer-v0" | ||
| "gym_id": "seals/Swimmer-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Walker2d-v0" | ||
| "gym_id": "seals/Walker2d-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Ant-v0" | ||
| "gym_id": "seals/Ant-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/HalfCheetah-v0" | ||
| "gym_id": "seals/HalfCheetah-v1" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,6 @@ | |
| "n_episodes_eval": 50 | ||
| }, | ||
| "environment": { | ||
| "gym_id": "seals/Hopper-v0" | ||
| "gym_id": "seals/Hopper-v1" | ||
| } | ||
| } | ||
|
taufeeque9 marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,174 @@ | ||||||||
| """Tunes the hyperparameters of the algorithms.""" | ||||||||
|
|
||||||||
| import copy | ||||||||
| import pathlib | ||||||||
| from typing import Any, Dict | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import ray | ||||||||
| from pandas.api import types as pd_types | ||||||||
| from ray.tune.search import optuna | ||||||||
| from sacred.observers import FileStorageObserver | ||||||||
| from tuning_config import parallel_ex, tuning_ex | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8: standard library, third-party then first party-imports with line separating each.
Suggested change
|
||||||||
|
|
||||||||
|
|
||||||||
| @tuning_ex.main | ||||||||
| def tune( | ||||||||
| parallel_run_config: Dict[str, Any], | ||||||||
| eval_best_trial_resource_multiplier: int = 1, | ||||||||
|
taufeeque9 marked this conversation as resolved.
Outdated
|
||||||||
| num_eval_seeds: int = 5, | ||||||||
| ) -> None: | ||||||||
| """Tune hyperparameters of imitation algorithms using parallel script. | ||||||||
|
taufeeque9 marked this conversation as resolved.
Outdated
|
||||||||
|
|
||||||||
| Args: | ||||||||
| parallel_run_config: Dictionary of arguments to pass to the parallel script. | ||||||||
| eval_best_trial_resource_multiplier: Factor by which to multiply the | ||||||||
|
taufeeque9 marked this conversation as resolved.
|
||||||||
| number of cpus per trial in `resources_per_trial`. This is useful for | ||||||||
| allocating more resources per trial to the evaluation trials than the | ||||||||
| resources for hyperparameter tuning since number of evaluation trials | ||||||||
| is usually much smaller than the number of tuning trials. | ||||||||
| num_eval_seeds: Number of distinct seeds to evaluate the best trial on. | ||||||||
| Set to 0 to disable evaluation. | ||||||||
|
|
||||||||
| Raises: | ||||||||
| ValueError: If no trials are returned by the parallel run of tuning. | ||||||||
| """ | ||||||||
| updated_parallel_run_config = copy.deepcopy(parallel_run_config) | ||||||||
|
taufeeque9 marked this conversation as resolved.
|
||||||||
| search_alg = optuna.OptunaSearch() | ||||||||
| if "tune_run_kwargs" in updated_parallel_run_config: | ||||||||
|
taufeeque9 marked this conversation as resolved.
Outdated
|
||||||||
| updated_parallel_run_config["tune_run_kwargs"]["search_alg"] = search_alg | ||||||||
| else: | ||||||||
| updated_parallel_run_config["tune_run_kwargs"] = dict(search_alg=search_alg) | ||||||||
| run = parallel_ex.run(config_updates=updated_parallel_run_config) | ||||||||
| experiment_analysis = run.result | ||||||||
| if not experiment_analysis.trials: | ||||||||
| raise ValueError( | ||||||||
| "No trials found. Please ensure that the `experiment_checkpoint_path` " | ||||||||
| "in `parallel_run_config` is passed correctly " | ||||||||
| "or that the tuning run finished properly.", | ||||||||
| ) | ||||||||
|
|
||||||||
| return_key = "imit_stats/monitor_return_mean" | ||||||||
| if updated_parallel_run_config["sacred_ex_name"] == "train_rl": | ||||||||
| return_key = "monitor_return_mean" | ||||||||
| best_trial = find_best_trial(experiment_analysis, return_key, print_return=True) | ||||||||
|
|
||||||||
| if num_eval_seeds > 0: # evaluate the best trial | ||||||||
| resources_per_trial_eval = copy.deepcopy( | ||||||||
| updated_parallel_run_config["resources_per_trial"], | ||||||||
| ) | ||||||||
| # update cpus per trial only if it is provided in `resources_per_trial` | ||||||||
| # Uses the default values (cpu=1) if it is not provided | ||||||||
| if "cpu" in updated_parallel_run_config["resources_per_trial"]: | ||||||||
| resources_per_trial_eval["cpu"] *= eval_best_trial_resource_multiplier | ||||||||
| evaluate_trial( | ||||||||
| best_trial, | ||||||||
| num_eval_seeds, | ||||||||
| updated_parallel_run_config["run_name"] + "_best_hp_eval", | ||||||||
| updated_parallel_run_config, | ||||||||
| resources_per_trial_eval, | ||||||||
| return_key, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| def find_best_trial( | ||||||||
| experiment_analysis: ray.tune.analysis.ExperimentAnalysis, | ||||||||
| return_key: str, | ||||||||
| print_return: bool = False, | ||||||||
| ) -> ray.tune.experiment.Trial: | ||||||||
| """Find the trial with the best mean return across all seeds. | ||||||||
|
|
||||||||
| Args: | ||||||||
| experiment_analysis: The result of a parallel/tuning experiment. | ||||||||
| return_key: The key of the return metric in the results dataframe. | ||||||||
| print_return: Whether to print the mean and std of the returns | ||||||||
| of the best trial. | ||||||||
|
|
||||||||
| Returns: | ||||||||
| best_trial: The trial with the best mean return across all seeds. | ||||||||
| """ | ||||||||
| df = experiment_analysis.results_df | ||||||||
| # convert object dtype to str required by df.groupby | ||||||||
| for col in df.columns: | ||||||||
| if pd_types.is_object_dtype(df[col]): | ||||||||
| df[col] = df[col].astype("str") | ||||||||
| # group into separate HP configs | ||||||||
| grp_keys = [c for c in df.columns if c.startswith("config") and "seed" not in c] | ||||||||
| grps = df.groupby(grp_keys) | ||||||||
| # store mean return of runs across all seeds in a group | ||||||||
| df["mean_return"] = grps[return_key].transform(lambda x: x.mean()) | ||||||||
|
taufeeque9 marked this conversation as resolved.
|
||||||||
| best_config_df = df[df["mean_return"] == df["mean_return"].max()] | ||||||||
| row = best_config_df.iloc[0] | ||||||||
| best_config_tag = row["experiment_tag"] | ||||||||
| assert experiment_analysis.trials is not None # for mypy | ||||||||
| best_trial = [ | ||||||||
| t for t in experiment_analysis.trials if best_config_tag in t.experiment_tag | ||||||||
| ][0] | ||||||||
|
|
||||||||
| if print_return: | ||||||||
| all_returns = df[df["mean_return"] == row["mean_return"]][return_key] | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit fragile, you could in theory have multiple distinct hyperparameter groups that led to the same mean returns across seeds, but in practice probably OK.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case we get multiple distinct hyperparameter groups that get the same mean returns across seeds, we pick the first hyperparameter group from the |
||||||||
| all_returns = all_returns.to_numpy() | ||||||||
| print("All returns:", all_returns) | ||||||||
| print("Mean return:", row["mean_return"]) | ||||||||
| print("Std return:", np.std(all_returns)) | ||||||||
| print("Total seeds:", len(all_returns)) | ||||||||
| return best_trial | ||||||||
|
|
||||||||
|
|
||||||||
| def evaluate_trial( | ||||||||
| trial: ray.tune.experiment.Trial, | ||||||||
| num_eval_seeds: int, | ||||||||
| run_name: str, | ||||||||
| parallel_run_config: Dict[str, Any], | ||||||||
| resources_per_trial: Dict[str, int], | ||||||||
| return_key: str, | ||||||||
| print_return: bool = False, | ||||||||
| ): | ||||||||
| """Evaluate a given trial of a parallel run on a separate set of seeds. | ||||||||
|
|
||||||||
| Args: | ||||||||
| trial: The trial to evaluate. | ||||||||
| num_eval_seeds: Number of distinct seeds to evaluate the best trial on. | ||||||||
| run_name: The name of the evaluation run. | ||||||||
| parallel_run_config: Dictionary of arguments passed to the parallel | ||||||||
| script to get best_trial. | ||||||||
| resources_per_trial: Resources to be used for each evaluation trial. | ||||||||
| return_key: The key of the return metric in the results dataframe. | ||||||||
| print_return: Whether to print the mean and std of the evaluation returns. | ||||||||
|
|
||||||||
| Returns: | ||||||||
| eval_run: The result of the evaluation run. | ||||||||
| """ | ||||||||
| config = trial.config | ||||||||
| config["config_updates"].update( | ||||||||
| seed=ray.tune.grid_search(list(range(100, 100 + num_eval_seeds))), | ||||||||
| ) | ||||||||
| eval_config_updates = parallel_run_config.copy() | ||||||||
| eval_config_updates.update( | ||||||||
| run_name=run_name, | ||||||||
| num_samples=1, | ||||||||
| search_space=config, | ||||||||
| resources_per_trial=resources_per_trial, | ||||||||
| search_alg=None, | ||||||||
| repeat=1, | ||||||||
| experiment_checkpoint_path="", | ||||||||
| ) | ||||||||
| eval_run = parallel_ex.run(config_updates=eval_config_updates) | ||||||||
| eval_result = eval_run.result | ||||||||
| returns = eval_result.results_df[return_key].to_numpy() | ||||||||
| if print_return: | ||||||||
| print("All returns:", returns) | ||||||||
| print("Mean:", np.mean(returns)) | ||||||||
| print("Std:", np.std(returns)) | ||||||||
| return eval_run | ||||||||
|
|
||||||||
|
|
||||||||
| def main_console(): | ||||||||
| observer_path = pathlib.Path.cwd() / "output" / "sacred" / "tuning" | ||||||||
| observer = FileStorageObserver(observer_path) | ||||||||
| tuning_ex.observers.append(observer) | ||||||||
| tuning_ex.run_commandline() | ||||||||
|
|
||||||||
|
|
||||||||
| if __name__ == "__main__": # pragma: no cover | ||||||||
| main_console() | ||||||||
Uh oh!
There was an error while loading. Please reload this page.