@@ -62,7 +62,7 @@ def pareto_idx(steps, costs, scores):
6262
6363 return idxs
6464
65- def cached_load (path , env_name , cache ):
65+ def cached_load (path , env_name , cache , full_dataset = False ):
6666 data = {}
6767 num_metrics = 0
6868 metric_keys = []
@@ -170,21 +170,22 @@ def cached_load(path, env_name, cache):
170170 del data ['metrics/agent_steps' ]
171171
172172 # Filter to pareto
173- steps = data ['agent_steps' ]
174- costs = data ['uptime' ]
175- scores = data ['env/score' ]
176-
177- idxs = pareto_idx (steps , costs , scores )
178- for k in data :
179- try :
180- data [k ] = [data [k ][i ] for i in idxs ]
181- except IndexError :
182- continue
173+ if not full_dataset :
174+ steps = data ['agent_steps' ]
175+ costs = data ['uptime' ]
176+ scores = data ['env/score' ]
177+
178+ idxs = pareto_idx (steps , costs , scores )
179+ for k in data :
180+ try :
181+ data [k ] = [data [k ][i ] for i in idxs ]
182+ except IndexError :
183+ continue
183184
184185 data ['sweep' ] = sweep_metadata
185186 return data
186187
187- def compute_tsne ():
188+ def compute_tsne (full_dataset = False ):
188189 all_data = {}
189190 normed = {}
190191
@@ -196,7 +197,7 @@ def compute_tsne():
196197 env_names = sorted (os .listdir ('logs' ))
197198 for env in env_names :
198199 print ('Loading: ' , env )
199- all_data [env ] = cached_load (f'logs/{ env } /*.json' , env , cache )
200+ all_data [env ] = cached_load (f'logs/{ env } /*.json' , env , cache , full_dataset )
200201
201202 with open (cache_file , 'w' ) as f :
202203 json .dump (cache , f )
@@ -230,7 +231,7 @@ def compute_tsne():
230231 and len (v ) > 0 and isinstance (v [0 ], (int , float ))
231232 and (k == 'train/max_grad_norm' or not k .endswith ('_norm' ))}
232233 all_data [env ] = dat
233- print (f' Env { env } has { len (dat ['env/perf' ])} points' )
234+ print (f" Env { env } has { len (dat ['env/perf' ])} points" )
234235 for k , v in dat .items ():
235236 if 'env/perf' in k or 'score' in k :
236237 print (f'{ env } /{ k } : min={ min (v )} , max={ max (v )} ' )
@@ -243,4 +244,8 @@ def compute_tsne():
243244 json .dump (all_data , open ('resources/constellation/experiments.json' , 'w' ))
244245
245246if __name__ == '__main__' :
246- compute_tsne ()
247+ import argparse
248+ parser = argparse .ArgumentParser ()
249+ parser .add_argument ('--full' , action = 'store_true' , help = 'Include full dataset (no pareto filtering)' )
250+ args = parser .parse_args ()
251+ compute_tsne (full_dataset = args .full )
0 commit comments