-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathuse_purejaxrl.py
More file actions
278 lines (229 loc) · 10.7 KB
/
use_purejaxrl.py
File metadata and controls
278 lines (229 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
'''Train HighJax with PureJaxRL's PPO pipeline.
PureJaxRL (https://github.com/luchris429/purejaxrl) is not a library you
import — it's a collection of reference training scripts for end-to-end
JIT-compiled RL in JAX. This example reproduces PureJaxRL's core PPO loop
with HighJax as a drop-in gymnax environment.
The only HighJax-specific line is the env creation. Everything else is
standard PureJaxRL code.
Requires: pip install distrax
'''
from __future__ import annotations
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from typing import NamedTuple
import distrax
import highjax
# --- PureJaxRL network (same as purejaxrl/ppo.py) ---
class ActorCritic(nn.Module):
action_dim: int
@nn.compact
def __call__(self, x):
actor = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)),
bias_init=constant(0.0))(x)
actor = nn.tanh(actor)
actor = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)),
bias_init=constant(0.0))(actor)
actor = nn.tanh(actor)
actor = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01),
bias_init=constant(0.0))(actor)
pi = distrax.Categorical(logits=actor)
critic = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)),
bias_init=constant(0.0))(x)
critic = nn.tanh(critic)
critic = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)),
bias_init=constant(0.0))(critic)
critic = nn.tanh(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0),
bias_init=constant(0.0))(critic)
return pi, jnp.squeeze(critic, axis=-1)
class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
# --- PureJaxRL observation wrapper ---
class FlattenObsWrapper:
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
def observation_space(self, params):
import gymnasium
orig = self._env.observation_space(params)
flat_size = 1
for s in orig.shape:
flat_size *= s
return gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(flat_size,))
def reset(self, key, params=None):
obs, state = self._env.reset(key, params)
return jnp.reshape(obs, (-1,)), state
def step(self, key, state, action, params=None):
obs, state, reward, done, info = self._env.step(
key, state, action, params)
return jnp.reshape(obs, (-1,)), state, reward, done, info
# --- PureJaxRL training function (adapted from purejaxrl/ppo.py) ---
def make_train(config, env, env_params):
config = {**config}
config['MINIBATCH_SIZE'] = (
config['NUM_ENVS'] * config['NUM_STEPS'] // config['NUM_MINIBATCHES']
)
def train(rng):
network = ActorCritic(env.action_space(env_params).n)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
tx = optax.chain(
optax.clip_by_global_norm(config['MAX_GRAD_NORM']),
optax.adam(config['LR'], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply, params=network_params, tx=tx,
)
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config['NUM_ENVS'])
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(
reset_rng, env_params)
def _update_step(runner_state, unused):
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config['NUM_ENVS'])
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0, None)
)(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs)
runner_state = (train_state, env_state, obsv, rng)
return runner_state, transition
runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config['NUM_STEPS'])
train_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done, transition.value, transition.reward)
delta = (reward + config['GAMMA'] * next_value
* (1 - done) - value)
gae = (delta + config['GAMMA'] * config['GAE_LAMBDA']
* (1 - done) * gae)
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch, reverse=True, unroll=16)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info
def _loss_fn(params, traj_batch, gae, targets):
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config['CLIP_EPS'], config['CLIP_EPS'])
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(
value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(
value_losses, value_losses_clipped).mean()
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = jnp.clip(
ratio, 1.0 - config['CLIP_EPS'],
1.0 + config['CLIP_EPS']) * gae
loss_actor = -jnp.minimum(
loss_actor1, loss_actor2).mean()
entropy = pi.entropy().mean()
total_loss = (loss_actor
+ config['VF_COEF'] * value_loss
- config['ENT_COEF'] * entropy)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, traj_batch, advantages, targets, rng = (
update_state)
rng, _rng = jax.random.split(rng)
batch_size = (config['MINIBATCH_SIZE']
* config['NUM_MINIBATCHES'])
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config['NUM_MINIBATCHES'], -1]
+ list(x.shape[1:])),
shuffled_batch)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches)
update_state = (
train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
update_state = (
train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config['UPDATE_EPOCHS'])
train_state = update_state[0]
metric = traj_batch.reward.mean()
rng = update_state[-1]
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state, metrics = jax.lax.scan(
_update_step, runner_state, None, config['NUM_UPDATES'])
return runner_state, metrics
return train
def main():
# This is the only HighJax-specific line. Replace with gymnax.make()
# for any other gymnax environment — the rest is identical.
env, env_params = highjax.make('highjax-v0', n_npcs=5)
env = FlattenObsWrapper(env)
config = {
'LR': 2.5e-4,
'NUM_ENVS': 32,
'NUM_STEPS': 40,
'TOTAL_TIMESTEPS': 40_960,
'UPDATE_EPOCHS': 4,
'NUM_MINIBATCHES': 4,
'GAMMA': 0.99,
'GAE_LAMBDA': 0.95,
'CLIP_EPS': 0.2,
'ENT_COEF': 0.01,
'VF_COEF': 0.5,
'MAX_GRAD_NORM': 0.5,
}
config['NUM_UPDATES'] = (
config['TOTAL_TIMESTEPS'] // config['NUM_STEPS'] // config['NUM_ENVS']
)
print(f'Training with PureJaxRL PPO on HighJax')
print(f' {config["NUM_ENVS"]} envs, {config["NUM_STEPS"]} steps, '
f'{config["NUM_UPDATES"]} updates')
train_fn = make_train(config, env, env_params)
train_jit = jax.jit(train_fn)
rng = jax.random.PRNGKey(0)
runner_state, metrics = train_jit(rng)
print(f' Mean reward per update: {metrics}')
print(f' Training steps completed: {runner_state[0].step}')
print('Done.')
if __name__ == '__main__':
main()