-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexperiment_or.py
More file actions
36 lines (30 loc) · 1.27 KB
/
experiment_or.py
File metadata and controls
36 lines (30 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
Experiment for task collect blue OR purple objects
"""
import torch
from gym.wrappers import Monitor
from dqn import ComposedDQN, FloatTensor, get_action
from trainer import load
from gym_repoman.envs import CollectEnv
from wrappers import WarpFrame, MaxLength
if __name__ == '__main__':
max_episodes = 50000
max_trajectory = 50
task = MaxLength(WarpFrame(CollectEnv(goal_condition=lambda x: x.colour == 'blue' or x.colour == 'purple')),
max_trajectory)
dqn_blue = load('./models/blue/model.dqn', task)
dqn_purple = load('./models/purple/model.dqn', task)
dqn_composed = ComposedDQN([dqn_blue, dqn_purple], [1, 1])
for dqn, name in [(dqn_blue, 'blue'), (dqn_purple, 'purple'), (dqn_composed, 'composed')]:
env = Monitor(task, './experiment_or/' + name + '/', video_callable=False, force=True)
for episode in range(max_episodes):
if episode % 1000 == 0:
print(episode)
obs = env.reset()
for _ in range(max_trajectory):
obs = torch.from_numpy(obs).type(FloatTensor).unsqueeze(0)
action = get_action(dqn, obs)
obs, reward, done, _ = env.step(action)
env.render()
if done:
break