Skip to content

Commit 80cd743

Browse files
committed
Refactor for clarity
1 parent 6f58927 commit 80cd743

1 file changed

Lines changed: 14 additions & 63 deletions

File tree

search/particle_swarm.py

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,15 @@
66
import time
77
from dataclasses import dataclass
88
from typing import Any, Callable, Dict, List, Mapping, Optional
9-
109
import numpy as np
1110
from joblib import Parallel, delayed
1211
from torch.utils.tensorboard import SummaryWriter
13-
14-
# Import your existing ParamSpace classes
1512
from models.ParamSpace import ParamSpace, ParamType
1613
from .base import Optimizer
1714

1815

19-
# --------------------------------------------------------------------------- #
20-
# Helper: Parameter Transformer
21-
# --------------------------------------------------------------------------- #
2216
class ParameterTransformer:
23-
"""
24-
Handles the translation between the 'Dictionary' world (user params)
25-
and the 'Vector' world (PSO math).
26-
"""
17+
"""Transforms the parameter space into a vector and back."""
2718

2819
def __init__(self, param_space: Mapping[str, ParamSpace]):
2920
self.param_space = param_space
@@ -35,7 +26,6 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
3526
self.bounds_min: List[float] = []
3627
self.bounds_max: List[float] = []
3728

38-
# We need to know which indices correspond to which logic
3929
self.types: List[ParamType] = []
4030

4131
for name in self.param_names:
@@ -57,16 +47,14 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
5747
self.types.append(space.param_type)
5848

5949
elif space.param_type in [ParamType.CATEGORICAL, ParamType.BOOLEAN]:
60-
# N Dimensions (One-Hot / Logits), Loose bounds
61-
# We use choices length. For Boolean, choices is [True, False] implicitly in your class
50+
# One-hot ndim
6251
choices = space.choices
6352
if choices is None:
6453
raise ValueError(f"choices cannot be None for {space.param_type.value} parameter")
6554
n_choices = len(choices)
6655

6756
self.total_dim += n_choices
68-
# Logits technically unbounded, but we clamp to prevent overflow/saturation
69-
# -10 to 10 covers sigmoid ranges 0.00004 to 0.99995
57+
# Clamp to prevent saturation
7058
self.bounds_min.extend([-10.0] * n_choices)
7159
self.bounds_max.extend([10.0] * n_choices)
7260
self.types.extend([space.param_type] * n_choices)
@@ -77,6 +65,7 @@ def __init__(self, param_space: Mapping[str, ParamSpace]):
7765
self.np_bounds_max = np.array(self.bounds_max, dtype=float)
7866

7967
# Velocity limits: 20% of the range
68+
# kinda arbitrary but it works.
8069
self.vel_limits = (self.np_bounds_max - self.np_bounds_min) * 0.2
8170

8271
def vector_to_params(self, vector: np.ndarray) -> Dict[str, Any]:
@@ -89,28 +78,24 @@ def vector_to_params(self, vector: np.ndarray) -> Dict[str, Any]:
8978
segment = vector[sl]
9079

9180
if space.param_type == ParamType.INTEGER:
92-
# Round to nearest integer and clamp to bounds
9381
if space.min_value is None or space.max_value is None:
9482
raise ValueError("min_value and max_value required for INTEGER parameter")
9583
rounded = int(round(float(segment[0])))
9684
params[name] = max(int(space.min_value), min(int(space.max_value), rounded))
9785

9886
elif space.param_type == ParamType.FLOAT:
99-
# Clamp to bounds
10087
if space.min_value is None or space.max_value is None:
10188
raise ValueError("min_value and max_value required for FLOAT parameter")
10289
val = float(segment[0])
10390
params[name] = float(max(float(space.min_value), min(float(space.max_value), val)))
10491

