-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathconftest.py
More file actions
190 lines (156 loc) · 4.63 KB
/
conftest.py
File metadata and controls
190 lines (156 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Shared pytest fixtures for the testing suite.
This module contains common fixtures that can be used across all test modules.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch
import os
@pytest.fixture
def temp_dir():
"""Create a temporary directory that gets cleaned up after the test."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)
@pytest.fixture
def temp_file():
"""Create a temporary file that gets cleaned up after the test."""
temp_fd, temp_path = tempfile.mkstemp()
os.close(temp_fd)
yield Path(temp_path)
if Path(temp_path).exists():
Path(temp_path).unlink()
@pytest.fixture
def mock_wandb():
"""Mock wandb for tests that don't need actual logging."""
try:
with patch('wandb.init') as mock_init, \
patch('wandb.log') as mock_log, \
patch('wandb.finish') as mock_finish:
mock_init.return_value = Mock()
yield {
'init': mock_init,
'log': mock_log,
'finish': mock_finish
}
except ImportError:
# wandb not available, provide mock objects
yield {
'init': Mock(),
'log': Mock(),
'finish': Mock()
}
@pytest.fixture
def mock_tensorflow():
"""Mock tensorflow imports for tests that don't need actual TF."""
with patch.dict('sys.modules', {
'tensorflow': Mock(),
'tensorflow.keras': Mock(),
'tensorflow.keras.models': Mock(),
'tensorflow.keras.layers': Mock()
}):
yield
@pytest.fixture
def sample_params():
"""Sample parameters dictionary for testing."""
return {
'model': {
'name': 'test_model',
'layers': [128, 64, 32],
'activation': 'relu'
},
'training': {
'epochs': 10,
'batch_size': 32,
'learning_rate': 0.001
},
'data': {
'path': '/path/to/data',
'train_split': 0.8,
'val_split': 0.2
}
}
@pytest.fixture
def sample_config_file(temp_dir):
"""Create a sample configuration file for testing."""
config_content = """
model:
name: test_model
layers: [128, 64, 32]
activation: relu
training:
epochs: 10
batch_size: 32
learning_rate: 0.001
data:
path: /path/to/data
train_split: 0.8
val_split: 0.2
"""
config_file = temp_dir / "config.yaml"
config_file.write_text(config_content)
return config_file
@pytest.fixture
def mock_huggingface_hub():
"""Mock Hugging Face Hub for tests."""
with patch('huggingface_hub.login') as mock_login, \
patch('huggingface_hub.upload_file') as mock_upload, \
patch('huggingface_hub.download_file') as mock_download:
yield {
'login': mock_login,
'upload': mock_upload,
'download': mock_download
}
@pytest.fixture
def mock_dvc():
"""Mock DVC operations for tests."""
with patch('dvc.repo.Repo') as mock_repo:
mock_instance = Mock()
mock_repo.return_value = mock_instance
yield mock_instance
@pytest.fixture
def sample_data():
"""Sample data arrays for testing ML models."""
import numpy as np
# Generate sample training data
X_train = np.random.random((100, 10))
y_train = np.random.randint(0, 2, (100,))
# Generate sample test data
X_test = np.random.random((20, 10))
y_test = np.random.randint(0, 2, (20,))
return {
'X_train': X_train,
'y_train': y_train,
'X_test': X_test,
'y_test': y_test
}
@pytest.fixture(scope="session")
def test_env_vars():
"""Set up test environment variables for the session."""
test_vars = {
'WANDB_MODE': 'offline',
'TF_CPP_MIN_LOG_LEVEL': '3',
'PYTHONPATH': str(Path.cwd())
}
original_vars = {}
for key, value in test_vars.items():
original_vars[key] = os.environ.get(key)
os.environ[key] = value
yield test_vars
# Restore original environment variables
for key, original_value in original_vars.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.fixture
def mock_typer_app():
"""Mock Typer application for CLI testing."""
from unittest.mock import Mock
return Mock()
@pytest.fixture(autouse=True)
def setup_test_environment(test_env_vars):
"""Automatically set up the test environment for all tests."""
pass