Skip to content

Commit d9241e6

Browse files
authored
Merge branch 'PufferAI:4.0' into 4.0
2 parents e59b6a7 + 91e4ce9 commit d9241e6

16 files changed

Lines changed: 564 additions & 300 deletions

File tree

cache_data.py

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#'impulse_wars',
1313
#'pacman',
1414
#'tetris',
15-
'g2048',
15+
#'g2048',
1616
#'moba',
17-
'pong',
17+
#'pong',
1818
#'tower_climb',
19-
'grid',
20-
'nmmo3',
19+
#'grid',
20+
#'nmmo3',
2121
#'snake',
2222
#'tripletriad'
2323
])
@@ -38,7 +38,7 @@
3838
'train/eps',
3939
'train/prio_alpha',
4040
'train/prio_beta0',
41-
#'train/horizon',
41+
'train/horizon',
4242
'train/replay_ratio',
4343
'train/minibatch_size',
4444
'policy/hidden_size',
@@ -65,22 +65,21 @@ def pareto_idx(steps, costs, scores):
6565

6666
return idxs
6767

68-
def load_sweep_data(path):
68+
def cached_load(path, env_name, cache):
6969
data = {}
70-
sweep_metadata = {}
7170
num_metrics = 0
7271
for fpath in glob.glob(path):
73-
if 'cache.json' in fpath:
74-
continue
75-
76-
with open(fpath, 'r') as f:
77-
try:
78-
exp = json.load(f)
79-
except json.decoder.JSONDecodeError:
80-
print(f'Skipping {fpath}')
81-
continue
72+
if fpath in cache:
73+
exp = cache[fpath]
74+
else:
75+
with open(fpath, 'r') as f:
76+
try:
77+
exp = json.load(f)
78+
except json.decoder.JSONDecodeError:
79+
print(f'Skipping {fpath}')
80+
continue
8281

83-
sweep_metadata = exp.pop('sweep')
82+
cache[fpath] = exp
8483

8584
data_len = len(exp['metrics']['agent_steps'])
8685
if data_len > 100:
@@ -91,7 +90,7 @@ def load_sweep_data(path):
9190
num_metrics = len(exp['metrics'])
9291

9392
skip = False
94-
metrics = exp.pop('metrics')
93+
metrics = exp['metrics']
9594

9695
if len(metrics) != num_metrics:
9796
print(f'Skipping {fpath} (num_metrics={len(metrics)} != {num_metrics})')
@@ -120,62 +119,77 @@ def load_sweep_data(path):
120119
breakpoint()
121120
pass
122121

122+
sweep_metadata = exp['sweep']
123+
123124
for k, v in pufferlib.unroll_nested_dict(exp):
124125
if k not in data:
125126
data[k] = []
126127

127128
data[k].append([v]*n)
128129

130+
for hyper in HYPERS:
131+
prefix, suffix = hyper.split('/')
132+
if prefix not in sweep_metadata:
133+
continue
134+
135+
group = sweep_metadata[prefix]
136+
if suffix not in group:
137+
continue
138+
139+
param = group[suffix]
140+
141+
key = f'{prefix}/{suffix}_norm'
142+
if key not in data:
143+
data[key] = []
144+
145+
mmin = param['min']
146+
mmax = param['max']
147+
dist = param['distribution']
148+
val = exp[prefix][suffix]
149+
150+
if 'log' in dist or 'pow2' in dist:
151+
mmin = np.log(mmin)
152+
mmax = np.log(mmax)
153+
val = np.log(val)
154+
155+
norm = (val - mmin) / (mmax - mmin)
156+
data[key].append([norm]*n)
157+
129158
for k, v in data.items():
130159
data[k] = [item for sublist in v for item in sublist]
131160

132-
#steps = data['agent_steps']
133-
#costs = data['uptime']
134-
#scores = data['env/score']
135-
#idxs = pareto_idx(steps, costs, scores)
136161
# Filter to pareto
137-
#for k in data:
138-
# data[k] = [data[k][i] for i in idxs]
162+
steps = data['agent_steps']
163+
costs = data['uptime']
164+
scores = data['env/score']
165+
idxs = pareto_idx(steps, costs, scores)
166+
for k in data:
167+
data[k] = [data[k][i] for i in idxs]
139168

140169
data['sweep'] = sweep_metadata
141170
return data
142171

143-
def cached_sweep_load(path, env_name):
144-
cache_file = os.path.join(path, 'c_cache.json')
145-
if not os.path.exists(cache_file):
146-
data = load_sweep_data(os.path.join(path, '*.json'))
147-
with open(cache_file, 'w') as f:
148-
json.dump(data, f)
149-
150-
with open(cache_file, 'r') as f:
151-
data = json.load(f)
152-
153-
print(f'Loaded {env_name}')
154-
return data
155-
156172
def compute_tsne():
157173
all_data = {}
158174
normed = {}
159175

