Skip to content

Commit 292f98d

Browse files
committed
FIXED GA: Replace Dict for evals caching with a proper thread-safe structure Manager().dict()
1 parent 3877f49 commit 292f98d

1 file changed

Lines changed: 31 additions & 18 deletions

File tree

search/genetic_algorithm.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import math
1010
from joblib import Parallel, delayed
1111
from dataclasses import dataclass
12+
from multiprocessing import Manager
1213

1314
# Logging
1415
from torch.utils.tensorboard import SummaryWriter
@@ -57,7 +58,9 @@ def __init__(
5758
self.radius = radius
5859
self.num_fittest = num_fittest
5960
# Memoization cache: maps trial number to (params, metrics, duration)
60-
self._eval_cache: Dict[int, tuple[Dict[str, Any], Dict[str, float], float]] = {}
61+
# Use Manager().dict() for thread-safe access in parallel execution
62+
self._manager = Manager()
63+
self._eval_cache = self._manager.dict()
6164

6265
def run(self, trials: int, verbose: bool = False, writer: Optional[SummaryWriter] = None):
6366
"""Run the Genetic Algorithm optimization process."""
@@ -99,47 +102,57 @@ def run(self, trials: int, verbose: bool = False, writer: Optional[SummaryWriter
99102
print(f"Population {geneID}/{self.populationSize}: {params}")
100103

101104
# Check memoization cache to avoid re-evaluation based on trial number
102-
if evals_done in self._eval_cache:
105+
cache_key = evals_done
106+
if cache_key in self._eval_cache:
103107
# Use cached result
104-
cached_params, metrics, duration = self._eval_cache[evals_done]
108+
cached_params, metrics, duration = self._eval_cache[cache_key]
105109
if verbose:
106110
print(f" -> Using cached evaluation result for trial {evals_done}")
107111
else:
108112
# Evaluate the fitness of each member in the initial population
109113
start = time.perf_counter()
110114
metrics = self.evaluate_fn(params)
111115
duration = time.perf_counter() - start
112-
# Store in cache by trial number
113-
self._eval_cache[evals_done] = (params, metrics, duration)
116+
# Store in cache by trial number (thread-safe with Manager.dict())
117+
self._eval_cache[cache_key] = (params, metrics, duration)
114118

115119
# trial, params, metrics, duration
116120
if verbose:
117121
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
118122
results.append((evals_done, params, metrics, duration))
119123
else: # Parallel Processing
120-
def evaluate_population(evals_done, params, eval_cache):
121-
"""Evaluate a single population member."""
124+
def evaluate_population(trial_id, params, eval_cache, evaluate_fn, metric_key, verbose_flag):
125+
"""Evaluate a single population member (thread-safe)."""
122126
# Check memoization cache to avoid re-evaluation based on trial number
123-
if evals_done in eval_cache:
124-
# Use cached result
125-
cached_params, metrics, duration = eval_cache[evals_done]
126-
print(f" -> Using cached evaluation result for trial {evals_done}")
127+
cache_key = trial_id
128+
if cache_key in eval_cache:
129+
# Use cached result (thread-safe read from Manager.dict())
130+
cached_params, metrics, duration = eval_cache[cache_key]
131+
if verbose_flag:
132+
print(f" -> Using cached evaluation result for trial {trial_id}")
127133
else:
128134
start = time.perf_counter()
129-
metrics = self.evaluate_fn(params)
135+
metrics = evaluate_fn(params)
130136
duration = time.perf_counter() - start
131-
# Store in cache by trial number (note: updates to dict are thread-safe for unique keys)
132-
eval_cache[evals_done] = (params, metrics, duration)
137+
# Store in cache by trial number (thread-safe write with Manager.dict())
138+
eval_cache[cache_key] = (params, metrics, duration)
133139

134140
# trial, params, metrics, duration
135-
if verbose:
136-
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
137-
return (evals_done, params, metrics, duration)
141+
if verbose_flag:
142+
print(f"Current Trial has evaluated models {trial_id} times, with: {metric_key} = {metrics.get(metric_key, 'N/A')}, Duration: {duration:.4f} sec")
143+
return (trial_id, params, metrics, duration)
138144

139145
parallel_verbose = 10 if verbose else 0
140146
# Record results (mainly the individual fitness values) into an iterable structure
141147
results += list(Parallel(n_jobs=self.n_jobs, verbose=parallel_verbose)(
142-
delayed(evaluate_population)(evals_done + idx, params, self._eval_cache) # added idx for uniqueness (avoid parallel crashes)
148+
delayed(evaluate_population)(
149+
evals_done + idx, # trial_id for uniqueness
150+
params,
151+
self._eval_cache, # Thread-safe Manager.dict()
152+
self.evaluate_fn, # Pass function to avoid closure issues
153+
self.metric_key, # Pass metric_key
154+
verbose # Pass verbose flag
155+
)
143156
for idx, params in enumerate(all_params) # each gene and param value in a parameters set
144157
))
145158

0 commit comments

Comments
 (0)