|
9 | 9 | import math |
10 | 10 | from joblib import Parallel, delayed |
11 | 11 | from dataclasses import dataclass |
| 12 | +from multiprocessing import Manager |
12 | 13 |
|
13 | 14 | # Logging |
14 | 15 | from torch.utils.tensorboard import SummaryWriter |
@@ -57,7 +58,9 @@ def __init__( |
57 | 58 | self.radius = radius |
58 | 59 | self.num_fittest = num_fittest |
59 | 60 | # Memoization cache: maps trial number to (params, metrics, duration) |
60 | | - self._eval_cache: Dict[int, tuple[Dict[str, Any], Dict[str, float], float]] = {} |
| 61 | + # Use Manager().dict() for thread-safe access in parallel execution |
| 62 | + self._manager = Manager() |
| 63 | + self._eval_cache = self._manager.dict() |
61 | 64 |
|
62 | 65 | def run(self, trials: int, verbose: bool = False, writer: Optional[SummaryWriter] = None): |
63 | 66 | """Run the Genetic Algorithm optimization process.""" |
@@ -99,47 +102,57 @@ def run(self, trials: int, verbose: bool = False, writer: Optional[SummaryWriter |
99 | 102 | print(f"Population {geneID}/{self.populationSize}: {params}") |
100 | 103 |
|
101 | 104 | # Check memoization cache to avoid re-evaluation based on trial number |
102 | | - if evals_done in self._eval_cache: |
| 105 | + cache_key = evals_done |
| 106 | + if cache_key in self._eval_cache: |
103 | 107 | # Use cached result |
104 | | - cached_params, metrics, duration = self._eval_cache[evals_done] |
| 108 | + cached_params, metrics, duration = self._eval_cache[cache_key] |
105 | 109 | if verbose: |
106 | 110 | print(f" -> Using cached evaluation result for trial {evals_done}") |
107 | 111 | else: |
108 | 112 | # Evaluate the fitness of each member in the initial population |
109 | 113 | start = time.perf_counter() |
110 | 114 | metrics = self.evaluate_fn(params) |
111 | 115 | duration = time.perf_counter() - start |
112 | | - # Store in cache by trial number |
113 | | - self._eval_cache[evals_done] = (params, metrics, duration) |
| 116 | + # Store in cache by trial number (thread-safe with Manager.dict()) |
| 117 | + self._eval_cache[cache_key] = (params, metrics, duration) |
114 | 118 |
|
115 | 119 | # trial, params, metrics, duration |
116 | 120 | if verbose: |
117 | 121 | print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec") |
118 | 122 | results.append((evals_done, params, metrics, duration)) |
119 | 123 | else: # Parallel Processing |
120 | | - def evaluate_population(evals_done, params, eval_cache): |
121 | | - """Evaluate a single population member.""" |
| 124 | + def evaluate_population(trial_id, params, eval_cache, evaluate_fn, metric_key, verbose_flag): |
| 125 | + """Evaluate a single population member (thread-safe).""" |
122 | 126 | # Check memoization cache to avoid re-evaluation based on trial number |
123 | | - if evals_done in eval_cache: |
124 | | - # Use cached result |
125 | | - cached_params, metrics, duration = eval_cache[evals_done] |
126 | | - print(f" -> Using cached evaluation result for trial {evals_done}") |
| 127 | + cache_key = trial_id |
| 128 | + if cache_key in eval_cache: |
| 129 | + # Use cached result (thread-safe read from Manager.dict()) |
| 130 | + cached_params, metrics, duration = eval_cache[cache_key] |
| 131 | + if verbose_flag: |
| 132 | + print(f" -> Using cached evaluation result for trial {trial_id}") |
127 | 133 | else: |
128 | 134 | start = time.perf_counter() |
129 | | - metrics = self.evaluate_fn(params) |
| 135 | + metrics = evaluate_fn(params) |
130 | 136 | duration = time.perf_counter() - start |
131 | | - # Store in cache by trial number (note: updates to dict are thread-safe for unique keys) |
132 | | - eval_cache[evals_done] = (params, metrics, duration) |
| 137 | + # Store in cache by trial number (thread-safe write with Manager.dict()) |
| 138 | + eval_cache[cache_key] = (params, metrics, duration) |
133 | 139 |
|
134 | 140 | # trial, params, metrics, duration |
135 | | - if verbose: |
136 | | - print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec") |
137 | | - return (evals_done, params, metrics, duration) |
| 141 | + if verbose_flag: |
| 142 | + print(f"Current Trial has evaluated models {trial_id} times, with: {metric_key} = {metrics.get(metric_key, 'N/A')}, Duration: {duration:.4f} sec") |
| 143 | + return (trial_id, params, metrics, duration) |
138 | 144 |
|
139 | 145 | parallel_verbose = 10 if verbose else 0 |
140 | 146 | # Record results (mainly the individual fitness values) into an iterable structure |
141 | 147 | results += list(Parallel(n_jobs=self.n_jobs, verbose=parallel_verbose)( |
142 | | - delayed(evaluate_population)(evals_done + idx, params, self._eval_cache) # added idx for uniqueness (avoid parallel crashes) |
| 148 | + delayed(evaluate_population)( |
| 149 | + evals_done + idx, # trial_id for uniqueness |
| 150 | + params, |
| 151 | + self._eval_cache, # Thread-safe Manager.dict() |
| 152 | + self.evaluate_fn, # Pass function to avoid closure issues |
| 153 | + self.metric_key, # Pass metric_key |
| 154 | + verbose # Pass verbose flag |
| 155 | + ) |
143 | 156 | for idx, params in enumerate(all_params) # each gene and param value in a parameters set |
144 | 157 | )) |
145 | 158 |
|
|
0 commit comments