176+
cache = {}
177+
cache_file = os.path.join('cache.json')
178+
if os.path.exists(cache_file):
179+
cache = json.load(open(cache_file, 'r'))
180+
160181
for env in env_names:
161-
env_data = cached_sweep_load(f'logs/puffer_{env}', env)
162-
sweep_metadata = env_data.pop('sweep')
163-
all_data[env] = env_data
182+
all_data[env] = cached_load(f'logs/puffer_{env}/*.json', env, cache)
164183

184+
with open(cache_file, 'w') as f:
185+
json.dump(cache, f)
186+
187+
for env in env_names:
188+
env_data = all_data[env]
165189
normed_env = []
166190
for key in HYPERS:
167-
prefix, suffix = key.split('/')
168-
mmin = sweep_metadata[prefix][suffix]['min']
169-
mmax = sweep_metadata[prefix][suffix]['max']
170-
dat = np.array(env_data[key])
171-
172-
dist = sweep_metadata[prefix][suffix]['distribution']
173-
if 'log' in dist or 'pow2' in dist:
174-
mmin = np.log(mmin)
175-
mmax = np.log(mmax)
176-
dat = np.log(dat)
177-
178-
normed_env.append((dat - mmin) / (mmax - mmin))
191+
norm_key = f'{key}_norm'
192+
normed_env.append(np.array(env_data[norm_key]))
179193

180194
normed[env] = np.stack(normed_env, axis=1)
181195

@@ -192,7 +206,6 @@ def compute_tsne():
192206
row = 0
193207
for env in env_names:
194208
sz = len(all_data[env]['agent_steps'])
195-
#all_data[env] = {k: v for k, v in all_data[env].items()}
196209
if reduced is not None:
197210
all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
198211
all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
@@ -203,7 +216,7 @@ def compute_tsne():
203216
row += sz
204217
print(f'Env {env} has {sz} points')
205218

206-
json.dump(all_data, open('all_cache.json', 'w'))
219+
json.dump(all_data, open('pufferlib/ocean/constellation/default.json', 'w'))
207220

208221
if __name__ == '__main__':
209222
compute_tsne()

pufferlib/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
original_stderr = sys.stderr
1919
sys.stdout = open(os.devnull, 'w')
2020
sys.stderr = open(os.devnull, 'w')
21-
try:
22-
import gymnasium
23-
import pygame
24-
except ImportError:
25-
pass
21+
#try:
22+
# import gymnasium
23+
# import pygame
24+
#except ImportError:
25+
# pass
2626
sys.stdout.close()
2727
sys.stderr.close()
2828
sys.stdout = original_stdout

pufferlib/config/default.ini

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ max = 1e-4
217217
scale = auto
218218

219219
[sweep.train.prio_alpha]
220-
distribution = logit_normal
221-
min = 0.1
222-
max = 0.99
220+
distribution = uniform
221+
min = 0.0
222+
max = 1.0
223223
scale = auto
224224

225225
[sweep.train.prio_beta0]
226-
distribution = logit_normal
227-
min = 0.1
228-
max = 0.99
226+
distribution = uniform
227+
min = 0.0
228+
max = 1.0
229229
scale = auto

pufferlib/ocean/breakout/binding.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#define NUM_ATNS 1
44
#define ACT_SIZES {3}
55
#define OBS_TENSOR_T FloatTensor
6-
#define ACT_TYPE DOUBLE
76

87
#define Env Breakout
98
#include "vecenv.h"

