Skip to content

Commit 2b665d3

Browse files
authored
Merge branch 'PufferAI:4.0' into 4.0
2 parents a120550 + 90d08e0 commit 2b665d3

31 files changed

Lines changed: 808 additions & 967 deletions

build.sh

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ set -e
1010
# ./build.sh breakout --web # Emscripten web build
1111
# ./build.sh breakout --profile # Kernel profiling binary
1212

13-
ENV=${1:?Usage: ./build.sh ENV_NAME [--float] [--debug] [--local|--fast|--web|--profile]}
13+
ENV=${1:?Usage: ./build.sh ENV_NAME [--float] [--debug] [--local|--fast|--web|--profile|--cpu]}
1414
MODE=""
1515
PRECISION=""
1616
DEBUG=""
@@ -22,6 +22,7 @@ for arg in "${@:2}"; do
2222
--fast) MODE=fast ;;
2323
--web) MODE=web ;;
2424
--profile) MODE=profile ;;
25+
--cpu) MODE=cpu; PRECISION="-DPRECISION_FLOAT" ;;
2526
esac
2627
done
2728

@@ -34,6 +35,8 @@ CLANG_WARN="\
3435
-Wno-incompatible-pointer-types-discards-qualifiers \
3536
-Wno-error=array-parameter"
3637

38+
PLATFORM="$(uname -s)"
39+
3740
if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then
3841
CLANG_OPT="-g -O0 $CLANG_WARN"
3942
NVCC_OPT="-O0 -g"
@@ -48,8 +51,6 @@ fi
4851
# ============================================================================
4952
# Platform + dependencies
5053
# ============================================================================
51-
52-
PLATFORM="$(uname -s)"
5354
if [ -d "ocean/$ENV" ]; then
5455
SRC_DIR="ocean/$ENV"
5556
elif [ -d "$ENV" ]; then
@@ -196,13 +197,40 @@ if [ "$MODE" = "profile" ]; then
196197
-Xcompiler=-fopenmp \
197198
tests/profile_kernels.cu ini.c \
198199
"$STATIC_LIB" "$RAYLIB_A" \
199-
-lnccl -lnvidia-ml -lcublas -lcurand -lcudnn -lnvToolsExt \
200+
-lnccl -lnvidia-ml -lcublas -lcurand -lcudnn \
200201
-lGL -lm -lpthread -lomp5 \
201202
-o profile
202203
echo "=== Built: ./profile ==="
203204
exit 0
204205
fi
205206

207+
if [ "$MODE" = "cpu" ]; then
208+
echo "=== Compiling bindings_cpu.cpp ==="
209+
g++ -c -fPIC -fopenmp \
210+
-D_GLIBCXX_USE_CXX11_ABI=1 \
211+
-DPLATFORM_DESKTOP \
212+
-std=c++17 \
213+
-I. -Isrc \
214+
-I$PYTHON_INCLUDE -I$PYBIND_INCLUDE \
215+
-DOBS_TENSOR_T=$OBS_TENSOR_T \
216+
$PRECISION $LINK_OPT \
217+
src/bindings_cpu.cpp -o src/bindings_cpu.o
218+
219+
echo "=== Linking $OUTPUT (CPU) ==="
220+
LINK_CMD=(
221+
g++ -shared -fPIC -fopenmp
222+
src/bindings_cpu.o "$STATIC_LIB" "$RAYLIB_A"
223+
-lm -lpthread -lomp5
224+
$LINK_OPT
225+
)
226+
[ "$PLATFORM" = "Linux" ] && LINK_CMD+=(-Bsymbolic-functions)
227+
[ "$PLATFORM" = "Darwin" ] && LINK_CMD+=(-framework Cocoa -framework OpenGL -framework IOKit)
228+
LINK_CMD+=(-o "$OUTPUT")
229+
"${LINK_CMD[@]}"
230+
echo "=== Built: $OUTPUT (CPU) ==="
231+
exit 0
232+
fi
233+
206234
echo "=== Compiling bindings.cu ==="
207235
$NVCC -c -Xcompiler -fPIC \
208236
-Xcompiler=-D_GLIBCXX_USE_CXX11_ABI=1 \
@@ -224,7 +252,7 @@ LINK_CMD=(
224252
src/bindings.o "$STATIC_LIB" "$RAYLIB_A"
225253
-L$CUDA_HOME/lib64
226254
-lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn
227-
-lnvToolsExt -lomp5
255+
-lomp5
228256
$LINK_OPT
229257
)
230258
[ "$PLATFORM" = "Linux" ] && LINK_CMD+=(-Bsymbolic-functions)

constellation/'

Lines changed: 0 additions & 27 deletions
This file was deleted.

constellation/all_cache.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

constellation/cache_data.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import os
66