10592
elif space.param_type == ParamType.FLOAT_LOG:
106-
# Convert back from log-space and clamp to bounds
10793
if space.min_value is None or space.max_value is None:
10894
raise ValueError("min_value and max_value required for FLOAT_LOG parameter")
10995
exp_val = math.exp(float(segment[0]))
11096
params[name] = float(max(float(space.min_value), min(float(space.max_value), exp_val)))
11197

11298
elif space.param_type in [ParamType.CATEGORICAL, ParamType.BOOLEAN]:
113-
# Argmax of logits -> Index -> Choice
11499
if space.choices is None:
115100
raise ValueError(f"choices cannot be None for {space.param_type.value} parameter")
116101
best_idx = np.argmax(segment)
@@ -119,21 +104,18 @@ def vector_to_params(self, vector: np.ndarray) -> Dict[str, Any]:
119104
return params
120105

121106
def sample_random_vector(self, rng: random.Random) -> np.ndarray:
122-
"""Create a random valid vector in the search space."""
107+
"""Sample a random valid vector in the search space."""
123108
vec = np.zeros(self.total_dim)
124109

110+
# I'm choosing a random value between -2 and 2 for the one-hot ndim.
125111
for i, (b_min, b_max, p_type) in enumerate(zip(self.bounds_min, self.bounds_max, self.types)):
126112
if p_type in [ParamType.CATEGORICAL, ParamType.BOOLEAN]:
127-
# Initialize logits with smaller noise around 0 for fairness
128113
vec[i] = rng.uniform(-2.0, 2.0)
129114
else:
130115
vec[i] = rng.uniform(b_min, b_max)
131116
return vec
132117

133118

134-
# --------------------------------------------------------------------------- #
135-
# Result Data Class
136-
# --------------------------------------------------------------------------- #
137119
@dataclass
138120
class PSOResult:
139121
best_params: Dict[str, Any]
@@ -142,28 +124,17 @@ class PSOResult:
142124
history: List[Dict[str, Any]]
143125

144126

145-
# --------------------------------------------------------------------------- #
146-
# Particle Class
147-
# --------------------------------------------------------------------------- #
148127
class _Particle:
149128
def __init__(
150129
self,
151130
transformer: ParameterTransformer,
152131
rng: random.Random
153132
) -> None:
154133
self.transformer = transformer
155-
156-
# 1. Position: A flat float vector (including logits)
157134
self.position = transformer.sample_random_vector(rng)
158-
159-
# 2. Velocity: Same shape, starts at 0
160135
self.velocity = np.zeros_like(self.position)
161-
162-
# 3. Personal Best
163136
self.p_best_pos = self.position.copy()
164137
self.p_best_score = float("-inf")
165-
166-
# Cache current params to avoid re-decoding constantly
167138
self.current_params_dict = transformer.vector_to_params(self.position)
168139

