1212 #'impulse_wars',
1313 #'pacman',
1414 #'tetris',
15- 'g2048' ,
15+ # 'g2048',
1616 #'moba',
17- 'pong' ,
17+ # 'pong',
1818 #'tower_climb',
19- 'grid' ,
20- 'nmmo3' ,
19+ # 'grid',
20+ # 'nmmo3',
2121 #'snake',
2222 #'tripletriad'
2323])
3838 'train/eps' ,
3939 'train/prio_alpha' ,
4040 'train/prio_beta0' ,
41- # 'train/horizon',
41+ 'train/horizon' ,
4242 'train/replay_ratio' ,
4343 'train/minibatch_size' ,
4444 'policy/hidden_size' ,
@@ -65,22 +65,21 @@ def pareto_idx(steps, costs, scores):
6565
6666 return idxs
6767
68- def load_sweep_data (path ):
68+ def cached_load (path , env_name , cache ):
6969 data = {}
70- sweep_metadata = {}
7170 num_metrics = 0
7271 for fpath in glob .glob (path ):
73- if 'cache.json' in fpath :
74- continue
75-
76- with open (fpath , 'r' ) as f :
77- try :
78- exp = json .load (f )
79- except json .decoder .JSONDecodeError :
80- print (f'Skipping { fpath } ' )
81- continue
72+ if fpath in cache :
73+ exp = cache [ fpath ]
74+ else :
75+ with open (fpath , 'r' ) as f :
76+ try :
77+ exp = json .load (f )
78+ except json .decoder .JSONDecodeError :
79+ print (f'Skipping { fpath } ' )
80+ continue
8281
83- sweep_metadata = exp . pop ( 'sweep' )
82+ cache [ fpath ] = exp
8483
8584 data_len = len (exp ['metrics' ]['agent_steps' ])
8685 if data_len > 100 :
@@ -91,7 +90,7 @@ def load_sweep_data(path):
9190 num_metrics = len (exp ['metrics' ])
9291
9392 skip = False
94- metrics = exp . pop ( 'metrics' )
93+ metrics = exp [ 'metrics' ]
9594
9695 if len (metrics ) != num_metrics :
9796 print (f'Skipping { fpath } (num_metrics={ len (metrics )} != { num_metrics } )' )
@@ -120,62 +119,77 @@ def load_sweep_data(path):
120119 breakpoint ()
121120 pass
122121
122+ sweep_metadata = exp ['sweep' ]
123+
123124 for k , v in pufferlib .unroll_nested_dict (exp ):
124125 if k not in data :
125126 data [k ] = []
126127
127128 data [k ].append ([v ]* n )
128129
130+ for hyper in HYPERS :
131+ prefix , suffix = hyper .split ('/' )
132+ if prefix not in sweep_metadata :
133+ continue
134+
135+ group = sweep_metadata [prefix ]
136+ if suffix not in group :
137+ continue
138+
139+ param = group [suffix ]
140+
141+ key = f'{ prefix } /{ suffix } _norm'
142+ if key not in data :
143+ data [key ] = []
144+
145+ mmin = param ['min' ]
146+ mmax = param ['max' ]
147+ dist = param ['distribution' ]
148+ val = exp [prefix ][suffix ]
149+
150+ if 'log' in dist or 'pow2' in dist :
151+ mmin = np .log (mmin )
152+ mmax = np .log (mmax )
153+ val = np .log (val )
154+
155+ norm = (val - mmin ) / (mmax - mmin )
156+ data [key ].append ([norm ]* n )
157+
129158 for k , v in data .items ():
130159 data [k ] = [item for sublist in v for item in sublist ]
131160
132- #steps = data['agent_steps']
133- #costs = data['uptime']
134- #scores = data['env/score']
135- #idxs = pareto_idx(steps, costs, scores)
136161 # Filter to pareto
137- #for k in data:
138- # data[k] = [data[k][i] for i in idxs]
162+ steps = data ['agent_steps' ]
163+ costs = data ['uptime' ]
164+ scores = data ['env/score' ]
165+ idxs = pareto_idx (steps , costs , scores )
166+ for k in data :
167+ data [k ] = [data [k ][i ] for i in idxs ]
139168
140169 data ['sweep' ] = sweep_metadata
141170 return data
142171
143- def cached_sweep_load (path , env_name ):
144- cache_file = os .path .join (path , 'c_cache.json' )
145- if not os .path .exists (cache_file ):
146- data = load_sweep_data (os .path .join (path , '*.json' ))
147- with open (cache_file , 'w' ) as f :
148- json .dump (data , f )
149-
150- with open (cache_file , 'r' ) as f :
151- data = json .load (f )
152-
153- print (f'Loaded { env_name } ' )
154- return data
155-
156172def compute_tsne ():
157173 all_data = {}
158174 normed = {}
159175
176+ cache = {}
177+ cache_file = os .path .join ('cache.json' )
178+ if os .path .exists (cache_file ):
179+ cache = json .load (open (cache_file , 'r' ))
180+
160181 for env in env_names :
161- env_data = cached_sweep_load (f'logs/puffer_{ env } ' , env )
162- sweep_metadata = env_data .pop ('sweep' )
163- all_data [env ] = env_data
182+ all_data [env ] = cached_load (f'logs/puffer_{ env } /*.json' , env , cache )
164183
184+ with open (cache_file , 'w' ) as f :
185+ json .dump (cache , f )
186+
187+ for env in env_names :
188+ env_data = all_data [env ]
165189 normed_env = []
166190 for key in HYPERS :
167- prefix , suffix = key .split ('/' )
168- mmin = sweep_metadata [prefix ][suffix ]['min' ]
169- mmax = sweep_metadata [prefix ][suffix ]['max' ]
170- dat = np .array (env_data [key ])
171-
172- dist = sweep_metadata [prefix ][suffix ]['distribution' ]
173- if 'log' in dist or 'pow2' in dist :
174- mmin = np .log (mmin )
175- mmax = np .log (mmax )
176- dat = np .log (dat )
177-
178- normed_env .append ((dat - mmin ) / (mmax - mmin ))
191+ norm_key = f'{ key } _norm'
192+ normed_env .append (np .array (env_data [norm_key ]))
179193
180194 normed [env ] = np .stack (normed_env , axis = 1 )
181195
@@ -192,7 +206,6 @@ def compute_tsne():
192206 row = 0
193207 for env in env_names :
194208 sz = len (all_data [env ]['agent_steps' ])
195- #all_data[env] = {k: v for k, v in all_data[env].items()}
196209 if reduced is not None :
197210 all_data [env ]['tsne1' ] = reduced [row :row + sz , 0 ].tolist ()
198211 all_data [env ]['tsne2' ] = reduced [row :row + sz , 1 ].tolist ()
@@ -203,7 +216,7 @@ def compute_tsne():
203216 row += sz
204217 print (f'Env { env } has { sz } points' )
205218
206- json .dump (all_data , open ('all_cache .json' , 'w' ))
219+ json .dump (all_data , open ('pufferlib/ocean/constellation/default .json' , 'w' ))
207220
208221if __name__ == '__main__' :
209222 compute_tsne ()
0 commit comments