Skip to content

Commit f5396ed

Browse files
committed
Add RS demo notebook
1 parent b0d4f24 commit f5396ed

1 file changed

Lines changed: 305 additions & 0 deletions

File tree

notebooks/rs_training.ipynb

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "d90a756a",
6+
"metadata": {},
7+
"source": [
8+
"# Model Training Demo"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"id": "63bd092a",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"# Imports\n",
19+
"\n",
20+
"from typing import Literal\n",
21+
"import os\n",
22+
"import sys\n",
23+
"import numpy as np\n",
24+
"\n",
25+
"sys.path.append(os.path.join(os.path.curdir, \"..\"))\n",
26+
"\n",
27+
"from search.random_search import RandomSearch\n",
28+
"from scripts.run_experiment import prepare_dataset\n",
29+
"from models.cnn import CNNModel, TrainingConfig\n",
30+
"from models.factory import get_model_by_name\n",
31+
"from models.decision_tree import DecisionTreeModel\n",
32+
"from models.knn import KNNModel\n"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 2,
38+
"id": "70994585",
39+
"metadata": {},
40+
"outputs": [
41+
{
42+
"name": "stdout",
43+
"output_type": "stream",
44+
"text": [
45+
"CIFAR-10 sample image count: 100\n",
46+
"CIFAR-10 sample label shape: (100,)\n",
47+
"Individual image shape: (32, 32)\n",
48+
"CIFAR-10 pixel value ranges: [0.0, 1.0]\n",
49+
"Image data type: float32\n",
50+
"Label data type: int64\n",
51+
"Sample datasets loaded successfully!\n",
52+
"CIFAR-10 sample: 100 train, 50 validation\n"
53+
]
54+
}
55+
],
56+
"source": [
57+
"# Load CIFAR-10 dataset\n",
58+
"\n",
59+
"cifar10_data = prepare_dataset()\n",
60+
"\n",
61+
"# Smaller samples for demo (100 train, 50 validation)\n",
62+
"SAMPLE_SIZE = 100\n",
63+
"VAL_SAMPLE_SIZE = 50\n",
64+
"\n",
65+
"# Sample from the prepared data\n",
66+
"np.random.seed(42)\n",
67+
"train_indices = np.random.choice(len(cifar10_data['train_images']), SAMPLE_SIZE, replace=False)\n",
68+
"val_indices = np.random.choice(len(cifar10_data['val_images']), VAL_SAMPLE_SIZE, replace=False)\n",
69+
"\n",
70+
"# CNN uses raw images (List[np.ndarray])\n",
71+
"X_train: list[np.ndarray] = [cifar10_data['train_images'][i] for i in train_indices]\n",
72+
"X_test: list[np.ndarray] = [cifar10_data['val_images'][i] for i in val_indices]\n",
73+
"\n",
74+
"# sklearn uses flattened arrays\n",
75+
"X_train_flat: np.ndarray = cifar10_data['train_flat'][train_indices]\n",
76+
"X_test_flat: np.ndarray = cifar10_data['val_flat'][val_indices]\n",
77+
"\n",
78+
"# Labels are the same for both\n",
79+
"y_train = cifar10_data['train_labels'][train_indices]\n",
80+
"y_test = cifar10_data['val_labels'][val_indices]\n",
81+
"\n",
82+
"# Observe Sample Shapes\n",
83+
"print(f\"CIFAR-10 sample image count: {len(X_train)}\")\n",
84+
"print(f\"CIFAR-10 sample label shape: {y_train.shape}\")\n",
85+
"print(f\"Individual image shape: {X_train[0].shape}\")\n",
86+
"\n",
87+
"pixel_min = np.min([np.min(img) for img in X_train])\n",
88+
"pixel_max = np.max([np.max(img) for img in X_train])\n",
89+
"print(f\"CIFAR-10 pixel value ranges: [{pixel_min}, {pixel_max}]\")\n",
90+
"\n",
91+
"# Show data types\n",
92+
"print(f\"Image data type: {X_train[0].dtype}\")\n",
93+
"print(f\"Label data type: {y_train.dtype}\")\n",
94+
"\n",
95+
"print(\"Sample datasets loaded successfully!\")\n",
96+
"print(f\"CIFAR-10 sample: {len(X_train)} train, {len(X_test)} validation\")\n"
97+
]
98+
},
99+
{
100+
"cell_type": "markdown",
101+
"id": "22d798a6",
102+
"metadata": {},
103+
"source": [
104+
"## Hyperparameter Search Test"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 3,
110+
"id": "367be266",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"def quick_hyperparameter_test(\n",
115+
" model_keys: list[Literal['dt', 'knn', 'cnn']],\n",
116+
" X_train: list[np.ndarray],\n",
117+
" y_train: np.ndarray,\n",
118+
" X_test: list[np.ndarray],\n",
119+
" y_test: np.ndarray,\n",
120+
" X_train_flat: np.ndarray,\n",
121+
" X_test_flat: np.ndarray,\n",
122+
" dataset_name: str = \"Dataset\",\n",
123+
" trials: int = 5\n",
124+
"):\n",
125+
" \"\"\"Perform a quick hyperparameter test using RandomSearch\"\"\"\n",
126+
" # Map model keys to display names\n",
127+
" model_key_to_name = {\n",
128+
" \"dt\": \"Decision Tree\",\n",
129+
" \"knn\": \"K-Nearest Neighbors\",\n",
130+
" \"cnn\": \"Convolutional Neural Network\",\n",
131+
" }\n",
132+
" print(f\"Starting quick hyperparameter test on {dataset_name}\")\n",
133+
" print(\"=\" * 60)\n",
134+
" print(f\"Using RandomSearch with {trials} trials per model\")\n",
135+
" results = {}\n",
136+
" for model_key in model_keys:\n",
137+
" model_name = model_key_to_name.get(model_key, model_key)\n",
138+
" print(f\"\\nTesting {model_name}...\")\n",
139+
" # Get model and parameter space\n",
140+
" model = get_model_by_name(model_key)\n",
141+
" param_space = model.get_param_space()\n",
142+
" # Create evaluation function for this model\n",
143+
" def evaluate_params(params):\n",
144+
" # Create fresh model instance\n",
145+
" model_instance = get_model_by_name(model_key)\n",
146+
" if model_key == \"cnn\":\n",
147+
" assert isinstance(model_instance, CNNModel)\n",
148+
" X_train_prep = X_train\n",
149+
" X_test_prep = X_test\n",
150+
" y_train_prep, y_test_prep = y_train, y_test\n",
151+
" # Separate CNN-specific params from training config params\n",
152+
" cnn_params = {}\n",
153+
" training_config_params = {}\n",
154+
" for param_name, param_value in params.items():\n",
155+
" if param_name in ['batch_size', 'learning_rate', 'optimizer', 'weight_decay']:\n",
156+
" training_config_params[param_name] = param_value\n",
157+
" else:\n",
158+
" cnn_params[param_name] = param_value\n",
159+
" # Create model with CNN architecture params\n",
160+
" model_instance.create_model(**cnn_params)\n",
161+
" # Create training config with training params\n",
162+
" config = TrainingConfig(epochs=5, **training_config_params)\n",
163+
" # Train using the correct CNN signature\n",
164+
" model_instance.train(X_train_prep, y_train_prep, X_test_prep, y_test_prep, config=config, verbose=False)\n",
165+
" # Evaluate CNN\n",
166+
" return model_instance.evaluate(X_test_prep, y_test_prep)\n",
167+
" else:\n",
168+
" assert isinstance(model_instance, (DecisionTreeModel, KNNModel))\n",
169+
" # sklearn models\n",
170+
" X_train_prep = X_train_flat\n",
171+
" X_test_prep = X_test_flat\n",
172+
" y_train_prep, y_test_prep = y_train, y_test\n",
173+
" # Create model with params, then train\n",
174+
" model_instance.create_model(**params)\n",
175+
" model_instance.train(X_train_prep, y_train_prep)\n",
176+
" # Evaluate sklearn models\n",
177+
" return model_instance.evaluate(X_test_prep, y_test_prep)\n",
178+
" # Create and run RandomSearch (sequential)\n",
179+
" random_search = RandomSearch(\n",
180+
" param_space=param_space,\n",
181+
" evaluate_fn=evaluate_params,\n",
182+
" metric_key=\"accuracy\",\n",
183+
" seed=42, # For reproducibility\n",
184+
" n_jobs=1 # Sequential execution\n",
185+
" )\n",
186+
" # Run the search\n",
187+
" search_result = random_search.run(trials=trials, verbose=True)\n",
188+
" # Store results for this model\n",
189+
" results[model_name] = {\n",
190+
" \"best_params\": search_result.best_params,\n",
191+
" \"best_score\": search_result.best_metrics.get(\"accuracy\", 0.0),\n",
192+
" \"metrics\": search_result.best_metrics,\n",
193+
" \"trials\": search_result.trials,\n",
194+
" \"history\": search_result.history\n",
195+
" }\n",
196+
" print(f\"Best params: {search_result.best_params}\")\n",
197+
" print(f\"Best score: {search_result.best_metrics.get('accuracy', 0.0):.4f}\")\n",
198+
" print(\"\\n\" + \"=\" * 60)\n",
199+
" print(\"Quick Hyperparameter Test Summary:\")\n",
200+
" for model_name, result in results.items():\n",
201+
" score = result.get(\"best_score\")\n",
202+
" print(f\"{model_name}: Best Score = {score:.4f}\")\n",
203+
" return results\n"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": 4,
209+
"id": "72b263d8",
210+
"metadata": {},
211+
"outputs": [
212+
{
213+
"name": "stdout",
214+
"output_type": "stream",
215+
"text": [
216+
"\n",
217+
"============================================================\n",
218+
"Testing all models on CIFAR-10 sample...\n",
219+
"Starting quick hyperparameter test on CIFAR-10\n",
220+
"============================================================\n",
221+
"Using RandomSearch with 5 trials per model\n",
222+
"\n",
223+
"Testing Decision Tree...\n",
224+
"Running 5 trials...\n",
225+
"Optimizing for metric: accuracy\n",
226+
"Trial 1/5: {'max_depth': 6, 'min_samples_split': 2, 'min_samples_leaf': 5, 'criterion': 'gini'}\n",
227+
"Trial 2/5: {'max_depth': 10, 'min_samples_split': 6, 'min_samples_leaf': 2, 'criterion': 'gini'}\n",
228+
"Trial 3/5: {'max_depth': 16, 'min_samples_split': 3, 'min_samples_leaf': 1, 'criterion': 'gini'}\n",
229+
"Trial 4/5: {'max_depth': 9, 'min_samples_split': 9, 'min_samples_leaf': 9, 'criterion': 'gini'}\n",
230+
"Trial 5/5: {'max_depth': 20, 'min_samples_split': 8, 'min_samples_leaf': 9, 'criterion': 'entropy'}\n",
231+
" -> New best! accuracy=0.2600\n",
232+
"Best params: {'max_depth': 6, 'min_samples_split': 2, 'min_samples_leaf': 5, 'criterion': 'gini'}\n",
233+
"Best score: 0.2600\n",
234+
"\n",
235+
"Testing K-Nearest Neighbors...\n",
236+
"Running 5 trials...\n",
237+
"Optimizing for metric: accuracy\n",
238+
"Trial 1/5: {'n_neighbors': 23, 'weights': 'uniform', 'metric': 'minkowski'}\n",
239+
"Trial 2/5: {'n_neighbors': 26, 'weights': 'distance', 'metric': 'manhattan'}\n",
240+
"Trial 3/5: {'n_neighbors': 10, 'weights': 'uniform', 'metric': 'minkowski'}\n",
241+
"Trial 4/5: {'n_neighbors': 24, 'weights': 'uniform', 'metric': 'chebyshev'}\n",
242+
"Trial 5/5: {'n_neighbors': 4, 'weights': 'uniform', 'metric': 'minkowski'}\n",
243+
" -> New best! accuracy=0.0800\n",
244+
" -> New best! accuracy=0.1200\n",
245+
"Best params: {'n_neighbors': 10, 'weights': 'uniform', 'metric': 'minkowski'}\n",
246+
"Best score: 0.1200\n",
247+
"\n",
248+
"Testing Convolutional Neural Network...\n",
249+
"Running 5 trials...\n",
250+
"Optimizing for metric: accuracy\n",
251+
"Trial 1/5: {'kernel_size': 5, 'stride': 1, 'learning_rate': 1.188590529831906e-05, 'batch_size': 64, 'weight_decay': 0.002448918538034762, 'optimizer': 'AdamW'}\n",
252+
"Trial 2/5: {'kernel_size': 5, 'stride': 1, 'learning_rate': 0.0010717622652265692, 'batch_size': 16, 'weight_decay': 0.005904925124490396, 'optimizer': 'AdamW'}\n",
253+
"Trial 3/5: {'kernel_size': 3, 'stride': 1, 'learning_rate': 4.5280782614269235e-05, 'batch_size': 16, 'weight_decay': 0.00561245062938613, 'optimizer': 'SGD'}\n",
254+
"Trial 4/5: {'kernel_size': 3, 'stride': 2, 'learning_rate': 0.0005858643226824373, 'batch_size': 16, 'weight_decay': 0.007588073671297673, 'optimizer': 'AdamW'}\n",
255+
"Trial 5/5: {'kernel_size': 5, 'stride': 2, 'learning_rate': 0.00010489421799219316, 'batch_size': 32, 'weight_decay': 0.002153137621075888, 'optimizer': 'SGD'}\n",
256+
" -> New best! accuracy=0.0800\n",
257+
" -> New best! accuracy=0.1000\n",
258+
"Best params: {'kernel_size': 5, 'stride': 2, 'learning_rate': 0.00010489421799219316, 'batch_size': 32, 'weight_decay': 0.002153137621075888, 'optimizer': 'SGD'}\n",
259+
"Best score: 0.1000\n",
260+
"\n",
261+
"============================================================\n",
262+
"Quick Hyperparameter Test Summary:\n",
263+
"Decision Tree: Best Score = 0.2600\n",
264+
"K-Nearest Neighbors: Best Score = 0.1200\n",
265+
"Convolutional Neural Network: Best Score = 0.1000\n"
266+
]
267+
}
268+
],
269+
"source": [
270+
"# Test on CIFAR-10 only\n",
271+
"print(\"\\n\" + \"=\" * 60)\n",
272+
"print(\"Testing all models on CIFAR-10 sample...\")\n",
273+
"\n",
274+
"# List of model keys to test\n",
275+
"model_keys: list[Literal['dt', 'knn', 'cnn']] = [\"dt\", \"knn\", \"cnn\"]\n",
276+
"\n",
277+
"cifar_results = quick_hyperparameter_test(\n",
278+
" model_keys, X_train, y_train, X_test, y_test, X_train_flat, X_test_flat, \"CIFAR-10\", trials=5\n",
279+
")\n",
280+
"\n"
281+
]
282+
}
283+
],
284+
"metadata": {
285+
"kernelspec": {
286+
"display_name": ".venv",
287+
"language": "python",
288+
"name": "python3"
289+
},
290+
"language_info": {
291+
"codemirror_mode": {
292+
"name": "ipython",
293+
"version": 3
294+
},
295+
"file_extension": ".py",
296+
"mimetype": "text/x-python",
297+
"name": "python",
298+
"nbconvert_exporter": "python",
299+
"pygments_lexer": "ipython3",
300+
"version": "3.13.7"
301+
}
302+
},
303+
"nbformat": 4,
304+
"nbformat_minor": 5
305+
}

0 commit comments

Comments
 (0)