Skip to content

Commit f97b7df

Browse files
authored
Merge pull request #4 from AI-Enabled-Software-Testing/pso-and-experiments
Update KNN params, add PSO, create experiment runner
2 parents 23bb979 + aa167e7 commit f97b7df

14 files changed

Lines changed: 1185 additions & 526 deletions

framework/data_utils.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""Data loading and preprocessing utilities."""
22

33
from pathlib import Path
4-
from typing import List, Tuple
4+
from typing import Any, Dict, List, Tuple
55
import numpy as np
66
from datasets import load_from_disk
77
from sklearn.model_selection import train_test_split
88
from torch.utils.data import DataLoader
9-
from PIL import Image
109

1110
from framework import utils
1211
from framework.datasets import CIFAR10Dataset
@@ -116,14 +115,12 @@ def create_dataloaders(
116115
X_val: List[np.ndarray],
117116
y_val: np.ndarray,
118117
batch_size: int,
119-
num_workers: int = 2,
120118
) -> Tuple[DataLoader, DataLoader]:
121119
train_dataset = CIFAR10Dataset(X_train, y_train)
122120
train_loader = DataLoader(
123121
train_dataset,
124122
batch_size=batch_size,
125123
shuffle=True,
126-
num_workers=num_workers,
127124
pin_memory=utils.is_cuda_available(),
128125
)
129126

@@ -132,8 +129,38 @@ def create_dataloaders(
132129
val_dataset,
133130
batch_size=batch_size,
134131
shuffle=False,
135-
num_workers=num_workers,
136132
pin_memory=utils.is_cuda_available(),
137133
)
138134

139135
return train_loader, val_loader
136+
137+
138+
def prepare_dataset(val_ratio: float = 0.1) -> Dict[str, Any]:
139+
"""Prepare and return the CIFAR-10 dataset"""
140+
ds_dict = load_cifar10_data()
141+
train_images, train_labels = prepare_data(ds_dict, "train")
142+
test_images, test_labels = prepare_data(ds_dict, "test")
143+
144+
X_train, y_train, X_val, y_val = split_train_val(
145+
train_images, train_labels, val_ratio=val_ratio
146+
)
147+
148+
def flatten(images):
149+
stacked = np.stack([np.asarray(img, dtype=np.float32) for img in images])
150+
return stacked.reshape(len(images), -1)
151+
152+
train_flat = flatten(X_train)
153+
val_flat = flatten(X_val)
154+
test_flat = flatten(test_images)
155+
156+
return {
157+
"train_images": X_train,
158+
"train_labels": y_train,
159+
"val_images": X_val,
160+
"val_labels": y_val,
161+
"test_images": test_images,
162+
"test_labels": test_labels,
163+
"train_flat": train_flat,
164+
"val_flat": val_flat,
165+
"test_flat": test_flat,
166+
}

framework/fitness.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
def calculate_composite_fitness(metrics: dict[str, float]) -> float:
22
"""Calculate composite fitness score from evaluation metrics."""
33
# Extract metrics
4-
f1_macro = metrics.get("f1_macro", 0.0)
5-
recall_macro = metrics.get("recall_macro", 0.0)
6-
roc_auc = metrics.get("roc_auc", 0.0)
7-
precision_macro = metrics.get("precision_macro", 0.0)
8-
accuracy = metrics.get("accuracy", 0.0)
9-
f1_micro = metrics.get("f1_micro", 0.0)
4+
f1_macro = metrics["f1_macro"]
5+
recall_macro = metrics["recall_macro"]
6+
roc_auc = metrics["roc_auc"]
7+
precision_macro = metrics["precision_macro"]
8+
accuracy = metrics["accuracy"]
9+
f1_micro = metrics["f1_micro"]
1010

1111
# Composite fitness
1212
composite_fitness = (

hparam_search.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import torch
99
from torch.utils.tensorboard import SummaryWriter
1010

11-
from framework.data_utils import (
12-
load_cifar10_data,
13-
prepare_data,
14-
split_train_val,
15-
)
11+
from framework.data_utils import prepare_dataset
1612
from framework.fitness import calculate_composite_fitness
17-
from models.base import get_model_by_name
13+
from models.cnn import CNNModel
14+
from models.decision_tree import DecisionTreeModel
15+
from models.factory import get_model_by_name
16+
from models.knn import KNNModel
1817
from search import RandomSearch
1918

2019
RANDOM_SEED = 321
@@ -34,34 +33,6 @@ def set_seeds(seed: int):
3433
torch.cuda.manual_seed_all(seed)
3534

3635

37-
def prepare_dataset() -> Dict[str, Any]:
38-
ds_dict = load_cifar10_data()
39-
train_images, train_labels = prepare_data(ds_dict, "train")
40-
test_images, test_labels = prepare_data(ds_dict, "test")
41-
42-
X_train, y_train, X_val, y_val = split_train_val(
43-
train_images, train_labels, val_ratio=0.2
44-
)
45-
46-
def flatten(images):
47-
stacked = np.stack([np.asarray(img, dtype=np.float32) for img in images])
48-
return stacked.reshape(len(images), -1)
49-
50-
train_flat = flatten(X_train)
51-
val_flat = flatten(X_val)
52-
test_flat = flatten(test_images)
53-
54-
return {
55-
"train_images": X_train,
56-
"train_labels": y_train,
57-
"val_images": X_val,
58-
"val_labels": y_val,
59-
"test_images": test_images,
60-
"test_labels": test_labels,
61-
"train_flat": train_flat,
62-
"val_flat": val_flat,
63-
"test_flat": test_flat,
64-
}
6536

6637

6738
def evaluate_model(
@@ -72,10 +43,12 @@ def evaluate_model(
7243
model = get_model_by_name(model_key)
7344

7445
if model_key in {"dt", "knn"}:
46+
assert isinstance(model, (DecisionTreeModel, KNNModel))
7547
model.create_model(**params)
7648
model.train(data["train_flat"], data["train_labels"])
7749
metrics = model.evaluate(data["val_flat"], data["val_labels"])
7850
elif model_key == "cnn":
51+
assert isinstance(model, CNNModel)
7952
model.create_model(**params)
8053
model.train(
8154
data["train_images"],

models/base.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Abstract interface for models used in the hyperparameter tuning framework."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Dict, Any, Literal
4+
from typing import Dict, Any
5+
6+
57

68
from .ParamSpace import ParamSpace
79

@@ -36,23 +38,3 @@ def evaluate(self, *args: Any, **kwargs: Any) -> Dict[str, float]:
3638
def get_param_space(self) -> Dict[str, ParamSpace]:
3739
"""Return the searchable hyperparameter space."""
3840
raise NotImplementedError
39-
40-
41-
def get_model_by_name(model_name: Literal["dt", "knn", "cnn"]) -> BaseModel:
42-
"""Factory function to get model by name."""
43-
from models.decision_tree import DecisionTreeModel
44-
from models.knn import KNNModel
45-
from models.cnn import CNNModel
46-
47-
models = {
48-
"dt": DecisionTreeModel,
49-
"knn": KNNModel,
50-
"cnn": CNNModel,
51-
}
52-
53-
if model_name not in models:
54-
raise ValueError(
55-
f"Unknown model: {model_name}. Available models: {list(models.keys())}"
56-
)
57-
58-
return models[model_name]()

0 commit comments

Comments
 (0)