Skip to content

Commit 885152c

Browse files
committed
Implement Tests
1 parent 0383a5c commit 885152c

4 files changed

Lines changed: 229 additions & 1 deletion

File tree

core/model_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_voice_classification_model(self):
3030
@classmethod
3131
def _build_cnn(self, input_shape, num_classes):
3232
model = models.Sequential()
33-
model.add(layers.InputLayer(input_shape=input_shape))
33+
model.add(layers.InputLayer(shape=input_shape))
3434
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
3535
model.add(layers.MaxPooling2D((2, 2)))
3636
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

test/test_model_generator.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
from tensorflow.keras.models import Model
3+
from sklearn.svm import SVC
4+
from core.model_generator import ModelGenerator
5+
6+
def test_get_image_classification_model_cnn():
7+
input_shape = (64, 64, 3)
8+
model = ModelGenerator.get_image_classification_model("CNN (Recommended)", input_shape, num_classes=1)
9+
assert isinstance(model, Model)
10+
assert model.input_shape[1:] == input_shape
11+
assert model.output_shape[-1] == 1
12+
assert model.loss == 'sparse_categorical_crossentropy'
13+
14+
def test_get_image_classification_model_svm():
15+
model = ModelGenerator.get_image_classification_model("SVM", input_shape=(64,64,3))
16+
assert isinstance(model, SVC)
17+
assert model.kernel == "linear"
18+
assert model.probability is True
19+
20+
def test_get_image_classification_model_invalid_algorithm():
21+
with pytest.raises(ValueError):
22+
ModelGenerator.get_image_classification_model("InvalidAlgo", input_shape=(64,64,3))
23+
24+
def test_get_image_segmentation_model_unet():
25+
input_shape = (128, 128, 3)
26+
num_classes = 1
27+
model = ModelGenerator.get_image_segmentation_model("U-Net (Recommended)", input_shape, num_classes)
28+
assert isinstance(model, Model)
29+
assert model.input_shape[1:] == input_shape
30+
assert model.output_shape[-1] == num_classes
31+
# Check loss for binary classification
32+
assert model.loss == 'binary_crossentropy'
33+
34+
def test_get_image_segmentation_model_cnn():
35+
input_shape = (128, 128, 3)
36+
num_classes = 2
37+
model = ModelGenerator.get_image_segmentation_model("CNN", input_shape, num_classes)
38+
assert isinstance(model, Model)
39+
assert model.output_shape[-1] == num_classes
40+
# Multi-class last layer activation softmax
41+
last_layer_activation = model.layers[-1].activation.__name__
42+
assert last_layer_activation == "softmax"
43+
44+
def test_get_image_segmentation_model_invalid_algorithm():
45+
with pytest.raises(ValueError):
46+
ModelGenerator.get_image_segmentation_model("InvalidAlgo", input_shape=(64,64,3))
47+
48+
def test_get_voice_classification_model_returns_svc():
49+
model = ModelGenerator.get_voice_classification_model()
50+
assert isinstance(model, SVC)
51+
assert model.kernel == "linear"
52+
assert model.probability is True

