66import time
77from dataclasses import dataclass
88from typing import Any , Callable , Dict , List , Mapping , Optional
9-
109import numpy as np
1110from joblib import Parallel , delayed
1211from torch .utils .tensorboard import SummaryWriter
13-
14- # Import your existing ParamSpace classes
1512from models .ParamSpace import ParamSpace , ParamType
1613from .base import Optimizer
1714
1815
19- # --------------------------------------------------------------------------- #
20- # Helper: Parameter Transformer
21- # --------------------------------------------------------------------------- #
2216class ParameterTransformer :
23- """
24- Handles the translation between the 'Dictionary' world (user params)
25- and the 'Vector' world (PSO math).
26- """
17+ """Transforms the parameter space into a vector and back."""
2718
2819 def __init__ (self , param_space : Mapping [str , ParamSpace ]):
2920 self .param_space = param_space
@@ -35,7 +26,6 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
3526 self .bounds_min : List [float ] = []
3627 self .bounds_max : List [float ] = []
3728
38- # We need to know which indices correspond to which logic
3929 self .types : List [ParamType ] = []
4030
4131 for name in self .param_names :
@@ -57,16 +47,14 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
5747 self .types .append (space .param_type )
5848
5949 elif space .param_type in [ParamType .CATEGORICAL , ParamType .BOOLEAN ]:
60- # N Dimensions (One-Hot / Logits), Loose bounds
61- # We use choices length. For Boolean, choices is [True, False] implicitly in your class
50+ # One-hot ndim
6251 choices = space .choices
6352 if choices is None :
6453 raise ValueError (f"choices cannot be None for { space .param_type .value } parameter" )
6554 n_choices = len (choices )
6655
6756 self .total_dim += n_choices
68- # Logits technically unbounded, but we clamp to prevent overflow/saturation
69- # -10 to 10 covers sigmoid ranges 0.00004 to 0.99995
57+ # Clamp to prevent saturation
7058 self .bounds_min .extend ([- 10.0 ] * n_choices )
7159 self .bounds_max .extend ([10.0 ] * n_choices )
7260 self .types .extend ([space .param_type ] * n_choices )
@@ -77,6 +65,7 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
7765 self .np_bounds_max = np .array (self .bounds_max , dtype = float )
7866
7967 # Velocity limits: 20% of the range
68+ # kinda arbitrary but it works.
8069 self .vel_limits = (self .np_bounds_max - self .np_bounds_min ) * 0.2
8170
8271 def vector_to_params (self , vector : np .ndarray ) -> Dict [str , Any ]:
@@ -89,28 +78,24 @@ def vector_to_params(self, vector: np.ndarray) -> Dict[str, Any]:
8978 segment = vector [sl ]
9079
9180 if space .param_type == ParamType .INTEGER :
92- # Round to nearest integer and clamp to bounds
9381 if space .min_value is None or space .max_value is None :
9482 raise ValueError ("min_value and max_value required for INTEGER parameter" )
9583 rounded = int (round (float (segment [0 ])))
9684 params [name ] = max (int (space .min_value ), min (int (space .max_value ), rounded ))
9785
9886 elif space .param_type == ParamType .FLOAT :
99- # Clamp to bounds
10087 if space .min_value is None or space .max_value is None :
10188 raise ValueError ("min_value and max_value required for FLOAT parameter" )
10289 val = float (segment [0 ])
10390 params [name ] = float (max (float (space .min_value ), min (float (space .max_value ), val )))
10491
10592 elif space .param_type == ParamType .FLOAT_LOG :
106- # Convert back from log-space and clamp to bounds
10793 if space .min_value is None or space .max_value is None :
10894 raise ValueError ("min_value and max_value required for FLOAT_LOG parameter" )
10995 exp_val = math .exp (float (segment [0 ]))
11096 params [name ] = float (max (float (space .min_value ), min (float (space .max_value ), exp_val )))
11197
11298 elif space .param_type in [ParamType .CATEGORICAL , ParamType .BOOLEAN ]:
113- # Argmax of logits -> Index -> Choice
11499 if space .choices is None :
115100 raise ValueError (f"choices cannot be None for { space .param_type .value } parameter" )
116101 best_idx = np .argmax (segment )
@@ -119,21 +104,18 @@ def vector_to_params(self, vector: np.ndarray) -> Dict[str, Any]:
119104 return params
120105
121106 def sample_random_vector (self , rng : random .Random ) -> np .ndarray :
122- """Create a random valid vector in the search space."""
107+ """Sample a random valid vector in the search space."""
123108 vec = np .zeros (self .total_dim )
124109
110+ # I'm choosing a random value between -2 and 2 for the one-hot ndim.
125111 for i , (b_min , b_max , p_type ) in enumerate (zip (self .bounds_min , self .bounds_max , self .types )):
126112 if p_type in [ParamType .CATEGORICAL , ParamType .BOOLEAN ]:
127- # Initialize logits with smaller noise around 0 for fairness
128113 vec [i ] = rng .uniform (- 2.0 , 2.0 )
129114 else :
130115 vec [i ] = rng .uniform (b_min , b_max )
131116 return vec
132117
133118
134- # --------------------------------------------------------------------------- #
135- # Result Data Class
136- # --------------------------------------------------------------------------- #
137119@dataclass
138120class PSOResult :
139121 best_params : Dict [str , Any ]
@@ -142,28 +124,17 @@ class PSOResult:
142124 history : List [Dict [str , Any ]]
143125
144126
145- # --------------------------------------------------------------------------- #
146- # Particle Class
147- # --------------------------------------------------------------------------- #
148127class _Particle :
149128 def __init__ (
150129 self ,
151130 transformer : ParameterTransformer ,
152131 rng : random .Random
153132 ) -> None :
154133 self .transformer = transformer
155-
156- # 1. Position: A flat float vector (including logits)
157134 self .position = transformer .sample_random_vector (rng )
158-
159- # 2. Velocity: Same shape, starts at 0
160135 self .velocity = np .zeros_like (self .position )
161-
162- # 3. Personal Best
163136 self .p_best_pos = self .position .copy ()
164137 self .p_best_score = float ("-inf" )
165-
166- # Cache current params to avoid re-decoding constantly
167138 self .current_params_dict = transformer .vector_to_params (self .position )
168139
169140 def update_velocity (
@@ -175,7 +146,7 @@ def update_velocity(
175146 r2 : np .ndarray ,
176147 g_best_pos : np .ndarray
177148 ) -> None :
178- # Standard PSO Math (Works for logits too!)
149+ # Standard PSO
179150 # v = w*v + c1*r1*(p_best - x) + c2*r2*(g_best - x)
180151
181152 cognitive = c1 * r1 * (self .p_best_pos - self .position )
@@ -194,20 +165,15 @@ def move(self) -> None:
194165 self .position += self .velocity
195166
196167 # Clamp position to valid bounds
197- # For Logits, this prevents values like 1e9 which kill gradients
198168 self .position = np .clip (
199169 self .position ,
200170 self .transformer .np_bounds_min ,
201171 self .transformer .np_bounds_max
202172 )
203173
204- # Update dictionary representation
205174 self .current_params_dict = self .transformer .vector_to_params (self .position )
206175
207176
208- # --------------------------------------------------------------------------- #
209- # Main Optimizer Class
210- # --------------------------------------------------------------------------- #
211177class ParticleSwarmOptimization (Optimizer ):
212178 def __init__ (
213179 self ,
@@ -216,15 +182,14 @@ def __init__(
216182 metric_key : str = "accuracy" ,
217183 seed : Optional [int ] = None ,
218184 n_jobs : int | None = 1 ,
219- # PSO Hyperparameters
185+ # PSO Hyperparams
220186 n_particles : int = 10 ,
221187 w : float = 0.5 ,
222188 c1 : float = 1.5 ,
223189 c2 : float = 1.5 ,
224190 ) -> None :
225191 super ().__init__ (param_space , evaluate_fn , metric_key , seed )
226-
227- # Initialize the Transformer
192+ # Vector to param space transformer
228193 self .transformer = ParameterTransformer (self .param_space )
229194
230195 def run (
@@ -245,33 +210,25 @@ def run(
245210 else :
246211 print (f"Using { self .n_jobs } parallel workers" )
247212
248- # ----------------------------------------------------------------- #
249- # State Initialization
250- # ----------------------------------------------------------------- #
251213 history : List [Dict [str , Any ]] = []
252214
253- # Global Best
254215 g_best_pos : Optional [np .ndarray ] = None
255216 g_best_score = float ("-inf" )
256217 g_best_metrics : Dict [str , float ] = {}
257218 g_best_params : Dict [str , Any ] = {}
258219
259- # Spawn Swarm
260220 swarm = [
261221 _Particle (self .transformer , self ._rng )
262222 for _ in range (self .n_particles )
263223 ]
264224
265- # ----------------------------------------------------------------- #
266- # Optimization Loop
267- # ----------------------------------------------------------------- #
268225 evals_done = 0
269226 generation = 0
270227
271228 while evals_done < trials :
272229 generation += 1
273230
274- # 1. Update Kinematics (Skip gen 0)
231+ # Update Kinematics (Skip gen 0)
275232 if evals_done > 0 and g_best_pos is not None :
276233 for p in swarm :
277234 # Random vectors for stochasticity
@@ -281,15 +238,10 @@ def run(
281238 p .update_velocity (self .w , self .c1 , self .c2 , r1 , r2 , g_best_pos )
282239 p .move ()
283240
284- # 2. Select particles to evaluate (Budget Check)
285241 remaining = trials - evals_done
286- # If remaining budget < n_particles, just eval the first 'remaining' ones
287242 current_batch = swarm [:remaining ]
288-
289- # 3. Prepare Configs
290243 configs = [p .current_params_dict for p in current_batch ]
291244
292- # 4. Evaluate
293245 if self .n_jobs == 1 :
294246 results = []
295247 for cfg in configs :
@@ -309,19 +261,19 @@ def _eval_wrapper(c):
309261 delayed (_eval_wrapper )(c ) for c in configs
310262 )
311263
312- # 5. Update Knowledge
264+ # Update Knowledge
313265 for i , (metrics , duration ) in enumerate (results ):
314266 p = current_batch [i ]
315267 evals_done += 1
316268
317269 score = metrics .get (self .metric_key , float ("-inf" ))
318270
319- # Update Personal Best
271+ # Update personal bests
320272 if score > p .p_best_score :
321273 p .p_best_score = score
322274 p .p_best_pos = p .position .copy ()
323275
324- # Update Global Best
276+ # Update global bests
325277 if score > g_best_score :
326278 g_best_score = score
327279 g_best_pos = p .position .copy ()
@@ -331,7 +283,6 @@ def _eval_wrapper(c):
331283 if verbose :
332284 print (f" Gen { generation } : New Best { self .metric_key } ={ score :.4f} " )
333285
334- # History & Logging
335286 rec = {
336287 "trial" : evals_done ,
337288 "params" : p .current_params_dict .copy (),
@@ -348,7 +299,7 @@ def _eval_wrapper(c):
348299 if evals_done >= trials :
349300 break
350301
351- # Sort history by trial ID
302+ # Sort history by trial number
352303 history .sort (key = lambda x : x ["trial" ])
353304
354305 return PSOResult (
0 commit comments