Skip to content

Commit 9a894aa

Browse files
committed
Fix PSO constructor
1 parent f97b7df commit 9a894aa

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

scripts/run_experiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Literal
1111

12+
# Add project root to path for imports
13+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
14+
1215
from models.decision_tree import DecisionTreeModel
1316
from models.knn import KNNModel
1417

15-
# Add project root to path for imports
16-
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
1718

1819
import numpy as np
1920
import torch

search/particle_swarm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ def __init__(
189189
c2: float = 1.5,
190190
) -> None:
191191
super().__init__(param_space, evaluate_fn, metric_key, seed)
192-
# Vector to param space transformer
192+
self.n_jobs = n_jobs if n_jobs is not None else -1
193+
self.n_particles = n_particles
194+
self.w = w
195+
self.c1 = c1
196+
self.c2 = c2
197+
self._rng = random.Random(seed)
193198
self.transformer = ParameterTransformer(self.param_space)
194199

195200
def run(

0 commit comments

Comments
 (0)