Skip to content

Commit cf61787

Browse files
committed
drive, tower_climb fixes
1 parent 87e8941 commit cf61787

6 files changed

Lines changed: 59 additions & 388 deletions

File tree

constellation/cache_data.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def pareto_idx(steps, costs, scores):
6262

6363
return idxs
6464

65-
def cached_load(path, env_name, cache):
65+
def cached_load(path, env_name, cache, full_dataset=False):
6666
data = {}
6767
num_metrics = 0
6868
metric_keys = []
@@ -170,21 +170,22 @@ def cached_load(path, env_name, cache):
170170
del data['metrics/agent_steps']
171171

172172
# Filter to pareto
173-
steps = data['agent_steps']
174-
costs = data['uptime']
175-
scores = data['env/score']
176-
177-
idxs = pareto_idx(steps, costs, scores)
178-
for k in data:
179-
try:
180-
data[k] = [data[k][i] for i in idxs]
181-
except IndexError:
182-
continue
173+
if not full_dataset:
174+
steps = data['agent_steps']
175+
costs = data['uptime']
176+
scores = data['env/score']
177+
178+
idxs = pareto_idx(steps, costs, scores)
179+
for k in data:
180+
try:
181+
data[k] = [data[k][i] for i in idxs]
182+
except IndexError:
183+
continue
183184

184185
data['sweep'] = sweep_metadata
185186
return data
186187

187-
def compute_tsne():
188+
def compute_tsne(full_dataset=False):
188189
all_data = {}
189190
normed = {}
190191

@@ -196,7 +197,7 @@ def compute_tsne():
196197
env_names = sorted(os.listdir('logs'))
197198
for env in env_names:
198199
print('Loading: ', env)
199-
all_data[env] = cached_load(f'logs/{env}/*.json', env, cache)
200+
all_data[env] = cached_load(f'logs/{env}/*.json', env, cache, full_dataset)
200201

201202
with open(cache_file, 'w') as f:
202203
json.dump(cache, f)
@@ -230,7 +231,7 @@ def compute_tsne():
230231
and len(v) > 0 and isinstance(v[0], (int, float))
231232
and (k == 'train/max_grad_norm' or not k.endswith('_norm'))}
232233
all_data[env] = dat
233-
print(f'Env {env} has {len(dat['env/perf'])} points')
234+
print(f"Env {env} has {len(dat['env/perf'])} points")
234235
for k, v in dat.items():
235236
if 'env/perf' in k or 'score' in k:
236237
print(f'{env}/{k}: min={min(v)}, max={max(v)}')
@@ -243,4 +244,8 @@ def compute_tsne():
243244
json.dump(all_data, open('resources/constellation/experiments.json', 'w'))
244245

245246
if __name__ == '__main__':
246-
compute_tsne()
247+
import argparse
248+
parser = argparse.ArgumentParser()
249+
parser.add_argument('--full', action='store_true', help='Include full dataset (no pareto filtering)')
250+
args = parser.parse_args()
251+
compute_tsne(full_dataset=args.full)

ocean/drive/binding.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#define ACT_SIZES {7, 13}
44
#define OBS_TENSOR_T FloatTensor
55

6-
#define MAP_BINARY_DIR "resources/drive/binaries/training"
6+
#define MAP_BINARY_DIR "drive_data/binaries"
77

88
#define MY_VEC_INIT
99
#define Env Drive

ocean/drive/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
https://huggingface.co/datasets/daphne-cornelisse/pufferdrive_womd_train_1000
55
66
uv pip install huggingface_hub
7-
python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='daphne-cornelisse/pufferdrive_womd_train_1000', repo_type='dataset', local_dir='resources/drive/data')"
7+
python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='daphne-cornelisse/pufferdrive_womd_train_1000', repo_type='dataset', local_dir='drive_data')"
88
99
Step 1: Unzip to get folder with .json files
10-
mkdir -p resources/drive/data/training
11-
tar xzf resources/drive/data/pufferdrive_womd_train_1000.tar.gz --strip-components=1 -C resources/drive/data/training/
10+
mkdir -p drive_data/training
11+
tar xzf drive_data/pufferdrive_womd_train_1000.tar.gz --strip-components=1 -C drive_data/training/
1212
1313
Step 2: Process to map binaries
14-
python ocean/drive/dataset.py --data_folder resources/drive/data/training --output_dir resources/drive/binaries/training
14+
python ocean/drive/dataset.py --data_folder drive_data/training --output_dir drive_data/binaries
1515
"""
1616

1717
import json

0 commit comments

Comments
 (0)