99from pathlib import Path
1010from typing import Any , Dict , Literal
1111
12+ from models .decision_tree import DecisionTreeModel
13+ from models .knn import KNNModel
14+
1215# Add project root to path for imports
1316sys .path .insert (0 , str (Path (__file__ ).resolve ().parents [1 ]))
1417
1518import numpy as np
1619import torch
1720
18- from framework .data_utils import (
19- load_cifar10_data ,
20- prepare_data ,
21- split_train_val ,
22- )
21+ from framework .data_utils import prepare_dataset
2322from framework .fitness import calculate_composite_fitness
24- from models .base import get_model_by_name
25- from models .cnn import TrainingConfig
23+ from models .factory import get_model_by_name
24+ from models .cnn import CNNModel , TrainingConfig
2625from search import RandomSearch , GeneticAlgorithm , ParticleSwarmOptimization
2726from dataclasses import replace
2827
@@ -40,35 +39,6 @@ def set_seeds(seed: int):
4039 torch .cuda .manual_seed_all (seed )
4140
4241
43- def prepare_dataset () -> Dict [str , Any ]:
44- """Prepare and return the CIFAR-10 dataset."""
45- ds_dict = load_cifar10_data ()
46- train_images , train_labels = prepare_data (ds_dict , "train" )
47- test_images , test_labels = prepare_data (ds_dict , "test" )
48-
49- X_train , y_train , X_val , y_val = split_train_val (
50- train_images , train_labels , val_ratio = 0.1
51- )
52-
53- def flatten (images ):
54- stacked = np .stack ([np .asarray (img , dtype = np .float32 ) for img in images ])
55- return stacked .reshape (len (images ), - 1 )
56-
57- train_flat = flatten (X_train )
58- val_flat = flatten (X_val )
59- test_flat = flatten (test_images )
60-
61- return {
62- "train_images" : X_train ,
63- "train_labels" : y_train ,
64- "val_images" : X_val ,
65- "val_labels" : y_val ,
66- "test_images" : test_images ,
67- "test_labels" : test_labels ,
68- "train_flat" : train_flat ,
69- "val_flat" : val_flat ,
70- "test_flat" : test_flat ,
71- }
7242
7343
7444def evaluate_model (
@@ -81,10 +51,12 @@ def evaluate_model(
8151 model = get_model_by_name (model_key )
8252
8353 if model_key in {"dt" , "knn" }:
54+ assert isinstance (model , (DecisionTreeModel , KNNModel ))
8455 model .create_model (** params )
8556 model .train (data ["train_flat" ], data ["train_labels" ])
8657 metrics = model .evaluate (data ["val_flat" ], data ["val_labels" ])
8758 elif model_key == "cnn" :
59+ assert isinstance (model , CNNModel )
8860 model .create_model (** params )
8961 default_config = TrainingConfig ()
9062 config = replace (
0 commit comments