Skip to content

Commit 5dba00a

Browse files
Copilotkmock930
andcommitted
Fix critical bugs in crossover and parallel execution
Co-authored-by: kmock930 <78272416+kmock930@users.noreply.github.com>
1 parent b12f19c commit 5dba00a

1 file changed

Lines changed: 17 additions & 11 deletions

File tree

search/genetic_algorithm.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def run(self, trials: int, verbose: bool = False, writer: Optional[SummaryWriter
110110
self._eval_cache[evals_done] = (params, metrics, duration)
111111

112112
# trial, params, metrics, duration
113-
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
113+
if verbose:
114+
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
114115
results.append((evals_done, params, metrics, duration))
115116
else: # Parallel Processing
116117
def evaluate_population(evals_done, params, eval_cache):
@@ -128,14 +129,15 @@ def evaluate_population(evals_done, params, eval_cache):
128129
eval_cache[evals_done] = (params, metrics, duration)
129130

130131
# trial, params, metrics, duration
131-
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
132+
if verbose:
133+
print(f"Current Trial has evaluated models {evals_done} times, with: {self.metric_key} = {metrics.get(self.metric_key, 'N/A')}, Duration: {duration:.4f} sec")
132134
return (evals_done, params, metrics, duration)
133135

134136
parallel_verbose = 10 if verbose else 0
135137
# Record results (mainly the individual fitness values) into an iterable structure
136138
results += list(Parallel(n_jobs=self.n_jobs, verbose=parallel_verbose)(
137-
delayed(evaluate_population)(evals_done, params, self._eval_cache)
138-
for _, params in enumerate(all_params, start=1) # each gene and param value in a parameters set
139+
delayed(evaluate_population)(evals_done + idx, params, self._eval_cache)
140+
for idx, params in enumerate(all_params) # each gene and param value in a parameters set
139141
))
140142

141143
# After Evaluation, Update the number of evaluations done
@@ -414,13 +416,17 @@ def _crossover(self, parent1: Dict[str, Any], parent2: Dict[str, Any], n: int =
414416
)
415417
)
416418

417-
# Swapping - to produce a newly unseen solution
418-
for crossover_point in crossover_points: # adapt to n-point crossover
419-
for paramInd, key in enumerate(parent1.keys()): # key is the parameter name in string format
420-
if paramInd < crossover_point:
421-
child[key] = parent1[key]
422-
else:
423-
child[key] = parent2[key]
419+
# Swapping - to produce a newly unseen solution using n-point crossover
420+
# Alternate between parents at each crossover point
421+
current_parent = 0 # Start with parent1
422+
next_point_idx = 0
423+
for paramInd, key in enumerate(parent1.keys()): # key is the parameter name in string format
424+
# Check if we've reached the next crossover point
425+
if next_point_idx < len(crossover_points) and paramInd >= crossover_points[next_point_idx]:
426+
current_parent = 1 - current_parent # Switch parent
427+
next_point_idx += 1
428+
# Assign value from current parent
429+
child[key] = parent1[key] if current_parent == 0 else parent2[key]
424430
return child
425431

426432
def _mutate(self, individual: Dict[str, Any], mutation_rate: float = 0.01, mutation_strength: float = 0.1) -> Dict[str, Any]:

0 commit comments

Comments
 (0)