From 628fc42854be01f85c4c83b85b3a033d98769f94 Mon Sep 17 00:00:00 2001 From: OrangeX4 <318483724@qq.com> Date: Thu, 1 Aug 2024 10:37:57 +0800 Subject: [PATCH 1/7] update: add antmaze configs --- scripts/configs/bt_iql/antmaze/default.yaml | 98 +++++++++++++++++ scripts/configs/ipl_iql/antmaze/default.yaml | 81 ++++++++++++++ scripts/configs/pt_iql/antmaze.yaml | 108 +++++++++++++++++++ wiserl/dataset/ipl_dataset.py | 4 + 4 files changed, 291 insertions(+) create mode 100644 scripts/configs/bt_iql/antmaze/default.yaml create mode 100644 scripts/configs/ipl_iql/antmaze/default.yaml create mode 100644 scripts/configs/pt_iql/antmaze.yaml diff --git a/scripts/configs/bt_iql/antmaze/default.yaml b/scripts/configs/bt_iql/antmaze/default.yaml new file mode 100644 index 0000000..a5daac0 --- /dev/null +++ b/scripts/configs/bt_iql/antmaze/default.yaml @@ -0,0 +1,98 @@ +algorithm: + class: BTIQL + beta: 0.3333 + expectile: 0.9 + max_exp_clip: 100.0 + reward_reg: 0.0 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-medium-diverse-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256] + reward_act: identity + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256] + +rm_dataset: + - class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 64 + segment_length: null + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 50000 + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 1000000 + +processor: null diff --git a/scripts/configs/ipl_iql/antmaze/default.yaml b/scripts/configs/ipl_iql/antmaze/default.yaml new file mode 100644 index 0000000..7d07a0d --- /dev/null +++ b/scripts/configs/ipl_iql/antmaze/default.yaml @@ -0,0 +1,81 @@ +algorithm: + class: IPL_IQL + beta: 0.3333 + expectile: 0.7 + reward_reg: 0.5 + reg_replay_weight: 0.5 + actor_replay_weight: 0.5 + value_replay_weight: 0.5 + tau: 0.005 + max_exp_clip: 100.0 + discount: 0.99 + target_freq: 1 + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-large-diverse-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256] + +dataset: + - class: IPLComparisonOfflineDataset + env: antmaze-large-diverse-v2 + batch_size: 8 + mode: human + segment_length: null + - class: D4RLOfflineDataset + env: antmaze-large-diverse-v2 + batch_size: 256 + mode: transition + reward_normalize: true + +dataloader: + num_workers: 0 # use the main thread to sample data + batch_size: null # do not merge the data along batch axis + +trainer: + env_freq: null + total_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + +processor: null diff --git a/scripts/configs/pt_iql/antmaze.yaml b/scripts/configs/pt_iql/antmaze.yaml new file mode 100644 index 0000000..0d930f7 --- /dev/null +++ b/scripts/configs/pt_iql/antmaze.yaml @@ -0,0 +1,108 @@ +algorithm: + class: PTIQL + beta: 0.3333 + expectile: 0.7 + max_exp_clip: 100.0 + reward_reg: 0.0 + use_weighted_sum: true + max_seq_len: 100 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-large-diverse-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + reward: + class: AdamW + lr: 0.0003 + +network: + reward: + num_layers: 3 + embed_dim: 256 + pref_embed_dim: 256 + num_heads: 1 + reward_act: sigmoid + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256] + +rm_dataset: + - class: IPLComparisonOfflineDataset + env: antmaze-large-diverse-v2 + batch_size: 64 + segment_length: null + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: antmaze-large-diverse-v2 + batch_size: 256 + mode: trajectory + segment_length: 20 +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 100000 + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 5000 + +rm_eval: + function: eval_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: antmaze-large-diverse-v2 + batch_size: 32 + mode: human + eval: false +rl_eval: + function: eval_offline + num_ep: 10 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 1000000 + reward: + warmup_steps: 10000 + max_steps: 100000 + + +processor: null diff --git a/wiserl/dataset/ipl_dataset.py b/wiserl/dataset/ipl_dataset.py index 82e6e70..36eec11 100644 --- a/wiserl/dataset/ipl_dataset.py +++ b/wiserl/dataset/ipl_dataset.py @@ -19,6 +19,10 @@ "hammer-human-v1": f"{prefix}/preference_transformer/hammer-human-v1/num100", "pen-cloned-v1": f"{prefix}/preference_transformer/pen-cloned-v1/num100", "pen-human-v1": f"{prefix}/preference_transformer/pen-human-v1/num100", + "antmaze-large-diverse-v2": f"{prefix}/preference_transformer/antmaze-large-diverse-v2/num1000", + "antmaze-large-play-v2": f"{prefix}/preference_transformer/antmaze-large-play-v2/num1000", + "antmaze-medium-diverse-v2": f"{prefix}/preference_transformer/antmaze-medium-diverse-v2/num1000", + "antmaze-medium-play-v2": f"{prefix}/preference_transformer/antmaze-medium-play-v2/num1000", "Can-mh": f"{prefix}/preference_transformer/Can/num500_q100", "Can-ph": f"{prefix}/preference_transformer/Can/num100_q50", "Lift-mh": f"{prefix}/preference_transformer/Lift/num500_q100", From 0e5b07cb36bc844609f0d7db3308d9c35115925b Mon Sep 17 00:00:00 2001 From: typoverflow Date: Thu, 1 Aug 2024 17:17:13 +0800 Subject: [PATCH 2/7] update: some configs, also add normalization and substration for antmaze dataset rewards --- scripts/configs/bt_iql/antmaze/default.yaml | 4 +- scripts/configs/ipl_iql/antmaze/default.yaml | 10 ++-- scripts/configs/pt_iql/antmaze.yaml | 12 ++-- wiserl/dataset/d4rl_dataset.py | 62 ++++++++++++++++---- 4 files changed, 63 insertions(+), 25 deletions(-) diff --git a/scripts/configs/bt_iql/antmaze/default.yaml b/scripts/configs/bt_iql/antmaze/default.yaml index a5daac0..e570574 100644 --- a/scripts/configs/bt_iql/antmaze/default.yaml +++ b/scripts/configs/bt_iql/antmaze/default.yaml @@ -1,6 +1,6 @@ algorithm: class: BTIQL - beta: 0.3333 + beta: 0.1 expectile: 0.9 max_exp_clip: 100.0 reward_reg: 0.0 @@ -87,7 +87,7 @@ rm_eval: eval: false rl_eval: function: eval_offline - num_ep: 10 + num_ep: 100 deterministic: true schedulers: diff --git a/scripts/configs/ipl_iql/antmaze/default.yaml b/scripts/configs/ipl_iql/antmaze/default.yaml index 7d07a0d..98a603c 100644 --- a/scripts/configs/ipl_iql/antmaze/default.yaml +++ b/scripts/configs/ipl_iql/antmaze/default.yaml @@ -1,7 +1,7 @@ algorithm: class: IPL_IQL - beta: 0.3333 - expectile: 0.7 + beta: 0.1 + expectile: 0.9 reward_reg: 0.5 reg_replay_weight: 0.5 actor_replay_weight: 0.5 @@ -21,7 +21,7 @@ wandb: entity: null project: null -env: antmaze-large-diverse-v2 +env: antmaze-medium-diverse-v2 env_kwargs: env_wrapper: env_wrapper_kwargs: @@ -50,12 +50,12 @@ network: dataset: - class: IPLComparisonOfflineDataset - env: antmaze-large-diverse-v2 + env: antmaze-medium-diverse-v2 batch_size: 8 mode: human segment_length: null - class: D4RLOfflineDataset - env: antmaze-large-diverse-v2 + env: antmaze-medium-diverse-v2 batch_size: 256 mode: transition reward_normalize: true diff --git a/scripts/configs/pt_iql/antmaze.yaml b/scripts/configs/pt_iql/antmaze.yaml index 0d930f7..f774f90 100644 --- a/scripts/configs/pt_iql/antmaze.yaml +++ b/scripts/configs/pt_iql/antmaze.yaml @@ -1,7 +1,7 @@ algorithm: class: PTIQL - beta: 0.3333 - expectile: 0.7 + beta: 0.1 + expectile: 0.9 max_exp_clip: 100.0 reward_reg: 0.0 use_weighted_sum: true @@ -18,7 +18,7 @@ wandb: entity: null project: null -env: antmaze-large-diverse-v2 +env: antmaze-medium-diverse-v2 env_kwargs: env_wrapper: env_wrapper_kwargs: @@ -56,7 +56,7 @@ network: rm_dataset: - class: IPLComparisonOfflineDataset - env: antmaze-large-diverse-v2 + env: antmaze-medium-diverse-v2 batch_size: 64 segment_length: null mode: human @@ -66,7 +66,7 @@ rm_dataloader: rl_dataset: - class: D4RLOfflineDataset - env: antmaze-large-diverse-v2 + env: antmaze-medium-diverse-v2 batch_size: 256 mode: trajectory segment_length: 20 @@ -87,7 +87,7 @@ rm_eval: function: eval_reward_model eval_dataset_kwargs: class: IPLComparisonOfflineDataset - env: antmaze-large-diverse-v2 + env: antmaze-medium-diverse-v2 batch_size: 32 mode: human eval: false diff --git a/wiserl/dataset/d4rl_dataset.py b/wiserl/dataset/d4rl_dataset.py index b676ab7..7604215 100644 --- a/wiserl/dataset/d4rl_dataset.py +++ b/wiserl/dataset/d4rl_dataset.py @@ -254,23 +254,61 @@ def relabel_reward(self, agent): for t in reversed(range(return_.shape[1]-1)): return_[:, t] += return_[:, t+1] self.data["return"] = return_ - # normalization - prev_return_min, prev_return_max = return_[:, 0].min(), return_[:, 0].max() - max_return = max(abs(return_[:, 0].max()), abs(return_[:, 0].min()), return_[:, 0].max()-return_[:, 0].min(), 1.0) - norm = 1000. / max_return - self.data["reward"] *= norm - self.data["return"] *= norm - print(f"[D4RLOfflineDataset]: return range: [{prev_return_min}, {prev_return_max}], multiplying norm factor {norm}.") + if "antmaze" in self.env_name: + # normalization on antmaze may look weird, we borrow it from preference transformer + # https://github.com/csmile-1006/PreferenceTransformer/blob/f71647bb075c8287e2f26aded78aa8f1ac176eb5/train_offline.py#L78 and + # https://github.com/csmile-1006/PreferenceTransformer/blob/f71647bb075c8287e2f26aded78aa8f1ac176eb5/train_offline.py#L119 + min_return, max_return = self.data["return"][:, 0].min(), self.data["return"][:, 0].max() + norm = 1000. / (max_return - min_return) + self.data["reward"] *= norm + self.data["reward"] -= 1.0 + self.data["reward"] *= self.data["mask"] + + # for i in range(self.data["reward"].shape[0]): + # self.data["reward"][i] -= (1. + self.data["return"][i, 0] / self.traj_len[i]) * self.data["mask"][i] + return_ = self.data["reward"].copy() + for t in reversed(range(return_.shape[1]-1)): + return_[:, t] += return_[:, t+1] + self.data["return"] = return_ + else: + prev_return_min, prev_return_max = return_[:, 0].min(), return_[:, 0].max() + max_return = max(abs(return_[:, 0].max()), abs(return_[:, 0].min()), return_[:, 0].max()-return_[:, 0].min(), 1.0) + norm = 1000. / max_return + self.data["reward"] *= norm + self.data["return"] *= norm + print(f"[D4RLOfflineDataset]: return range: [{prev_return_min}, {prev_return_max}], multiplying norm factor {norm}.") elif self.mode == "transition": ep_reward_ = [] + ep_length_ = [] episode_reward = 0 + episode_length = 0 N = self.data["reward"].shape[0] for i in range(N): episode_reward += self.data["reward"][i] + episode_length += 1 if self.data["end"][i]: ep_reward_.append(episode_reward) - episode_reward = 0 - max_return = max(abs(min(ep_reward_)).item(), abs(max(ep_reward_)).item(), (max(ep_reward_)-min(ep_reward_)).item(), 1.0) - norm = 1000 / max_return - self.data["reward"] *= norm - print(f"[D4RLOfflineDataset]: return range: [{min(ep_reward_)}, {max(ep_reward_)}], multiplying norm factor {norm}.") + ep_length_.append(episode_length) + episode_reward = episode_length = 0 + + if "antmaze" in self.env_name: + # normalization on antmaze may look weird, we borrow it from preference transformer + # https://github.com/csmile-1006/PreferenceTransformer/blob/f71647bb075c8287e2f26aded78aa8f1ac176eb5/train_offline.py#L78 and + # https://github.com/csmile-1006/PreferenceTransformer/blob/f71647bb075c8287e2f26aded78aa8f1ac176eb5/train_offline.py#L119 + min_return, max_return = min(ep_reward_), max(ep_reward_) + # idx = 0 + # for ep_len in ep_length_: + # for l in range(ep_len): + # self.data["reward"][idx] -= min_return / ep_len + # idx += 1 + # assert idx == N + norm = 1000 / (max_return - min_return) + self.data["reward"] *= norm + self.data["reward"] -= 1.0 + self.data["reward"] *= self.data["mask"] + print(f"[D4RLOfflineDataset]: return range: [{min_return}, {max_return}], multiplying norm factor {norm}.") + else: + max_return = max(abs(min(ep_reward_)).item(), abs(max(ep_reward_)).item(), (max(ep_reward_)-min(ep_reward_)).item(), 1.0) + norm = 1000 / max_return + self.data["reward"] *= norm + print(f"[D4RLOfflineDataset]: return range: [{min(ep_reward_)}, {max(ep_reward_)}], multiplying norm factor {norm}.") From 1b160ef871fcd306418f252fbc08238346f19d65 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Thu, 1 Aug 2024 19:17:52 +0800 Subject: [PATCH 3/7] update --- scripts/configs/bt_iql/antmaze/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/bt_iql/antmaze/default.yaml b/scripts/configs/bt_iql/antmaze/default.yaml index e570574..22991fb 100644 --- a/scripts/configs/bt_iql/antmaze/default.yaml +++ b/scripts/configs/bt_iql/antmaze/default.yaml @@ -75,7 +75,7 @@ trainer: rl_steps: 1000000 log_freq: 500 profile_freq: 500 - eval_freq: 5000 + eval_freq: 10000 rm_eval: function: eval_reward_model From 55de17e0319e8eda9aa4a934f844159d4da72153 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Fri, 2 Aug 2024 11:50:30 +0800 Subject: [PATCH 4/7] update --- scripts/configs/bt_iql/antmaze/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/configs/bt_iql/antmaze/default.yaml b/scripts/configs/bt_iql/antmaze/default.yaml index 22991fb..45b6c76 100644 --- a/scripts/configs/bt_iql/antmaze/default.yaml +++ b/scripts/configs/bt_iql/antmaze/default.yaml @@ -84,7 +84,7 @@ rm_eval: env: antmaze-medium-diverse-v2 batch_size: 32 mode: human - eval: false + eval: true rl_eval: function: eval_offline num_ep: 100 From 638800f2a908be6c32f567fb8f7d37336519c261 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Sat, 3 Aug 2024 13:27:08 +0800 Subject: [PATCH 5/7] update: add sigmoid activation --- wiserl/dataset/d4rl_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wiserl/dataset/d4rl_dataset.py b/wiserl/dataset/d4rl_dataset.py index 7604215..9478bf0 100644 --- a/wiserl/dataset/d4rl_dataset.py +++ b/wiserl/dataset/d4rl_dataset.py @@ -297,13 +297,14 @@ def relabel_reward(self, agent): # https://github.com/csmile-1006/PreferenceTransformer/blob/f71647bb075c8287e2f26aded78aa8f1ac176eb5/train_offline.py#L119 min_return, max_return = min(ep_reward_), max(ep_reward_) # idx = 0 + print("max return: ", max_return, " min_return: ", min_return) # for ep_len in ep_length_: # for l in range(ep_len): - # self.data["reward"][idx] -= min_return / ep_len + # self.data["reward"][idx] -= max_return / ep_len # idx += 1 # assert idx == N norm = 1000 / (max_return - min_return) - self.data["reward"] *= norm + # self.data["reward"] *= norm self.data["reward"] -= 1.0 self.data["reward"] *= self.data["mask"] print(f"[D4RLOfflineDataset]: return range: [{min_return}, {max_return}], multiplying norm factor {norm}.") From 6172f956f88ab4866b4c41648fd4885b90120093 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Sat, 3 Aug 2024 14:40:45 +0800 Subject: [PATCH 6/7] update: add hpl config --- scripts/configs/hpl/discrete/antmaze.yaml | 134 ++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 scripts/configs/hpl/discrete/antmaze.yaml diff --git a/scripts/configs/hpl/discrete/antmaze.yaml b/scripts/configs/hpl/discrete/antmaze.yaml new file mode 100644 index 0000000..4394e36 --- /dev/null +++ b/scripts/configs/hpl/discrete/antmaze.yaml @@ -0,0 +1,134 @@ +algorithm: + class: HindsightPreferenceLearning + expectile: 0.9 + beta: 0.1 + max_exp_clip: 100.0 + discount: 0.99 + tau: 0.005 + seq_len: 100 + future_len: 5 # [5, 10, 20] + z_dim: 128 # [128] + prior_sample: 20 + vae_steps: 250000 # [250k, 200k] + reward_steps: 100000 # [20k] + kl_loss_coef: 0.1 # [0.5, 5.0] + kl_balance_coef: 0.5 + reg_coef: 0.0 + discrete: true + discrete_group: 8 + stoc_encoding: true # for hopper-medium-expert, try true + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-medium-play-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + encoder: + embed_dim: 128 + num_layers: 3 + num_heads: 4 + dropout: 0.1 + decoder: + embed_dim: 128 + hidden_dims: [256, 256, 256] # shallower? + ortho_init: true + prior: + hidden_dims: [256, 256] + ortho_init: true + reward: + class: Critic + hidden_dims: [256, 256, 256] # tune + ortho_init: true + reward_act: sigmoid + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: true + logstd_min: -7 + logstd_max: 2 + ortho_init: true + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + ortho_init: true + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256, 256] + ortho_init: true + + +rm_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 64 # [64, 128] + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 8 + mode: human + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 512 + mode: transition +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 512 + mode: transition +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: 350000 + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 10000 # don't do eval + +rm_eval: + function: eval_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 32 + mode: human + eval: true +rl_eval: + function: eval_offline + num_ep: 100 + deterministic: true + +schedulers: + +processor: null + +# finalized From 63b2eee32df3705f0890d59ff7242c0f413bc5f0 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Sat, 3 Aug 2024 16:39:24 +0800 Subject: [PATCH 7/7] update: add early stop --- scripts/configs/bt_iql/antmaze/earlystop.yaml | 103 +++++++++++++ .../hpl/discrete/antmaze_earlystop.yaml | 139 ++++++++++++++++++ wiserl/algorithm/base.py | 42 ++++-- wiserl/trainer/rmb_offline_trainer.py | 31 ++++ wiserl/utils/earlystop.py | 45 ++++++ 5 files changed, 344 insertions(+), 16 deletions(-) create mode 100644 scripts/configs/bt_iql/antmaze/earlystop.yaml create mode 100644 scripts/configs/hpl/discrete/antmaze_earlystop.yaml create mode 100644 wiserl/utils/earlystop.py diff --git a/scripts/configs/bt_iql/antmaze/earlystop.yaml b/scripts/configs/bt_iql/antmaze/earlystop.yaml new file mode 100644 index 0000000..0fea89c --- /dev/null +++ b/scripts/configs/bt_iql/antmaze/earlystop.yaml @@ -0,0 +1,103 @@ +algorithm: + class: BTIQL + beta: 0.1 + expectile: 0.9 + max_exp_clip: 100.0 + reward_reg: 0.0 + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-medium-diverse-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + reward: + class: EnsembleMLP + ensemble_size: 1 + hidden_dims: [256, 256] + reward_act: identity + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256] + reparameterize: false + conditioned_logstd: false + logstd_min: -5 + logstd_max: 2 + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256] + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256] + +rm_dataset: + - class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 64 + segment_length: null + mode: human +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 256 + mode: transition + reward_normalize: true +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: null # use early stop + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 10000 + # early stop + earlystop_tolerance: 5 + earlystop_metric: val_acc + earlystop_mode: max + earlystop_start_step: 0 + +rm_eval: + function: eval_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 32 + mode: human + eval: true +rl_eval: + function: eval_offline + num_ep: 100 + deterministic: true + +schedulers: + actor: + class: CosineAnnealingLR + T_max: 1000000 + +processor: null diff --git a/scripts/configs/hpl/discrete/antmaze_earlystop.yaml b/scripts/configs/hpl/discrete/antmaze_earlystop.yaml new file mode 100644 index 0000000..b623a61 --- /dev/null +++ b/scripts/configs/hpl/discrete/antmaze_earlystop.yaml @@ -0,0 +1,139 @@ +algorithm: + class: HindsightPreferenceLearning + expectile: 0.9 + beta: 0.1 + max_exp_clip: 100.0 + discount: 0.99 + tau: 0.005 + seq_len: 100 + future_len: 5 # [5, 10, 20] + z_dim: 128 # [128] + prior_sample: 20 + vae_steps: 250000 # [250k, 200k] + reward_steps: 100000 # [20k] + kl_loss_coef: 0.1 # [0.5, 5.0] + kl_balance_coef: 0.5 + reg_coef: 0.0 + discrete: true + discrete_group: 8 + stoc_encoding: true # for hopper-medium-expert, try true + rm_label: true + +checkpoint: null +seed: 0 +name: default +debug: false +device: null +wandb: + activate: false + entity: null + project: null + +env: antmaze-medium-play-v2 +env_kwargs: +env_wrapper: +env_wrapper_kwargs: + +optim: + default: + class: Adam + lr: 0.0003 + +network: + encoder: + embed_dim: 128 + num_layers: 3 + num_heads: 4 + dropout: 0.1 + decoder: + embed_dim: 128 + hidden_dims: [256, 256, 256] # shallower? + ortho_init: true + prior: + hidden_dims: [256, 256] + ortho_init: true + reward: + class: Critic + hidden_dims: [256, 256, 256] # tune + ortho_init: true + reward_act: sigmoid + actor: + class: SquashedGaussianActor + hidden_dims: [256, 256, 256] + reparameterize: false + conditioned_logstd: true + logstd_min: -7 + logstd_max: 2 + ortho_init: true + critic: + class: Critic + ensemble_size: 2 + hidden_dims: [256, 256, 256] + ortho_init: true + value: + class: Critic + ensemble_size: 1 + hidden_dims: [256, 256, 256] + ortho_init: true + + +rm_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 64 # [64, 128] + mode: trajectory + segment_length: 100 + padding_mode: none + - class: IPLComparisonOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 8 + mode: human + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 512 + mode: transition +rm_dataloader: + num_workers: 2 + batch_size: null + +rl_dataset: + - class: D4RLOfflineDataset + env: antmaze-medium-play-v2 + batch_size: 512 + mode: transition +rl_dataloader: + num_workers: 2 + batch_size: null + +trainer: + env_freq: null + rm_label: true + rm_steps: null + rl_steps: 1000000 + log_freq: 500 + profile_freq: 500 + eval_freq: 10000 # don't do eval + # early stop + earlystop_tolerance: 5 + earlystop_metric: val_acc + earlystop_mode: max + earlystop_start_step: 100000 + +rm_eval: + function: eval_reward_model + eval_dataset_kwargs: + class: IPLComparisonOfflineDataset + env: antmaze-medium-diverse-v2 + batch_size: 32 + mode: human + eval: true +rl_eval: + function: eval_offline + num_ep: 100 + deterministic: true + +schedulers: + +processor: null + +# finalized diff --git a/wiserl/algorithm/base.py b/wiserl/algorithm/base.py index 678e5a4..0a75121 100644 --- a/wiserl/algorithm/base.py +++ b/wiserl/algorithm/base.py @@ -180,10 +180,7 @@ def setup_schedulers(self, scheduler_kwargs): **kwargs ) - def save(self, path: str, name: str, metadata: Optional[Dict]=None) -> None: - """ - Saves a checkpoint of the model and the optimizers - """ + def state_dict(self): save_dict = {} if len(self.optim) > 0: save_dict["optim"] = {k: v.state_dict() for k, v in self.optim.items()} @@ -196,19 +193,10 @@ def save(self, path: str, name: str, metadata: Optional[Dict]=None) -> None: else: assert isinstance(attr, torch.nn.Parameter), "Can only save Modules or Parameters." save_dict[k] = attr + return save_dict - # Add the metadata - save_dict["metadata"] = {} if metadata is None else metadata - os.makedirs(path, exist_ok=True) - save_path = os.path.join(path, name) - torch.save(save_dict, save_path) - - def load(self, ckpt_path: str, strict: bool=True) -> Dict: - """ - Loads the model and its associated checkpoints. - If we haven't created the optimizers and schedulers, do not load those. - """ - checkpoint = torch.load(ckpt_path, map_location=self.device) + def load_state_dict(self, state_dict, strict=False): + checkpoint = state_dict remaining_checkpoint_keys = set(checkpoint.keys()) # First load everything except for the optim @@ -258,6 +246,28 @@ def load(self, ckpt_path: str, strict: bool=True) -> Dict: return checkpoint["metadata"] + def save(self, path: str, name: str, metadata: Optional[Dict]=None) -> None: + """ + Saves a checkpoint of the model and the optimizers + """ + save_dict = self.state_dict() + + # Add the metadata + save_dict["metadata"] = {} if metadata is None else metadata + os.makedirs(path, exist_ok=True) + save_path = os.path.join(path, name) + torch.save(save_dict, save_path) + + def load(self, ckpt_path: str, strict: bool=True) -> Dict: + """ + Loads the model and its associated checkpoints. + If we haven't created the optimizers and schedulers, do not load those. + """ + checkpoint = torch.load(ckpt_path, map_location=self.device) + self.load_state_dict(checkpoint, strict=strict) + + return checkpoint["metadata"] + def format_batch(self, batches) -> Any: def convert(data): if data.dtype == np.float64: diff --git a/wiserl/trainer/rmb_offline_trainer.py b/wiserl/trainer/rmb_offline_trainer.py index 91525cb..ac8a0c2 100644 --- a/wiserl/trainer/rmb_offline_trainer.py +++ b/wiserl/trainer/rmb_offline_trainer.py @@ -14,6 +14,7 @@ import wiserl.dataset import wiserl.eval from wiserl.trainer.offline_trainer import OfflineTrainer +from wiserl.utils.earlystop import EarlyStopManager class RewardModelBasedOfflineTrainer(OfflineTrainer): @@ -26,6 +27,10 @@ def __init__( rm_dataloader_kwargs: Optional[Sequence[Dict]] = None, rm_steps: int = 1000, rm_eval_kwargs: Optional[dict] = None, + earlystop_tolerance: Optional[int] = None, + earlystop_metric: Optional[str] = None, + earlystop_mode: str="min", + earlystop_start_step: int=0, rl_dataset_kwargs: Optional[Sequence[str]] = None, rl_dataloader_kwargs: Optional[Sequence[Dict]] = None, rl_steps: int = 1000, @@ -60,6 +65,10 @@ def __init__( self.rm_steps = rm_steps self.rl_steps = rl_steps self.rm_label = rm_label + self.earlystop_tolerance = earlystop_tolerance + self.earlystop_metric = earlystop_metric + self.earlystop_mode = earlystop_mode + self.earlystop_start_step = earlystop_start_step self.load_rm_path = load_rm_path self.save_rm_path = save_rm_path # rm & rl datasets, dataloaders, and evals @@ -86,6 +95,21 @@ def train(self): self._rm_datasets = self.setup_datasets(self.rm_dataset_kwargs) self._rm_dataloaders, self._rm_dataloaders_iter = self.setup_dataloaders(self._rm_datasets, self.rm_dataloader_kwargs) + # register early stopping + if self.earlystop_tolerance is not None: + print(f"Registering early stopping, using metric={self.earlystop_metric}, mode={self.earlystop_mode}, tolerance={self.earlystop_tolerance}") + self.earlystop_manager = EarlyStopManager( + self.earlystop_tolerance, + self.earlystop_mode, + ) + self.earlystop_manager.reset() + if self.rm_steps is None: + print("Early stopping is enabled, but rm_steps is not set. Setting rm_steps to 9e9 ... ") + self.rm_steps = int(9e9) + else: + self.earlystop_manager = None + assert self.rm_steps is not None + self.logger.info("Starting pretraining ... ") for step in trange(0, self.rm_steps+1, desc="pretrain"): batches = [next(d) for d in self._rm_dataloaders_iter] @@ -101,6 +125,13 @@ def train(self): self.logger.log_scalars("eval", eval_metrics, step=step) self.algorithm.train() + if step >= self.earlystop_start_step and self.earlystop_manager is not None: + should_stop, best_model, best_metric = self.earlystop_manager.step(self.algorithm, eval_metrics[self.earlystop_metric]) + if should_stop: + self.logger.info(f"Early stopping triggered at step {step}, best metric={best_metric}") + self.algorithm.load_state_dict(best_model) + break + if self.save_rm_path is not None: self.logger.info(f"Saving pretrained model to {self.save_rm_path} ...") self.algorithm.save_pretrain(self.save_rm_path) diff --git a/wiserl/utils/earlystop.py b/wiserl/utils/earlystop.py new file mode 100644 index 0000000..be2cf81 --- /dev/null +++ b/wiserl/utils/earlystop.py @@ -0,0 +1,45 @@ +import torch + + +class EarlyStopManager(): + def __init__( + self, + tolerance: int, + mode: str = "min" + ): + assert mode in {"min", "max"}, f"EarlyStopManager: metric should be either max or min" + self.tolerance = tolerance + self.mode = mode + self.best = None + self.best_metric = None + self.counter = 0 + + self.reset() + + def reset(self): + self.best = None + self.best_metrics = 9e9 if self.mode == "min" else -9e9 + self.counter = 0 + + def step(self, model, metric): + if self.best is None: + self.best = model.state_dict() + self.best_metric = metric + self.counter = 0 + elif self.mode == "min" and metric < self.best_metric: + self.best = model.state_dict() + self.best_metric = metric + self.counter = 0 + elif self.mode == "max" and metric > self.best_metric: + self.best = model.state_dict() + self.best_metric = metric + self.counter = 0 + else: + self.counter += 1 + + if self.counter >= self.tolerance: + should_stop = True + else: + should_stop = False + + return should_stop, self.best, self.best_metric