55import os
66
77env_names = sorted ([
8- # 'breakout',
8+ 'breakout' ,
99 #'impulse_wars',
1010 #'pacman',
1111 #'tetris',
1212 #'g2048',
1313 #'moba',
14- # 'pong',
14+ 'pong' ,
1515 #'tower_climb',
1616 #'grid',
1717 'nmmo3' ,
3535 'train/eps' ,
3636 'train/prio_alpha' ,
3737 'train/prio_beta0' ,
38- # 'train/horizon',
38+ 'train/horizon' ,
3939 'train/replay_ratio' ,
4040 'train/minibatch_size' ,
4141 'policy/hidden_size' ,
@@ -101,10 +101,6 @@ def cached_load(path, env_name, cache):
101101 for k in list (exp ['metrics' ].keys ()):
102102 if 'loss' in k :
103103 del exp ['metrics' ][k ]
104- data_len = len (exp ['metrics' ]['agent_steps' ])
105- if data_len > 100 :
106- print (f'Skipping { fpath } (len={ data_len } )' )
107- continue
108104
109105 if num_metrics == 0 :
110106 num_metrics = len (exp ['metrics' ])
@@ -114,7 +110,6 @@ def cached_load(path, env_name, cache):
114110 metrics = exp ['metrics' ]
115111
116112 if len (metrics ) != num_metrics :
117- breakpoint ()
118113 print (f'Skipping { fpath } (num_metrics={ len (metrics )} != { num_metrics } )' )
119114 continue
120115
@@ -132,7 +127,6 @@ def cached_load(path, env_name, cache):
132127 break
133128
134129 if skip :
135- breakpoint ()
136130 print (f'Skipping { fpath } (bad data)' )
137131 continue
138132
@@ -151,31 +145,34 @@ def cached_load(path, env_name, cache):
151145
152146 for hyper in HYPERS :
153147 prefix , suffix = hyper .split ('/' )
154- if prefix not in sweep_metadata :
155- continue
148+ # if prefix not in sweep_metadata:
149+ # continue
156150
157151 group = sweep_metadata [prefix ]
158- if suffix not in group :
159- continue
152+ # if suffix not in group:
153+ # continue
160154
161- param = group [suffix ]
162155
163156 key = f'{ prefix } /{ suffix } _norm'
164157 if key not in data :
165158 data [key ] = []
166159
167- mmin = param ['min' ]
168- mmax = param ['max' ]
169- dist = param ['distribution' ]
170- val = exp [prefix ][suffix ]
160+ if suffix in group :
161+ param = group [suffix ]
162+ mmin = param ['min' ]
163+ mmax = param ['max' ]
164+ dist = param ['distribution' ]
165+ val = exp [prefix ][suffix ]
171166
172- if 'log' in dist or 'pow2' in dist :
173- mmin = np .log (mmin )
174- mmax = np .log (mmax )
175- val = np .log (val )
167+ if 'log' in dist or 'pow2' in dist :
168+ mmin = np .log (mmin )
169+ mmax = np .log (mmax )
170+ val = np .log (val )
176171
177- norm = (val - mmin ) / (mmax - mmin )
178- data [key ].append ([norm ]* n )
172+ norm = (val - mmin ) / (mmax - mmin )
173+ data [key ].append ([norm ]* n )
174+ else :
175+ data [key ].append ([1 ]* n )
179176
180177 for k , v in data .items ():
181178 data [k ] = [item for sublist in v for item in sublist ]
@@ -255,26 +252,41 @@ def compute_tsne():
255252 row = 0
256253 for env in env_names :
257254 sz = len (all_data [env ]['agent_steps' ])
255+ all_data [env ]['tsne1' ] = reduced [row :row + sz , 0 ].tolist ()
256+ all_data [env ]['tsne2' ] = reduced [row :row + sz , 1 ].tolist ()
257+
258+ '''
258259 if reduced is not None:
259260 all_data[env]['tsne1'] = reduced[row:row+sz, 0].tolist()
260261 all_data[env]['tsne2'] = reduced[row:row+sz, 1].tolist()
261262 else:
262263 all_data[env]['tsne1'] = np.random.rand(sz).tolist()
263264 all_data[env]['tsne2'] = np.random.rand(sz).tolist()
265+ '''
264266
265267 row += sz
266268 print (f'Env { env } has { sz } points' )
267269
268270 for env in all_data :
269271 dat = all_data [env ]
270- dat = {k : v for k , v in dat .items () if k in ALL_KEYS }
272+ dat = {k : v for k , v in dat .items () if isinstance (v , list )
273+ and len (v ) > 0 and isinstance (v [0 ], (int , float ))
274+ and (k == 'train/max_grad_norm' or not k .endswith ('_norm' ))}
271275 all_data [env ] = dat
272276 for k , v in dat .items ():
273277 try :
274278 print (f'{ env } /{ k } : { len (v ), min (v ), max (v )} ' )
275279 except :
276280 print (f'{ env } /{ k } : { len (v )} ' )
277281
282+ for env in all_data :
283+ for k , v in all_data [env ].items ():
284+ if isinstance (v , list ):
285+ try :
286+ all_data [env ][k ] = ',' .join ([f'{ x :.6g} ' for x in v ])
287+ except :
288+ breakpoint ()
289+
278290 json .dump (all_data , open ('constellation/default.json' , 'w' ))
279291
280292if __name__ == '__main__' :
0 commit comments