-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtrain_regress.py
More file actions
129 lines (106 loc) · 4.38 KB
/
train_regress.py
File metadata and controls
129 lines (106 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLI script to regress a model onto another, pre-loaded model."""
import functools
from typing import Any, Dict, Mapping, Optional
import sacred
from evaluating_rewards import datasets
from evaluating_rewards.rewards import base, comparisons
from evaluating_rewards.scripts import regress_utils, script_utils
train_regress_ex = sacred.Experiment("train_regress")
@train_regress_ex.config
def default_config():
"""Default configuration values."""
locals().update(**regress_utils.DEFAULT_CONFIG)
checkpoint_interval = 50 # save every checkpoint_interval epochs
dataset_factory = datasets.transitions_factory_from_serialized_policy
dataset_factory_kwargs = {}
# Model to train and hyperparameters
model_reward_type = base.MLPRewardModel
total_timesteps = 10e6
batch_size = 4096
learning_rate = 2e-2
_ = locals() # quieten flake8 unused variable warning
del _
@train_regress_ex.config
def default_kwargs(dataset_factory, dataset_factory_kwargs):
"""Sets dataset_factory_kwargs to defaults when dataset_factory not overridden."""
# TODO(): remove this function when Sacred issue #238 is fixed
if ( # pylint:disable=comparison-with-callable
dataset_factory == datasets.transitions_factory_from_serialized_policy
and not dataset_factory_kwargs
):
dataset_factory_kwargs = dict(policy_type="random", policy_path="dummy")
_ = locals() # quieten flake8 unused variable warning
del _
FAST_CONFIG = dict(total_timesteps=8192)
# Duplicate to have consistent interface
train_regress_ex.add_named_config("test", FAST_CONFIG)
train_regress_ex.add_named_config("fast", FAST_CONFIG)
@train_regress_ex.named_config
def dataset_random_transition():
"""Randomly samples state and action and computes next state from dynamics."""
dataset_factory = ( # noqa: F841 pylint:disable=unused-variable
datasets.transitions_factory_from_random_model
)
script_utils.add_logging_config(train_regress_ex, "train_regress")
@train_regress_ex.main
def train_regress(
_seed: int, # pylint:disable=invalid-name
# Dataset
env_name: str,
discount: float,
dataset_factory: datasets.TransitionsFactory,
dataset_factory_kwargs: Dict[str, Any],
# Target specification
target_reward_type: str,
target_reward_path: str,
# Model parameters
model_reward_type: regress_utils.EnvRewardFactory,
total_timesteps: int,
batch_size: int,
learning_rate: float,
# Logging
checkpoint_interval: int,
log_dir: str,
) -> Mapping[str, Any]:
"""Entry-point into script to regress source onto target reward model."""
with dataset_factory(env_name, seed=_seed, **dataset_factory_kwargs) as dataset_generator:
make_source = functools.partial(regress_utils.make_model, model_reward_type)
def make_trainer(model, model_scope, target):
del model_scope
return comparisons.RegressModel(model, target, learning_rate=learning_rate)
def do_training(target, trainer, callback: Optional[base.Callback]):
del target
return trainer.fit(
dataset=dataset_generator,
total_timesteps=total_timesteps,
batch_size=batch_size,
callback=callback,
)
return regress_utils.regress(
seed=_seed,
env_name=env_name,
discount=discount,
make_source=make_source,
source_init=True,
make_trainer=make_trainer,
do_training=do_training,
target_reward_type=target_reward_type,
target_reward_path=target_reward_path,
log_dir=log_dir,
checkpoint_interval=checkpoint_interval,
)
if __name__ == "__main__":
script_utils.experiment_main(train_regress_ex, "train_regress")