169140
def update_velocity(
@@ -175,7 +146,7 @@ def update_velocity(
175146
r2: np.ndarray,
176147
g_best_pos: np.ndarray
177148
) -> None:
178-
# Standard PSO Math (Works for logits too!)
149+
# Standard PSO
179150
# v = w*v + c1*r1*(p_best - x) + c2*r2*(g_best - x)
180151

181152
cognitive = c1 * r1 * (self.p_best_pos - self.position)
@@ -194,20 +165,15 @@ def move(self) -> None:
194165
self.position += self.velocity
195166

196167
# Clamp position to valid bounds
197-
# For Logits, this prevents values like 1e9 which kill gradients
198168
self.position = np.clip(
199169
self.position,
200170
self.transformer.np_bounds_min,
201171
self.transformer.np_bounds_max
202172
)
203173

204-
# Update dictionary representation
205174
self.current_params_dict = self.transformer.vector_to_params(self.position)
206175

207176

208-
# --------------------------------------------------------------------------- #
209-
# Main Optimizer Class
210-
# --------------------------------------------------------------------------- #
211177
class ParticleSwarmOptimization(Optimizer):
212178
def __init__(
213179
self,
@@ -216,15 +182,14 @@ def __init__(
216182
metric_key: str = "accuracy",
217183
seed: Optional[int] = None,
218184
n_jobs: int | None = 1,
219-
# PSO Hyperparameters
185+
# PSO Hyperparams
220186
n_particles: int = 10,
221187
w: float = 0.5,
222188
c1: float = 1.5,
223189
c2: float = 1.5,
224190
) -> None:
225191
super().__init__(param_space, evaluate_fn, metric_key, seed)
226-
227-
# Initialize the Transformer
192+
# Vector to param space transformer
228193
self.transformer = ParameterTransformer(self.param_space)
229194

230195
def run(
@@ -245,33 +210,25 @@ def run(
245210
else:
246211
print(f"Using {self.n_jobs} parallel workers")
247212

248-
# ----------------------------------------------------------------- #
249-
# State Initialization
250-
# ----------------------------------------------------------------- #
251213
history: List[Dict[str, Any]] = []
252214

253-
# Global Best
254215
g_best_pos: Optional[np.ndarray] = None
255216
g_best_score = float("-inf")
256217
g_best_metrics: Dict[str, float] = {}
257218
g_best_params: Dict[str, Any] = {}
258219

259-
# Spawn Swarm
260220
swarm = [
261221
_Particle(self.transformer, self._rng)
262222
for _ in range(self.n_particles)
263223
]
264224

265-
# ----------------------------------------------------------------- #
266-
# Optimization Loop
267-
# ----------------------------------------------------------------- #
268225
evals_done = 0
269226
generation = 0
270227

271228
while evals_done < trials:
272229
generation += 1
273230

274-
# 1. Update Kinematics (Skip gen 0)
231+
# Update Kinematics (Skip gen 0)
275232
if evals_done > 0 and g_best_pos is not None:
276233
for p in swarm:
277234
# Random vectors for stochasticity
@@ -281,15 +238,10 @@ def run(
281238
p.update_velocity(self.w, self.c1, self.c2, r1, r2, g_best_pos)
282239
p.move()
283240

284-
# 2. Select particles to evaluate (Budget Check)
285241
remaining = trials - evals_done
286-
# If remaining budget < n_particles, just eval the first 'remaining' ones
287242
current_batch = swarm[:remaining]
288-
289-
# 3. Prepare Configs
290243
configs = [p.current_params_dict for p in current_batch]
291244

292-
# 4. Evaluate
293245
if self.n_jobs == 1:
294246
results = []
295247
for cfg in configs:
@@ -309,19 +261,19 @@ def _eval_wrapper(c):
309261
delayed(_eval_wrapper)(c) for c in configs
310262
)
311263

312-
# 5. Update Knowledge
264+
# Update Knowledge
313265
for i, (metrics, duration) in enumerate(results):
314266
p = current_batch[i]
315267
evals_done += 1
316268

317269
score = metrics.get(self.metric_key, float("-inf"))
318270

319-
# Update Personal Best
271+
# Update personal bests
320272
if score > p.p_best_score:
321273
p.p_best_score = score
322274
p.p_best_pos = p.position.copy()
323275

324-
# Update Global Best
276+
# Update global bests
325277
if score > g_best_score:
326278
g_best_score = score
327279
g_best_pos = p.position.copy()
@@ -331,7 +283,6 @@ def _eval_wrapper(c):
331283
if verbose:
332284
print(f" Gen {generation}: New Best {self.metric_key}={score:.4f}")
333285

334-
# History & Logging
335286
rec = {
336287
"trial": evals_done,
337288
"params": p.current_params_dict.copy(),
@@ -348,7 +299,7 @@ def _eval_wrapper(c):
348299
if evals_done >= trials:
349300
break
350301

351-
# Sort history by trial ID
302+
# Sort history by trial number
352303
history.sort(key=lambda x: x["trial"])
353304

354305
return PSOResult(

0 commit comments

Comments
 (0)