pufferlib/ocean/breakout/breakout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ typedef struct Breakout {
4040
Client* client;
4141
Log log;
4242
float* observations;
43-
double* actions;
43+
float* actions;
4444
float* rewards;
4545
float* terminals;
4646
int num_agents;
@@ -121,7 +121,7 @@ void init(Breakout* env) {
121121
void allocate(Breakout* env) {
122122
init(env);
123123
env->observations = (float*)calloc(11 + env->num_bricks, sizeof(float));
124-
env->actions = (double*)calloc(1, sizeof(double));
124+
env->actions = (float*)calloc(1, sizeof(float));
125125
env->rewards = (float*)calloc(1, sizeof(float));
126126
env->terminals = (float*)calloc(1, sizeof(float));
127127
}

pufferlib/ocean/constellation/constellation.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ int main(void) {
889889
"train/eps",
890890
"train/prio_alpha",
891891
"train/prio_beta0",
892-
//"train/horizon",
892+
"train/horizon",
893893
"train/replay_ratio",
894894
"train/minibatch_size",
895895
"policy/hidden_size",
@@ -978,14 +978,14 @@ int main(void) {
978978
int fig_env_idx = 0;
979979
bool fig_env_active = false;
980980
bool fig_x_active = false;
981-
int fig_x_idx = 0;
981+
int fig_x_idx = 1;
982982
bool fig_xscale_active = false;
983983
int fig_xscale_idx = 0;
984984
bool fig_y_active = false;
985985
int fig_y_idx = 2;
986986
bool fig_yscale_active = false;
987987
bool fig_z_active = false;
988-
int fig_z_idx = 1;
988+
int fig_z_idx = 0;
989989
bool fig_zscale_active = false;
990990
int fig_zscale_idx = 0;
991991
int fig_color_idx = 0;

pufferlib/ocean/torch.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,26 +72,43 @@ def __init__(self, env, hidden_size=512, output_size=512, **kwargs):
7272
self.multihot_dim = self.factors.sum()
7373
self.is_continuous = False
7474

75+
#self.map_2d = nn.Sequential(
76+
# pufferlib.pytorch.layer_init(nn.Conv2d(self.multihot_dim, 128, 5, stride=3)),
77+
# nn.ReLU(),
78+
# pufferlib.pytorch.layer_init(nn.Conv2d(128, 128, 3, stride=1)),
79+
# nn.Flatten(),
80+
#)
81+
7582
self.map_2d = nn.Sequential(
76-
pufferlib.pytorch.layer_init(nn.Conv2d(self.multihot_dim, 128, 5, stride=3)),
83+
nn.Conv2d(self.multihot_dim, 128, 5, stride=3),
7784
nn.ReLU(),
78-
pufferlib.pytorch.layer_init(nn.Conv2d(128, 128, 3, stride=1)),
85+
nn.Conv2d(128, 128, 3, stride=1),
7986
nn.Flatten(),
8087
)
8188

89+
8290
self.player_discrete_encoder = nn.Sequential(
8391
nn.Embedding(128, 32),
8492
nn.Flatten(),
8593
)
94+
95+
#self.proj = nn.Sequential(
96+
# pufferlib.pytorch.layer_init(nn.Linear(1817, hidden_size)),
97+
# nn.ReLU(),
98+
#)
99+
86100
self.proj = nn.Sequential(
87-
pufferlib.pytorch.layer_init(nn.Linear(1817, hidden_size)),
101+
nn.Linear(1817, hidden_size),
88102
nn.ReLU(),
89103
)
90104

91-
self.layer_norm = nn.LayerNorm(hidden_size)
92-
self.actor = pufferlib.pytorch.layer_init(
93-
nn.Linear(output_size, self.num_actions), std=0.01)
94-
self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1)
105+
#self.layer_norm = nn.LayerNorm(hidden_size)
106+
#self.actor = pufferlib.pytorch.layer_init(
107+
# nn.Linear(output_size, self.num_actions), std=0.01)
108+
#self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=0.01)
109+
110+
self.actor = nn.Linear(output_size, self.num_actions)
111+
self.value_fn = nn.Linear(output_size, 1)
95112

96113
def forward(self, x, state=None):
97114
hidden = self.encode_observations(x)
@@ -120,7 +137,7 @@ def encode_observations(self, observations, state=None):
120137
return obs
121138

122139
def decode_actions(self, flat_hidden):
123-
flat_hidden = self.layer_norm(flat_hidden)
140+
#flat_hidden = self.layer_norm(flat_hidden)
124141
action = self.actor(flat_hidden)
125142
value = self.value_fn(flat_hidden)
126143
return action, value

pufferlib/src/bindings.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void py_puff_advantage(
218218
kernel<<<blocks, 256>>>(
219219
(const precision_t*)values_ptr, (const precision_t*)rewards_ptr,
220220
(const precision_t*)dones_ptr, (const precision_t*)importance_ptr,
221-
(float*)advantages_ptr,
221+
(precision_t*)advantages_ptr,
222222
gamma, lambda, rho_clip, c_clip, num_steps, horizon);
223223
}
224224

@@ -453,10 +453,6 @@ PYBIND11_MODULE(_C, m) {
453453
.def("__repr__", [](const PrecisionTensor& t) { return std::string(puf_repr(&t)); })
454454
.def("ndim", [](const PrecisionTensor& t) { return ndim(t.shape); })
455455
.def("numel", [](const PrecisionTensor& t) { return numel(t.shape); });
456-
py::class_<DoubleTensor>(m, "DoubleTensor")
457-
.def("__repr__", [](const DoubleTensor& t) { return std::string(puf_repr(&t)); })
458-
.def("ndim", [](const DoubleTensor& t) { return ndim(t.shape); })
459-
.def("numel", [](const DoubleTensor& t) { return numel(t.shape); });
460456
py::class_<FloatTensor>(m, "FloatTensor")
461457
.def("__repr__", [](const FloatTensor& t) { return std::string(puf_repr(&t)); })
462458
.def("ndim", [](const FloatTensor& t) { return ndim(t.shape); })

0 commit comments

Comments
 (0)