Skip to content

Commit 66c38d7

Browse files
Merge pull request #119 from rubenlucas93/issue-118/pendulum_problem
Issue 118/pendulum problem
2 parents 63e7b28 + 4080ab6 commit 66c38d7

10 files changed

Lines changed: 832 additions & 0 deletions

File tree

rl_studio/agents/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def __new__(cls, config):
8888
)
8989

9090
return QlearnAutoparkingTrainer(config)
91+
92+
elif agent == AgentsType.PENDULUM.value:
93+
if algorithm == AlgorithmsType.DDPG_TORCH.value:
94+
from rl_studio.agents.pendulum.train_ddpg import (
95+
DDPGPendulumTrainer as PendulumTrainer,
96+
)
97+
return PendulumTrainer(config)
9198
else:
9299
raise NoValidTrainingType(agent)
93100

@@ -143,5 +150,11 @@ def __new__(cls, config):
143150

144151
return MountainCarInferencer(config)
145152

153+
elif agent == AgentsType.PENDULUM.value:
154+
from rl_studio.agents.pendulum.inference_ddpg import (
155+
DDPGPendulumInferencer as PendulumInferencer,
156+
)
157+
158+
return PendulumInferencer(config)
146159
else:
147160
raise NoValidTrainingType(agent)

rl_studio/agents/agents_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ class AgentsType(Enum):
88
MOUNTAIN_CAR = "mountain_car"
99
CARTPOLE = "cartpole"
1010
AUTOPARKING = "autoparking"
11+
PENDULUM = "pendulum"
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import datetime
2+
import time
3+
import random
4+
5+
import gymnasium as gym
6+
import matplotlib.pyplot as plt
7+
from torch.utils import tensorboard
8+
from tqdm import tqdm
9+
import numpy as np
10+
import torch
11+
12+
import logging
13+
14+
from rl_studio.agents.pendulum import utils
15+
from rl_studio.algorithms.ddpg_torch import Actor, Critic, Memory
16+
from rl_studio.visual.ascii.images import JDEROBOT_LOGO
17+
from rl_studio.visual.ascii.text import JDEROBOT, LETS_GO
18+
from rl_studio.agents.pendulum.utils import store_rewards, save_metadata
19+
from rl_studio.wrappers.inference_rlstudio import InferencerWrapper
20+
21+
22+
# # https://github.com/openai/gym/blob/master/gym/core.py
23+
# class NormalizedEnv(gym.ActionWrapper):
24+
# """ Wrap action """
25+
#
26+
# def _action(self, action):
27+
# act_k = (self.action_space.high - self.action_space.low) / 2.
28+
# act_b = (self.action_space.high + self.action_space.low) / 2.
29+
# return act_k * action + act_b
30+
#
31+
# def _reverse_action(self, action):
32+
# act_k_inv = 2. / (self.action_space.high - self.action_space.low)
33+
# act_b = (self.action_space.high + self.action_space.low) / 2.
34+
# return act_k_inv * (action - act_b)
35+
36+
37+
class DDPGPendulumInferencer:
38+
def __init__(self, params):
39+
40+
self.now = datetime.datetime.now()
41+
# self.environment params
42+
self.params = params
43+
self.environment_params = params.environment["params"]
44+
self.env_name = params.environment["params"]["env_name"]
45+
self.config = params.settings["params"]
46+
self.agent_config = params.agent["params"]
47+
48+
if self.config["logging_level"] == "debug":
49+
self.LOGGING_LEVEL = logging.DEBUG
50+
elif self.config["logging_level"] == "error":
51+
self.LOGGING_LEVEL = logging.ERROR
52+
elif self.config["logging_level"] == "critical":
53+
self.LOGGING_LEVEL = logging.CRITICAL
54+
else:
55+
self.LOGGING_LEVEL = logging.INFO
56+
57+
self.RANDOM_PERTURBATIONS_LEVEL = self.environment_params.get("random_perturbations_level", 0)
58+
self.PERTURBATIONS_INTENSITY_STD = self.environment_params.get("perturbations_intensity_std", 0)
59+
self.RANDOM_START_LEVEL = self.environment_params.get("random_start_level", 0)
60+
self.INITIAL_POLE_ANGLE = self.environment_params.get("initial_pole_angle", None)
61+
62+
non_recoverable_angle = self.environment_params[
63+
"non_recoverable_angle"
64+
]
65+
# Unfortunately, max_steps is not working with new_step_api=True and it is not giving any benefit.
66+
# self.env = gym.make(self.env_name, new_step_api=True, random_start_level=random_start_level)
67+
# self.env = NormalizedEnv(gym.make(self.env_name
68+
# # ,random_start_level=self.RANDOM_START_LEVEL, initial_pole_angle=self.INITIAL_POLE_ANGLE,
69+
# # non_recoverable_angle=non_recoverable_angle
70+
# ))
71+
self.env = gym.make(self.env_name, render_mode="human")
72+
self.RUNS = self.environment_params["runs"]
73+
self.UPDATE_EVERY = self.environment_params[
74+
"update_every"
75+
] # How often the current progress is recorded
76+
self.OBJECTIVE_REWARD = self.environment_params[
77+
"objective_reward"
78+
]
79+
self.BLOCKED_EXPERIENCE_BATCH = self.environment_params[
80+
"block_experience_batch"
81+
]
82+
83+
self.losses_list, self.reward_list, self.episode_len_list= (
84+
[],
85+
[],
86+
[],
87+
) # metrics
88+
# recorded for graph
89+
self.batch_size = params.algorithm["params"]["batch_size"]
90+
self.tau = 1e-2
91+
92+
self.max_avg = -1000
93+
94+
self.num_actions = self.env.action_space.shape[0]
95+
96+
inference_file = params.inference["params"]["inference_file"]
97+
self.inferencer = InferencerWrapper("ddpg_torch", inference_file, env=self.env)
98+
99+
def print_init_info(self):
100+
logging.info(JDEROBOT)
101+
logging.info(JDEROBOT_LOGO)
102+
logging.info(f"\t- Start hour: {datetime.datetime.now()}\n")
103+
logging.info(f"\t- self.environment params:\n{self.environment_params}")
104+
105+
def gather_statistics(self, ep_len, episode_rew):
106+
self.reward_list.append(episode_rew)
107+
self.episode_len_list.append(ep_len)
108+
109+
def main(self):
110+
epoch_start_time = datetime.datetime.now()
111+
112+
logs_dir = 'logs/pendulum/ddpg/training/'
113+
logs_file_name = 'logs_file_' + str(self.RANDOM_START_LEVEL) + '_' + str(
114+
self.RANDOM_PERTURBATIONS_LEVEL) + '_' + str(epoch_start_time) \
115+
+ str(self.PERTURBATIONS_INTENSITY_STD) + '.log'
116+
logging.basicConfig(filename=logs_dir + logs_file_name, filemode='a',
117+
level=self.LOGGING_LEVEL,
118+
format='%(name)s - %(levelname)s - %(message)s')
119+
self.print_init_info()
120+
121+
start_time_format = epoch_start_time.strftime("%Y%m%d_%H%M")
122+
123+
logging.info(LETS_GO)
124+
w = tensorboard.SummaryWriter(log_dir=f"{logs_dir}/tensorboard/{start_time_format}")
125+
126+
total_reward_in_epoch = 0
127+
128+
for episode in tqdm(range(self.RUNS)):
129+
state, _ = self.env.reset()
130+
done = False
131+
episode_reward = 0
132+
step = 0
133+
while not done:
134+
step += 1
135+
# if random.uniform(0, 1) < self.RANDOM_PERTURBATIONS_LEVEL:
136+
# perturbation_action = random.randrange(self.env.action_space.n)
137+
# state, done, _, _ = self.env.perturbate(perturbation_action, self.PERTURBATIONS_INTENSITY_STD)
138+
# logging.debug("perturbated in step {} with action {}".format(episode_rew, perturbation_action))
139+
140+
action = self.inferencer.inference(state)
141+
new_state, reward, _, done, _ = self.env.step(action)
142+
state = new_state
143+
episode_reward += reward
144+
total_reward_in_epoch += reward
145+
146+
w.add_scalar("reward/episode_reward", episode_reward, global_step=episode)
147+
148+
self.gather_statistics(step, episode_reward)
149+
150+
# monitor progress
151+
if (episode + 1) % self.UPDATE_EVERY == 0:
152+
time_spent = datetime.datetime.now() - epoch_start_time
153+
epoch_start_time = datetime.datetime.now()
154+
updates_message = 'Run: {0} Average: {1} time spent {2}'.format(episode,
155+
total_reward_in_epoch / self.UPDATE_EVERY,
156+
str(time_spent))
157+
logging.info(updates_message)
158+
print(updates_message)
159+
total_reward_in_epoch=0
160+
base_file_name = f'_rewards_rsl-{self.RANDOM_START_LEVEL}_rpl-{self.RANDOM_PERTURBATIONS_LEVEL}_pi-{self.PERTURBATIONS_INTENSITY_STD}'
161+
file_path = f'{logs_dir}{datetime.datetime.now()}_{base_file_name}.pkl'
162+
store_rewards(self.reward_list, file_path)
163+
plt.plot(self.reward_list)
164+
plt.legend("reward per episode")
165+
plt.show()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
gym==0.26.2
2+
gymnasium==0.27.0
3+
markdownTable==6.0.0
4+
matplotlib==3.3.2
5+
numpy==1.17.4
6+
torch==1.12.1
7+
tqdm==4.64.0

0 commit comments

Comments
 (0)