-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
37 lines (26 loc) · 1.12 KB
/
train.py
File metadata and controls
37 lines (26 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/python
import Enviroment, Policy, numpy
import matplotlib
matplotlib.rcParams["backend"] = "TkAgg"
import matplotlib.pyplot as plt
if __name__ == "__main__":
try:
reward_history = numpy.load("reward_history.npy").tolist()
except IOError:
reward_history = []
try:
#TODO: Definir as dimensoes da rede a partir do Enviroment
policy = Policy.ACGradientPolicy(action_dimension=3, obsv_dimension=15)
trajectory_sampler = Enviroment.TrajectorySampler(policy=policy)
for episode in range(0, 200000):
print("Ep %d" %len( reward_history ))
observations, rewards, actions = trajectory_sampler.generate_trajectorys()
policy.learn(observations, actions, rewards)
r = numpy.array( rewards[:-1] ).sum()
reward_history.append(r)
print("Episode reward: %s\nLast 10 mean: %s" %(r, numpy.mean(numpy.array( rewards ))))
numpy.save("reward_history", numpy.array(reward_history))
except KeyboardInterrupt:
plt.plot(reward_history)
plt.title("Rewards")
plt.show()