test/test_model_trainer.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
import torch
5+
from torchvision.datasets import ImageFolder
6+
from torchvision import transforms
7+
from torch.utils.data import DataLoader
8+
from core.model_trainer import ModelTrainer
9+
from core.model_generator import ModelGenerator
10+
from unittest.mock import MagicMock
11+
12+
# ----------- UTILITIES -----------
13+
14+
class DummyImageFolder(ImageFolder):
15+
def __init__(self, transform=None):
16+
self.samples = [(None, 0)] * 10
17+
self.targets = [0] * 10
18+
self.classes = ['class0']
19+
self.imgs = self.samples
20+
self.transform = transform
21+
22+
def __len__(self):
23+
return 10
24+
25+
def __getitem__(self, index):
26+
image = torch.randn(1, 128, 128) # 1-channel image
27+
mask = torch.randint(0, 2, (128, 128)) # segmentation mask (binary)
28+
return image, mask
29+
30+
# ----------- TEST CASES -----------
31+
32+
def test_train_image_classification_cnn(monkeypatch):
33+
# Create a dummy tabular dataset
34+
data = pd.DataFrame(np.random.rand(100, 5))
35+
data.iloc[:, -1] = np.random.randint(0, 3, size=100)
36+
trainer = ModelTrainer(data)
37+
38+
# Patch model.fit to avoid actual training
39+
dummy_model = ModelGenerator.get_image_classification_model("CNN (Recommended)", input_shape=(128, 128, 1), num_classes=3)
40+
dummy_model.fit = MagicMock(return_value=None)
41+
42+
monkeypatch.setattr(ModelGenerator, "get_image_classification_model", lambda *args, **kwargs: dummy_model)
43+
44+
model = trainer.train("Image Classification", "CNN (Recommended)")
45+
assert model is not None
46+
dummy_model.fit.assert_called_once()
47+
48+
49+
def test_train_image_classification_svm():
50+
# Create dummy tabular dataset
51+
data = pd.DataFrame(np.random.rand(50, 4))
52+
data.iloc[:, -1] = np.random.randint(0, 2, size=50)
53+
trainer = ModelTrainer(data)
54+
55+
model = trainer.train("Image Classification", "SVM")
56+
assert hasattr(model, "fit")
57+
assert hasattr(model, "predict")
58+
59+
60+
def test_train_image_segmentation_unet(monkeypatch):
61+
# Fake image dataset for segmentation
62+
dataset = DummyImageFolder()
63+
trainer = ModelTrainer(dataset)
64+
65+
# Patch model.fit to skip training
66+
dummy_model = ModelGenerator.get_image_segmentation_model("U-Net (Recommended)", input_shape=(128, 128, 1), num_classes=1)
67+
dummy_model.fit = MagicMock(return_value=None)
68+
69+
monkeypatch.setattr(ModelGenerator, "get_image_segmentation_model", lambda *args, **kwargs: dummy_model)
70+
71+
model = trainer.train("Image Segmentation", "U-Net (Recommended)")
72+
assert model is not None
73+
dummy_model.fit.assert_called_once()
74+
75+
76+
def test_train_image_segmentation_cnn(monkeypatch):
77+
dataset = DummyImageFolder()
78+
trainer = ModelTrainer(dataset)
79+
80+
dummy_model = ModelGenerator.get_image_segmentation_model("CNN", input_shape=(128, 128, 1), num_classes=1)
81+
dummy_model.fit = MagicMock(return_value=None)
82+
83+
monkeypatch.setattr(ModelGenerator, "get_image_segmentation_model", lambda *args, **kwargs: dummy_model)
84+
85+
model = trainer.train("Image Segmentation", "CNN")
86+
assert model is not None
87+
dummy_model.fit.assert_called_once()
88+
89+
90+
def test_train_voice_classification(monkeypatch):
91+
# Create dummy voice features as tabular data
92+
data = pd.DataFrame(np.random.rand(60, 20))
93+
data.iloc[:, -1] = np.random.randint(0, 2, size=60)
94+
trainer = ModelTrainer(data)
95+
96+
dummy_model = ModelGenerator.get_voice_classification_model()
97+
dummy_model.fit = MagicMock(return_value=None)
98+
99+
monkeypatch.setattr(ModelGenerator, "get_voice_classification_model", lambda: dummy_model)
100+
101+
model = trainer.train("Voice Classification", "")
102+
assert model is not None
103+
dummy_model.fit.assert_called_once()
104+
105+
106+
def test_invalid_model_type_raises():
107+
data = pd.DataFrame(np.random.rand(10, 4))
108+
trainer = ModelTrainer(data)
109+
110+
with pytest.raises(ValueError):
111+
trainer.train("Unknown Task", "")
112+
113+
def test_invalid_task_type_raises():
114+
data = pd.DataFrame(np.random.rand(10, 4))
115+
trainer = ModelTrainer(data)
116+
117+
with pytest.raises(ValueError):
118+
trainer.train("Image Classification", "UnknownModel")
119+
120+
def test_invalid_dataset_type_raises():
121+
invalid_data = "I am not a dataset"
122+
with pytest.raises(TypeError):
123+
ModelTrainer(invalid_data)

test/uitest_trainer_widget.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import pandas as pd
3+
from torchvision.datasets import ImageFolder
4+
from unittest.mock import MagicMock, patch
5+
from PyQt6.QtWidgets import QApplication, QFileDialog
6+
from presentation.widgets.trainer_widget import TrainerWidget
7+
8+
@pytest.fixture
9+
def trainer_widget(qtbot):
10+
"""Create and return a TrainerWidget instance."""
11+
widget = TrainerWidget()
12+
qtbot.addWidget(widget)
13+
widget.show() # optional, can be useful for some tests
14+
return widget
15+
16+
@pytest.fixture
17+
def valid_dataframe_dataset():
18+
return pd.DataFrame({'feature1': [1, 2], 'feature2': [3, 4], 'label': [0, 1]})
19+
20+
@pytest.fixture
21+
def valid_imagefolder_dataset(tmp_path):
22+
class_to_dir = tmp_path / "class_a"
23+
class_to_dir.mkdir()
24+
dummy_file = class_to_dir / "dummy.jpg"
25+
dummy_file.write_text("dummy content")
26+
return ImageFolder(str(tmp_path))
27+
28+
def test_model_selection_shows_algorithms(trainer_widget, qtbot, valid_dataframe_dataset):
29+
trainer_widget.dataset = valid_dataframe_dataset
30+
trainer_widget.model_combobox.setCurrentIndex(1)
31+
qtbot.wait(100)
32+
assert trainer_widget.algorithm_label.isVisible()
33+
34+
def test_train_model_button_behavior(trainer_widget, qtbot, valid_dataframe_dataset):
35+
trainer_widget.dataset = valid_dataframe_dataset
36+
with patch("presentation.widgets.trainer_widget.ModelTrainer.train", return_value=MagicMock()) as mock_train:
37+
trainer_widget.train_model()
38+
qtbot.wait(100)
39+
mock_train.assert_called_once()
40+
41+
def test_save_model_button_behavior(trainer_widget, qtbot, valid_dataframe_dataset, tmp_path, monkeypatch):
42+
trainer_widget.dataset = valid_dataframe_dataset
43+
trainer_widget.model = MagicMock()
44+
45+
fake_path = str(tmp_path / "model.pth")
46+
47+
# Patch QFileDialog.getSaveFileName globally
48+
monkeypatch.setattr(QFileDialog, "getSaveFileName", lambda *args, **kwargs: (fake_path, 'All Files (*)'))
49+
50+
# Patch the actual saving logic to prevent real file I/O and serialization
51+
with patch("presentation.widgets.trainer_widget.save_model_to_file") as mock_save:
52+
trainer_widget.save_model()
53+
mock_save.assert_called_once_with(trainer_widget.model, fake_path)

0 commit comments

Comments
 (0)