1+ import numpy as np
2+ import pandas as pd
3+ import pytest
4+ import torch
5+ from torchvision .datasets import ImageFolder
6+ from torchvision import transforms
7+ from torch .utils .data import DataLoader
8+ from core .model_trainer import ModelTrainer
9+ from core .model_generator import ModelGenerator
10+ from unittest .mock import MagicMock
11+
12+ # ----------- UTILITIES -----------
13+
14+ class DummyImageFolder (ImageFolder ):
15+ def __init__ (self , transform = None ):
16+ self .samples = [(None , 0 )] * 10
17+ self .targets = [0 ] * 10
18+ self .classes = ['class0' ]
19+ self .imgs = self .samples
20+ self .transform = transform
21+
22+ def __len__ (self ):
23+ return 10
24+
25+ def __getitem__ (self , index ):
26+ image = torch .randn (1 , 128 , 128 ) # 1-channel image
27+ mask = torch .randint (0 , 2 , (128 , 128 )) # segmentation mask (binary)
28+ return image , mask
29+
30+ # ----------- TEST CASES -----------
31+
32+ def test_train_image_classification_cnn (monkeypatch ):
33+ # Create a dummy tabular dataset
34+ data = pd .DataFrame (np .random .rand (100 , 5 ))
35+ data .iloc [:, - 1 ] = np .random .randint (0 , 3 , size = 100 )
36+ trainer = ModelTrainer (data )
37+
38+ # Patch model.fit to avoid actual training
39+ dummy_model = ModelGenerator .get_image_classification_model ("CNN (Recommended)" , input_shape = (128 , 128 , 1 ), num_classes = 3 )
40+ dummy_model .fit = MagicMock (return_value = None )
41+
42+ monkeypatch .setattr (ModelGenerator , "get_image_classification_model" , lambda * args , ** kwargs : dummy_model )
43+
44+ model = trainer .train ("Image Classification" , "CNN (Recommended)" )
45+ assert model is not None
46+ dummy_model .fit .assert_called_once ()
47+
48+
49+ def test_train_image_classification_svm ():
50+ # Create dummy tabular dataset
51+ data = pd .DataFrame (np .random .rand (50 , 4 ))
52+ data .iloc [:, - 1 ] = np .random .randint (0 , 2 , size = 50 )
53+ trainer = ModelTrainer (data )
54+
55+ model = trainer .train ("Image Classification" , "SVM" )
56+ assert hasattr (model , "fit" )
57+ assert hasattr (model , "predict" )
58+
59+
60+ def test_train_image_segmentation_unet (monkeypatch ):
61+ # Fake image dataset for segmentation
62+ dataset = DummyImageFolder ()
63+ trainer = ModelTrainer (dataset )
64+
65+ # Patch model.fit to skip training
66+ dummy_model = ModelGenerator .get_image_segmentation_model ("U-Net (Recommended)" , input_shape = (128 , 128 , 1 ), num_classes = 1 )
67+ dummy_model .fit = MagicMock (return_value = None )
68+
69+ monkeypatch .setattr (ModelGenerator , "get_image_segmentation_model" , lambda * args , ** kwargs : dummy_model )
70+
71+ model = trainer .train ("Image Segmentation" , "U-Net (Recommended)" )
72+ assert model is not None
73+ dummy_model .fit .assert_called_once ()
74+
75+
76+ def test_train_image_segmentation_cnn (monkeypatch ):
77+ dataset = DummyImageFolder ()
78+ trainer = ModelTrainer (dataset )
79+
80+ dummy_model = ModelGenerator .get_image_segmentation_model ("CNN" , input_shape = (128 , 128 , 1 ), num_classes = 1 )
81+ dummy_model .fit = MagicMock (return_value = None )
82+
83+ monkeypatch .setattr (ModelGenerator , "get_image_segmentation_model" , lambda * args , ** kwargs : dummy_model )
84+
85+ model = trainer .train ("Image Segmentation" , "CNN" )
86+ assert model is not None
87+ dummy_model .fit .assert_called_once ()
88+
89+
90+ def test_train_voice_classification (monkeypatch ):
91+ # Create dummy voice features as tabular data
92+ data = pd .DataFrame (np .random .rand (60 , 20 ))
93+ data .iloc [:, - 1 ] = np .random .randint (0 , 2 , size = 60 )
94+ trainer = ModelTrainer (data )
95+
96+ dummy_model = ModelGenerator .get_voice_classification_model ()
97+ dummy_model .fit = MagicMock (return_value = None )
98+
99+ monkeypatch .setattr (ModelGenerator , "get_voice_classification_model" , lambda : dummy_model )
100+
101+ model = trainer .train ("Voice Classification" , "" )
102+ assert model is not None
103+ dummy_model .fit .assert_called_once ()
104+
105+
106+ def test_invalid_model_type_raises ():
107+ data = pd .DataFrame (np .random .rand (10 , 4 ))
108+ trainer = ModelTrainer (data )
109+
110+ with pytest .raises (ValueError ):
111+ trainer .train ("Unknown Task" , "" )
112+
113+ def test_invalid_task_type_raises ():
114+ data = pd .DataFrame (np .random .rand (10 , 4 ))
115+ trainer = ModelTrainer (data )
116+
117+ with pytest .raises (ValueError ):
118+ trainer .train ("Image Classification" , "UnknownModel" )
119+
120+ def test_invalid_dataset_type_raises ():
121+ invalid_data = "I am not a dataset"
122+ with pytest .raises (TypeError ):
123+ ModelTrainer (invalid_data )
0 commit comments