55import os
66
77env_names = sorted ([
8- # 'breakout',
8+ 'breakout' ,
99 #'impulse_wars',
1010 #'pacman',
1111 #'tetris',
1212 #'g2048',
1313 #'moba',
14- # 'pong',
14+ 'pong' ,
1515 #'tower_climb',
1616 #'grid',
17+ 'connect4' ,
1718 'nmmo3' ,
1819 #'snake',
1920 #'tripletriad'
3536 'train/eps' ,
3637 'train/prio_alpha' ,
3738 'train/prio_beta0' ,
38- # 'train/horizon',
39+ 'train/horizon' ,
3940 'train/replay_ratio' ,
4041 'train/minibatch_size' ,
4142 'policy/hidden_size' ,
@@ -101,10 +102,6 @@ def cached_load(path, env_name, cache):
101102 for k in list (exp ['metrics' ].keys ()):
102103 if 'loss' in k :
103104 del exp ['metrics' ][k ]
104- data_len = len (exp ['metrics' ]['agent_steps' ])
105- if data_len > 100 :
106- print (f'Skipping { fpath } (len={ data_len } )' )
107- continue
108105
109106 if num_metrics == 0 :
110107 num_metrics = len (exp ['metrics' ])
@@ -114,7 +111,6 @@ def cached_load(path, env_name, cache):
114111 metrics = exp ['metrics' ]
115112
116113 if len (metrics ) != num_metrics :
117- breakpoint ()
118114 print (f'Skipping { fpath } (num_metrics={ len (metrics )} != { num_metrics } )' )
119115 continue
120116
@@ -132,7 +128,6 @@ def cached_load(path, env_name, cache):
132128 break
133129
134130 if skip :
135- breakpoint ()
136131 print (f'Skipping { fpath } (bad data)' )
137132 continue
138133
@@ -151,31 +146,34 @@ def cached_load(path, env_name, cache):
151146
152147 for hyper in HYPERS :
153148 prefix , suffix = hyper .split ('/' )
154- if prefix not in sweep_metadata :
155- continue
149+ # if prefix not in sweep_metadata:
150+ # continue
156151
157152 group = sweep_metadata [prefix ]
158- if suffix not in group :
159- continue
153+ # if suffix not in group:
154+ # continue
160155
161- param = group [suffix ]
162156
163157 key = f'{ prefix } /{ suffix } _norm'
164158 if key not in data :
165159 data [key ] = []
166160
167- mmin = param ['min' ]
168- mmax = param ['max' ]
169- dist = param ['distribution' ]
170- val = exp [prefix ][suffix ]
161+ if suffix in group :
162+ param = group [suffix ]
163+ mmin = param ['min' ]
164+ mmax = param ['max' ]
165+ dist = param ['distribution' ]
166+ val = exp [prefix ][suffix ]
171167
172- if 'log' in dist or 'pow2' in dist :
173- mmin = np .log (mmin )
174- mmax = np .log (mmax )
175- val = np .log (val )
168+ if 'log' in dist or 'pow2' in dist :
169+ mmin = np .log (mmin )
170+ mmax = np .log (mmax )
171+ val = np .log (val )
176172
177- norm = (val - mmin ) / (mmax - mmin )
178- data [key ].append ([norm ]* n )
173+ norm = (val - mmin ) / (mmax - mmin )
174+ data [key ].append ([norm ]* n )
175+ else :
176+ data [key ].append ([1 ]* n )
179177
180178 for k , v in data .items ():
181179 data [k ] = [item for sublist in v for item in sublist ]
@@ -190,23 +188,11 @@ def cached_load(path, env_name, cache):
190188 #data['metrics/agent_steps'] = [e/1e6 for e in data['metrics/agent_steps']]
191189 del data ['metrics/agent_steps' ]
192190
193- '''
194- for k, v in data.items():
195- for e in v:
196- if e is None or isinstance(e, str):
197- continue
198- try:
199- if e > 1e9 or e < -1e9:
200- breakpoint()
201- except:
202- breakpoint()
203- '''
204-
205191 # Filter to pareto
192+ '''
206193 steps = data['agent_steps']
207194 costs = data['uptime']
208195 scores = data['env/score']
209- '''
210196 idxs = pareto_idx(steps, costs, scores)
211197 for k in data:
212198 try:
@@ -255,27 +241,42 @@ def compute_tsne():
255241 row = 0
256242 for env in env_names :
257243 sz = len (all_data [env ]['agent_steps' ])
244+ all_data [env ]['tsne1' ] = reduced [row :row + sz , 0 ].tolist ()
245+ all_data [env ]['tsne2' ] = reduced [row :row + sz , 1 ].tolist ()
246+
247+ '''
258248 if reduced is not None:
259249 all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
260250 all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
261251 else:
262252 all_data[env]['tsne1'] = np.random.rand(sz).tolist()
263253 all_data[env]['tsne2'] = np.random.rand(sz).tolist()
254+ '''
264255
265256 row += sz
266257 print (f'Env { env } has { sz } points' )
267258
268259 for env in all_data :
269260 dat = all_data [env ]
270- dat = {k : v for k , v in dat .items () if k in ALL_KEYS }
261+ dat = {k : v for k , v in dat .items () if isinstance (v , list )
262+ and len (v ) > 0 and isinstance (v [0 ], (int , float ))
263+ and (k == 'train/max_grad_norm' or not k .endswith ('_norm' ))}
271264 all_data [env ] = dat
272265 for k , v in dat .items ():
273266 try :
274267 print (f'{ env } /{ k } : { len (v ), min (v ), max (v )} ' )
275268 except :
276269 print (f'{ env } /{ k } : { len (v )} ' )
277270
278- json .dump (all_data , open ('constellation/default.json' , 'w' ))
271+ for env in all_data :
272+ for k , v in all_data [env ].items ():
273+ if isinstance (v , list ):
274+ try :
275+ all_data [env ][k ] = ',' .join ([f'{ x :.6g} ' for x in v ])
276+ except :
277+ breakpoint ()
278+
279+ json .dump (all_data , open ('resources/constellation/experiments.json' , 'w' ))
279280
280281if __name__ == '__main__' :
281282 compute_tsne ()
0 commit comments