Skip to content

Commit b0d4f24

Browse files
committed
Update model handling and data preparation
1 parent 2ed19c5 commit b0d4f24

7 files changed

Lines changed: 96 additions & 112 deletions

File tree

framework/data_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Data loading and preprocessing utilities."""
22

33
from pathlib import Path
4-
from typing import List, Tuple
4+
from typing import Any, Dict, List, Tuple
55
import numpy as np
66
from datasets import load_from_disk
77
from sklearn.model_selection import train_test_split
@@ -133,3 +133,34 @@ def create_dataloaders(
133133
)
134134

135135
return train_loader, val_loader
136+
137+
138+
def prepare_dataset(val_ratio: float = 0.1) -> Dict[str, Any]:
139+
"""Prepare and return the CIFAR-10 dataset"""
140+
ds_dict = load_cifar10_data()
141+
train_images, train_labels = prepare_data(ds_dict, "train")
142+
test_images, test_labels = prepare_data(ds_dict, "test")
143+
144+
X_train, y_train, X_val, y_val = split_train_val(
145+
train_images, train_labels, val_ratio=val_ratio
146+
)
147+
148+
def flatten(images):
149+
stacked = np.stack([np.asarray(img, dtype=np.float32) for img in images])
150+
return stacked.reshape(len(images), -1)
151+
152+
train_flat = flatten(X_train)
153+
val_flat = flatten(X_val)
154+
test_flat = flatten(test_images)
155+
156+
return {
157+
"train_images": X_train,
158+
"train_labels": y_train,
159+
"val_images": X_val,
160+
"val_labels": y_val,
161+
"test_images": test_images,
162+
"test_labels": test_labels,
163+
"train_flat": train_flat,
164+
"val_flat": val_flat,
165+
"test_flat": test_flat,
166+
}

hparam_search.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import torch
99
from torch.utils.tensorboard import SummaryWriter
1010

11-
from framework.data_utils import (
12-
load_cifar10_data,
13-
prepare_data,
14-
split_train_val,
15-
)
11+
from framework.data_utils import prepare_dataset
1612
from framework.fitness import calculate_composite_fitness
17-
from models.base import get_model_by_name
13+
from models.cnn import CNNModel
14+
from models.decision_tree import DecisionTreeModel
15+
from models.factory import get_model_by_name
16+
from models.knn import KNNModel
1817
from search import RandomSearch
1918

2019
RANDOM_SEED = 321
@@ -34,34 +33,6 @@ def set_seeds(seed: int):
3433
torch.cuda.manual_seed_all(seed)
3534

3635

37-
def prepare_dataset() -> Dict[str, Any]:
38-
ds_dict = load_cifar10_data()
39-
train_images, train_labels = prepare_data(ds_dict, "train")
40-
test_images, test_labels = prepare_data(ds_dict, "test")
41-
42-
X_train, y_train, X_val, y_val = split_train_val(
43-
train_images, train_labels, val_ratio=0.2
44-
)
45-
46-
def flatten(images):
47-
stacked = np.stack([np.asarray(img, dtype=np.float32) for img in images])
48-
return stacked.reshape(len(images), -1)
49-
50-
train_flat = flatten(X_train)
51-
val_flat = flatten(X_val)
52-
test_flat = flatten(test_images)
53-
54-
return {
55-
"train_images": X_train,
56-
"train_labels": y_train,
57-
"val_images": X_val,
58-
"val_labels": y_val,
59-
"test_images": test_images,
60-
"test_labels": test_labels,
61-
"train_flat": train_flat,
62-
"val_flat": val_flat,
63-
"test_flat": test_flat,
64-
}
6536

6637

6738
def evaluate_model(
@@ -72,10 +43,12 @@ def evaluate_model(
7243
model = get_model_by_name(model_key)
7344

7445
if model_key in {"dt", "knn"}:
46+
assert isinstance(model, (DecisionTreeModel, KNNModel))
7547
model.create_model(**params)
7648
model.train(data["train_flat"], data["train_labels"])
7749
metrics = model.evaluate(data["val_flat"], data["val_labels"])
7850
elif model_key == "cnn":
51+
assert isinstance(model, CNNModel)
7952
model.create_model(**params)
8053
model.train(
8154
data["train_images"],

models/base.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""Abstract interface for models used in the hyperparameter tuning framework."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Dict, Any, Literal, overload
4+
from typing import Dict, Any
5+
56

6-
from models.cnn import CNNModel
7-
from models.decision_tree import DecisionTreeModel
8-
from models.knn import KNNModel
97

108
from .ParamSpace import ParamSpace
119

@@ -40,30 +38,3 @@ def evaluate(self, *args: Any, **kwargs: Any) -> Dict[str, float]:
4038
def get_param_space(self) -> Dict[str, ParamSpace]:
4139
"""Return the searchable hyperparameter space."""
4240
raise NotImplementedError
43-
44-
45-
@overload
46-
def get_model_by_name(model_name: Literal["dt"]) -> DecisionTreeModel:
47-
...
48-
49-
@overload
50-
def get_model_by_name(model_name: Literal["knn"]) -> KNNModel:
51-
...
52-
53-
@overload
54-
def get_model_by_name(model_name: Literal["cnn"]) -> CNNModel:
55-
...
56-
57-
def get_model_by_name(model_name: Literal["dt", "knn", "cnn"]) -> KNNModel | DecisionTreeModel | CNNModel:
58-
models = {
59-
"dt": DecisionTreeModel,
60-
"knn": KNNModel,
61-
"cnn": CNNModel,
62-
}
63-
64-
if model_name not in models:
65-
raise ValueError(
66-
f"Unknown model: {model_name}. Available models: {list(models.keys())}"
67-
)
68-
69-
return models[model_name]()

models/decision_tree.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict
22

33
import numpy as np
44
from sklearn.metrics import classification_report, f1_score, roc_auc_score
@@ -21,23 +21,23 @@ def create_model(self, **params: Any) -> None:
2121
self.params.update(params)
2222
self.estimator = DecisionTreeClassifier(**self.params)
2323

24-
def train(self, X_train: List[np.ndarray], y_train: np.ndarray) -> DecisionTreeClassifier:
24+
def train(self, X_train: np.ndarray, y_train: np.ndarray) -> DecisionTreeClassifier:
2525
if self.estimator is None:
2626
self.create_model()
2727
estimator = self.estimator
2828
assert estimator is not None
2929
estimator.fit(X_train, y_train)
3030
return estimator
3131

32-
def predict(self, X: List[np.ndarray]):
32+
def predict(self, X: np.ndarray):
3333
if self.estimator is None:
3434
raise RuntimeError(
3535
"Estimator has not been created. Call create_model() first."
3636
)
3737
check_is_fitted(self.estimator)
3838
return self.estimator.predict(X)
3939

40-
def predict_proba(self, X: List[np.ndarray]):
40+
def predict_proba(self, X: np.ndarray):
4141
if self.estimator is None:
4242
raise RuntimeError(
4343
"Estimator has not been created. Call create_model() first."
@@ -49,7 +49,7 @@ def predict_proba(self, X: List[np.ndarray]):
4949
check_is_fitted(self.estimator)
5050
return self.estimator.predict_proba(X)
5151

52-
def evaluate(self, X_test: List[np.ndarray], y_test: np.ndarray) -> Dict[str, float]:
52+
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, float]:
5353
if self.estimator is None:
5454
raise RuntimeError(
5555
"Estimator has not been created. Call create_model() first."

models/factory.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Factory function for creating model instances by name."""
2+
3+
from typing import Literal, overload
4+
from models.cnn import CNNModel
5+
from models.decision_tree import DecisionTreeModel
6+
from models.knn import KNNModel
7+
8+
@overload
9+
def get_model_by_name(model_name: Literal["dt"]) -> DecisionTreeModel:
10+
...
11+
12+
13+
@overload
14+
def get_model_by_name(model_name: Literal["knn"]) -> KNNModel:
15+
...
16+
17+
18+
@overload
19+
def get_model_by_name(model_name: Literal["cnn"]) -> CNNModel:
20+
...
21+
22+
23+
def get_model_by_name(model_name: Literal["dt", "knn", "cnn"]) -> KNNModel | DecisionTreeModel | CNNModel:
24+
25+
models = {
26+
"dt": DecisionTreeModel,
27+
"knn": KNNModel,
28+
"cnn": CNNModel,
29+
}
30+
31+
if model_name not in models:
32+
raise ValueError(
33+
f"Unknown model: {model_name}. Available models: {list(models.keys())}"
34+
)
35+
36+
return models[model_name]()
37+

