88import torch
99from torch .utils .tensorboard import SummaryWriter
1010
11- from framework .data_utils import (
12- load_cifar10_data ,
13- prepare_data ,
14- split_train_val ,
15- )
11+ from framework .data_utils import prepare_dataset
1612from framework .fitness import calculate_composite_fitness
17- from models .base import get_model_by_name
13+ from models .cnn import CNNModel
14+ from models .decision_tree import DecisionTreeModel
15+ from models .factory import get_model_by_name
16+ from models .knn import KNNModel
1817from search import RandomSearch
1918
2019RANDOM_SEED = 321
@@ -34,34 +33,6 @@ def set_seeds(seed: int):
3433 torch .cuda .manual_seed_all (seed )
3534
3635
37- def prepare_dataset () -> Dict [str , Any ]:
38- ds_dict = load_cifar10_data ()
39- train_images , train_labels = prepare_data (ds_dict , "train" )
40- test_images , test_labels = prepare_data (ds_dict , "test" )
41-
42- X_train , y_train , X_val , y_val = split_train_val (
43- train_images , train_labels , val_ratio = 0.2
44- )
45-
46- def flatten (images ):
47- stacked = np .stack ([np .asarray (img , dtype = np .float32 ) for img in images ])
48- return stacked .reshape (len (images ), - 1 )
49-
50- train_flat = flatten (X_train )
51- val_flat = flatten (X_val )
52- test_flat = flatten (test_images )
53-
54- return {
55- "train_images" : X_train ,
56- "train_labels" : y_train ,
57- "val_images" : X_val ,
58- "val_labels" : y_val ,
59- "test_images" : test_images ,
60- "test_labels" : test_labels ,
61- "train_flat" : train_flat ,
62- "val_flat" : val_flat ,
63- "test_flat" : test_flat ,
64- }
6536
6637
6738def evaluate_model (
@@ -72,10 +43,12 @@ def evaluate_model(
7243 model = get_model_by_name (model_key )
7344
7445 if model_key in {"dt" , "knn" }:
46+ assert isinstance (model , (DecisionTreeModel , KNNModel ))
7547 model .create_model (** params )
7648 model .train (data ["train_flat" ], data ["train_labels" ])
7749 metrics = model .evaluate (data ["val_flat" ], data ["val_labels" ])
7850 elif model_key == "cnn" :
51+ assert isinstance (model , CNNModel )
7952 model .create_model (** params )
8053 model .train (
8154 data ["train_images" ],
0 commit comments