forked from joshuachen6/KrishLib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
108 lines (92 loc) · 3.05 KB
/
train.py
File metadata and controls
108 lines (92 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from gymnasium.envs.registration import register
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
import numpy as np
import random
from game import Game
from PIL import Image
import multiprocessing as mp
from env import MotionProfilePacman # Or whatever your class is
def test_spawn():
try:
ctx = mp.get_context("spawn")
p = ctx.Process(target=lambda: print("Success!"))
p.start()
p.join()
print("Subprocess created successfully.")
except Exception as e:
print(f"Subprocess failed: {e}")
register(
id="MotionProfilePacman-v1",
entry_point="env:MotionProfilePacman",
max_episode_steps=300,
)
if __name__ == "__main__":
# # env = gym.make('MotionProfilePacman-v1',render_mode="human")
# #
# # obs = env.reset()
# #
# # for i in range(1000):
# # action = random.choice(range(5))
# # obs, reward, terminated, something, info = env.step(action)
# # if i == 500:
# # img = Image.fromarray(obs)
# # img.save("last_frame.png")
# # if terminated:
# # print(terminated)
# # obs = env.reset()
# env = make_vec_env(
# "MotionProfilePacman-v1",
# n_envs=8,
# # env_kwargs={"render_mode": "human"},
# vec_env_cls=DummyVecEnv,
# )
# model = PPO(
# "MultiInputPolicy", env, device="cpu", verbose=1, tensorboard_log="tensorboard"
# ) # default policy is "MlpPolicy"
# model.learn(total_timesteps=int(1e4), log_interval=4)
# model.save("ppo_pacbot")
# del model # remove to demonstrate saving and loading
# model = PPO.load("ppo_pacbot")
# env = make_vec_env(
# "MotionProfilePacman-v1",
# n_envs=1,
# # env_kwargs={"render_mode": "human"},
# vec_env_cls=DummyVecEnv,
# )
env = make_vec_env("MotionProfilePacman-v1", n_envs=1, vec_env_cls=DummyVecEnv)
# policy_kwargs = dict(net_arch=[256, 256])
model = DQN(
"MultiInputPolicy",
env,
exploration_fraction=0.5,
exploration_final_eps=0.05,
verbose=1,
learning_rate=1e-3,
buffer_size=50000,
# exploration_fraction=0.1,
tensorboard_log="tensorboard"
)
model.learn(total_timesteps=int(1e6))
model.save("ppo_pacbot")
env = make_vec_env(
"MotionProfilePacman-v1",
n_envs=1,
env_kwargs={"render_mode": "human"},
vec_env_cls=DummyVecEnv,
)
obs = env.reset()
# action = np.array([0])
# obs, reward, terminated, truncated, info = env.step(action)
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
if done[0]:
print("Episode finished!")
obs = env.reset()
# if terminated or truncated:
# print(terminated)
# obs = env.reset()