Skip to content

Commit 4d2787f

Browse files
committed
100x data load speed, severl fixes
1 parent f2008d2 commit 4d2787f

4 files changed

Lines changed: 124 additions & 227 deletions

File tree

build.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ CLANG_WARN="\
3434
-Wno-incompatible-pointer-types-discards-qualifiers \
3535
-Wno-error=array-parameter"
3636

37+
PLATFORM="$(uname -s)"
38+
3739
if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then
3840
CLANG_OPT="-g -O0 $CLANG_WARN"
3941
NVCC_OPT="-O0 -g"
@@ -48,8 +50,6 @@ fi
4850
# ============================================================================
4951
# Platform + dependencies
5052
# ============================================================================
51-
52-
PLATFORM="$(uname -s)"
5353
if [ -d "ocean/$ENV" ]; then
5454
SRC_DIR="ocean/$ENV"
5555
elif [ -d "$ENV" ]; then

constellation/cache_data.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import os
66

77
env_names = sorted([
8-
#'breakout',
8+
'breakout',
99
#'impulse_wars',
1010
#'pacman',
1111
#'tetris',
1212
#'g2048',
1313
#'moba',
14-
#'pong',
14+
'pong',
1515
#'tower_climb',
1616
#'grid',
1717
'nmmo3',
@@ -35,7 +35,7 @@
3535
'train/eps',
3636
'train/prio_alpha',
3737
'train/prio_beta0',
38-
#'train/horizon',
38+
'train/horizon',
3939
'train/replay_ratio',
4040
'train/minibatch_size',
4141
'policy/hidden_size',
@@ -101,10 +101,6 @@ def cached_load(path, env_name, cache):
101101
for k in list(exp['metrics'].keys()):
102102
if 'loss' in k:
103103
del exp['metrics'][k]
104-
data_len = len(exp['metrics']['agent_steps'])
105-
if data_len > 100:
106-
print(f'Skipping {fpath} (len={data_len})')
107-
continue
108104

109105
if num_metrics == 0:
110106
num_metrics = len(exp['metrics'])
@@ -114,7 +110,6 @@ def cached_load(path, env_name, cache):
114110
metrics = exp['metrics']
115111

116112
if len(metrics) != num_metrics:
117-
breakpoint()
118113
print(f'Skipping {fpath} (num_metrics={len(metrics)} != {num_metrics})')
119114
continue
120115

@@ -132,7 +127,6 @@ def cached_load(path, env_name, cache):
132127
break
133128

134129
if skip:
135-
breakpoint()
136130
print(f'Skipping {fpath} (bad data)')
137131
continue
138132

@@ -151,31 +145,34 @@ def cached_load(path, env_name, cache):
151145

152146
for hyper in HYPERS:
153147
prefix, suffix = hyper.split('/')
154-
if prefix not in sweep_metadata:
155-
continue
148+
#if prefix not in sweep_metadata:
149+
# continue
156150

157151
group = sweep_metadata[prefix]
158-
if suffix not in group:
159-
continue
152+
#if suffix not in group:
153+
# continue
160154

161-
param = group[suffix]
162155

163156
key = f'{prefix}/{suffix}_norm'
164157
if key not in data:
165158
data[key] = []
166159

167-
mmin = param['min']
168-
mmax = param['max']
169-
dist = param['distribution']
170-
val = exp[prefix][suffix]
160+
if suffix in group:
161+
param = group[suffix]
162+
mmin = param['min']
163+
mmax = param['max']
164+
dist = param['distribution']
165+
val = exp[prefix][suffix]
171166

172-
if 'log' in dist or 'pow2' in dist:
173-
mmin = np.log(mmin)
174-
mmax = np.log(mmax)
175-
val = np.log(val)
167+
if 'log' in dist or 'pow2' in dist:
168+
mmin = np.log(mmin)
169+
mmax = np.log(mmax)
170+
val = np.log(val)
176171

177-
norm = (val - mmin) / (mmax - mmin)
178-
data[key].append([norm]*n)
172+
norm = (val - mmin) / (mmax - mmin)
173+
data[key].append([norm]*n)
174+
else:
175+
data[key].append([1]*n)
179176

180177
for k, v in data.items():
181178
data[k] = [item for sublist in v for item in sublist]
@@ -255,26 +252,41 @@ def compute_tsne():
255252
row = 0
256253
for env in env_names:
257254
sz = len(all_data[env]['agent_steps'])
255+
all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
256+
all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
257+
258+
'''
258259
if reduced is not None:
259260
all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
260261
all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
261262
else:
262263
all_data[env]['tsne1'] = np.random.rand(sz).tolist()
263264
all_data[env]['tsne2'] = np.random.rand(sz).tolist()
265+
'''
264266

265267
row += sz
266268
print(f'Env {env} has {sz} points')
267269

268270
for env in all_data:
269271
dat = all_data[env]
270-
dat = {k: v for k, v in dat.items() if k in ALL_KEYS}
272+
dat = {k: v for k, v in dat.items() if isinstance(v, list)
273+
and len(v) > 0 and isinstance(v[0], (int, float))
274+
and (k == 'train/max_grad_norm' or not k.endswith('_norm'))}
271275
all_data[env] = dat
272276
for k, v in dat.items():
273277
try:
274278
print(f'{env}/{k}: {len(v), min(v), max(v)}')
275279
except:
276280
print(f'{env}/{k}: {len(v)}')
277281

282+
for env in all_data:
283+
for k, v in all_data[env].items():
284+
if isinstance(v, list):
285+
try:
286+
all_data[env][k] = ','.join([f'{x:.6g}' for x in v])
287+
except:
288+
breakpoint()
289+
278290
json.dump(all_data, open('constellation/default.json', 'w'))
279291

280292
if __name__ == '__main__':

0 commit comments

Comments
 (0)