models/knn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict
22

33
import numpy as np
44
from sklearn.metrics import classification_report, f1_score, roc_auc_score
@@ -20,23 +20,23 @@ def create_model(self, **params: Any) -> None:
2020
self.params.update(params)
2121
self.estimator = KNeighborsClassifier(**self.params)
2222

23-
def train(self, X_train: List[np.ndarray], y_train: np.ndarray) -> KNeighborsClassifier:
23+
def train(self, X_train: np.ndarray, y_train: np.ndarray) -> KNeighborsClassifier:
2424
if self.estimator is None:
2525
self.create_model()
2626
estimator = self.estimator
2727
assert estimator is not None
2828
estimator.fit(X_train, y_train)
2929
return estimator
3030

31-
def predict(self, X: List[np.ndarray]):
31+
def predict(self, X: np.ndarray):
3232
if self.estimator is None:
3333
raise RuntimeError(
3434
"Estimator has not been created. Call create_model() first."
3535
)
3636
check_is_fitted(self.estimator)
3737
return self.estimator.predict(X)
3838

39-
def predict_proba(self, X: List[np.ndarray]):
39+
def predict_proba(self, X: np.ndarray):
4040
if self.estimator is None:
4141
raise RuntimeError(
4242
"Estimator has not been created. Call create_model() first."
@@ -48,7 +48,7 @@ def predict_proba(self, X: List[np.ndarray]):
4848
check_is_fitted(self.estimator)
4949
return self.estimator.predict_proba(X)
5050

51-
def evaluate(self, X_test: List[np.ndarray], y_test: np.ndarray) -> Dict[str, float]:
51+
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, float]:
5252
if self.estimator is None:
5353
raise RuntimeError(
5454
"Estimator has not been created. Call create_model() first."

scripts/run_experiment.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,19 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Literal
1111

12+
from models.decision_tree import DecisionTreeModel
13+
from models.knn import KNNModel
14+
1215
# Add project root to path for imports
1316
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
1417

1518
import numpy as np
1619
import torch
1720

18-
from framework.data_utils import (
19-
load_cifar10_data,
20-
prepare_data,
21-
split_train_val,
22-
)
21+
from framework.data_utils import prepare_dataset
2322
from framework.fitness import calculate_composite_fitness
24-
from models.base import get_model_by_name
25-
from models.cnn import TrainingConfig
23+
from models.factory import get_model_by_name
24+
from models.cnn import CNNModel, TrainingConfig
2625
from search import RandomSearch, GeneticAlgorithm, ParticleSwarmOptimization
2726
from dataclasses import replace
2827

@@ -40,35 +39,6 @@ def set_seeds(seed: int):
4039
torch.cuda.manual_seed_all(seed)
4140

4241

43-
def prepare_dataset() -> Dict[str, Any]:
44-
"""Prepare and return the CIFAR-10 dataset."""
45-
ds_dict = load_cifar10_data()
46-
train_images, train_labels = prepare_data(ds_dict, "train")
47-
test_images, test_labels = prepare_data(ds_dict, "test")
48-
49-
X_train, y_train, X_val, y_val = split_train_val(
50-
train_images, train_labels, val_ratio=0.1
51-
)
52-
53-
def flatten(images):
54-
stacked = np.stack([np.asarray(img, dtype=np.float32) for img in images])
55-
return stacked.reshape(len(images), -1)
56-
57-
train_flat = flatten(X_train)
58-
val_flat = flatten(X_val)
59-
test_flat = flatten(test_images)
60-
61-
return {
62-
"train_images": X_train,
63-
"train_labels": y_train,
64-
"val_images": X_val,
65-
"val_labels": y_val,
66-
"test_images": test_images,
67-
"test_labels": test_labels,
68-
"train_flat": train_flat,
69-
"val_flat": val_flat,
70-
"test_flat": test_flat,
71-
}
7242

7343

7444
def evaluate_model(
@@ -81,10 +51,12 @@ def evaluate_model(
8151
model = get_model_by_name(model_key)
8252

8353
if model_key in {"dt", "knn"}:
54+
assert isinstance(model, (DecisionTreeModel, KNNModel))
8455
model.create_model(**params)
8556
model.train(data["train_flat"], data["train_labels"])
8657
metrics = model.evaluate(data["val_flat"], data["val_labels"])
8758
elif model_key == "cnn":
59+
assert isinstance(model, CNNModel)
8860
model.create_model(**params)
8961
default_config = TrainingConfig()
9062
config = replace(

0 commit comments

Comments
 (0)