77
env_names = sorted([
8-
#'breakout',
8+
'breakout',
99
#'impulse_wars',
1010
#'pacman',
1111
#'tetris',
1212
#'g2048',
1313
#'moba',
14-
#'pong',
14+
'pong',
1515
#'tower_climb',
1616
#'grid',
17+
'connect4',
1718
'nmmo3',
1819
#'snake',
1920
#'tripletriad'
@@ -35,7 +36,7 @@
3536
'train/eps',
3637
'train/prio_alpha',
3738
'train/prio_beta0',
38-
#'train/horizon',
39+
'train/horizon',
3940
'train/replay_ratio',
4041
'train/minibatch_size',
4142
'policy/hidden_size',
@@ -101,10 +102,6 @@ def cached_load(path, env_name, cache):
101102
for k in list(exp['metrics'].keys()):
102103
if 'loss' in k:
103104
del exp['metrics'][k]
104-
data_len = len(exp['metrics']['agent_steps'])
105-
if data_len > 100:
106-
print(f'Skipping {fpath} (len={data_len})')
107-
continue
108105

109106
if num_metrics == 0:
110107
num_metrics = len(exp['metrics'])
@@ -114,7 +111,6 @@ def cached_load(path, env_name, cache):
114111
metrics = exp['metrics']
115112

116113
if len(metrics) != num_metrics:
117-
breakpoint()
118114
print(f'Skipping {fpath} (num_metrics={len(metrics)} != {num_metrics})')
119115
continue
120116

@@ -132,7 +128,6 @@ def cached_load(path, env_name, cache):
132128
break
133129

134130
if skip:
135-
breakpoint()
136131
print(f'Skipping {fpath} (bad data)')
137132
continue
138133

@@ -151,31 +146,34 @@ def cached_load(path, env_name, cache):
151146

152147
for hyper in HYPERS:
153148
prefix, suffix = hyper.split('/')
154-
if prefix not in sweep_metadata:
155-
continue
149+
#if prefix not in sweep_metadata:
150+
# continue
156151

157152
group = sweep_metadata[prefix]
158-
if suffix not in group:
159-
continue
153+
#if suffix not in group:
154+
# continue
160155

161-
param = group[suffix]
162156

163157
key = f'{prefix}/{suffix}_norm'
164158
if key not in data:
165159
data[key] = []
166160

167-
mmin = param['min']
168-
mmax = param['max']
169-
dist = param['distribution']
170-
val = exp[prefix][suffix]
161+
if suffix in group:
162+
param = group[suffix]
163+
mmin = param['min']
164+
mmax = param['max']
165+
dist = param['distribution']
166+
val = exp[prefix][suffix]
171167

172-
if 'log' in dist or 'pow2' in dist:
173-
mmin = np.log(mmin)
174-
mmax = np.log(mmax)
175-
val = np.log(val)
168+
if 'log' in dist or 'pow2' in dist:
169+
mmin = np.log(mmin)
170+
mmax = np.log(mmax)
171+
val = np.log(val)
176172

177-
norm = (val - mmin) / (mmax - mmin)
178-
data[key].append([norm]*n)
173+
norm = (val - mmin) / (mmax - mmin)
174+
data[key].append([norm]*n)
175+
else:
176+
data[key].append([1]*n)
179177

180178
for k, v in data.items():
181179
data[k] = [item for sublist in v for item in sublist]
@@ -190,23 +188,11 @@ def cached_load(path, env_name, cache):
190188
#data['metrics/agent_steps'] = [e/1e6 for e in data['metrics/agent_steps']]
191189
del data['metrics/agent_steps']
192190

193-
'''
194-
for k, v in data.items():
195-
for e in v:
196-
if e is None or isinstance(e, str):
197-
continue
198-
try:
199-
if e > 1e9 or e < -1e9:
200-
breakpoint()
201-
except:
202-
breakpoint()
203-
'''
204-
205191
# Filter to pareto
192+
'''
206193
steps = data['agent_steps']
207194
costs = data['uptime']
208195
scores = data['env/score']
209-
'''
210196
idxs = pareto_idx(steps, costs, scores)
211197
for k in data:
212198
try:
@@ -255,27 +241,42 @@ def compute_tsne():
255241
row = 0
256242
for env in env_names:
257243
sz = len(all_data[env]['agent_steps'])
244+
all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
245+
all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
246+
247+
'''
258248
if reduced is not None:
259249
all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
260250
all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
261251
else:
262252
all_data[env]['tsne1'] = np.random.rand(sz).tolist()
263253
all_data[env]['tsne2'] = np.random.rand(sz).tolist()
254+
'''
264255

265256
row += sz
266257
print(f'Env {env} has {sz} points')
267258

268259
for env in all_data:
269260
dat = all_data[env]
270-
dat = {k: v for k, v in dat.items() if k in ALL_KEYS}
261+
dat = {k: v for k, v in dat.items() if isinstance(v, list)
262+
and len(v) > 0 and isinstance(v[0], (int, float))
263+
and (k == 'train/max_grad_norm' or not k.endswith('_norm'))}
271264
all_data[env] = dat
272265
for k, v in dat.items():
273266
try:
274267
print(f'{env}/{k}: {len(v), min(v), max(v)}')
275268
except:
276269
print(f'{env}/{k}: {len(v)}')
277270

278-
json.dump(all_data, open('constellation/default.json', 'w'))
271+
for env in all_data:
272+
for k, v in all_data[env].items():
273+
if isinstance(v, list):
274+
try:
275+
all_data[env][k] = ','.join([f'{x:.6g}' for x in v])
276+
except:
277+
breakpoint()
278+
279+
json.dump(all_data, open('resources/constellation/experiments.json', 'w'))
279280

280281
if __name__ == '__main__':
281282
compute_tsne()

0 commit comments

Comments
 (0)