Skip to content

Commit 9377406

Browse files
Write few tests
1 parent 0c7270d commit 9377406

1 file changed

Lines changed: 173 additions & 1 deletion

File tree

core/tests.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,175 @@
11
from django.test import TestCase
2+
from unittest.mock import patch, MagicMock
23

3-
# Create your tests here.
4+
from .services import SQLGenerationService
5+
6+
7+
class SQLGenerationServiceTests(TestCase):
8+
9+
def test_format_schema_for_prompt(self):
10+
schema = {
11+
'tables': ['users', 'orders'],
12+
'columns': [
13+
{'table_name': 'users', 'column_name': 'id', 'data_type': 'integer'},
14+
{'table_name': 'users', 'column_name': 'name', 'data_type': 'varchar'},
15+
{'table_name': 'orders', 'column_name': 'id', 'data_type': 'integer'},
16+
{'table_name': 'orders', 'column_name': 'user_id', 'data_type': 'integer'}
17+
]
18+
}
19+
20+
result = SQLGenerationService.format_schema_for_prompt(schema)
21+
22+
self.assertIn('Table: users', result)
23+
self.assertIn('Table: orders', result)
24+
self.assertIn('id (integer)', result)
25+
self.assertIn('name (varchar)', result)
26+
27+
def test_format_schema_empty_tables(self):
28+
schema = {
29+
'tables': [],
30+
'columns': []
31+
}
32+
33+
result = SQLGenerationService.format_schema_for_prompt(schema)
34+
35+
self.assertEqual(result, '')
36+
37+
def test_format_schema_single_table(self):
38+
schema = {
39+
'tables': ['products'],
40+
'columns': [
41+
{'table_name': 'products', 'column_name': 'id', 'data_type': 'serial'},
42+
{'table_name': 'products', 'column_name': 'name', 'data_type': 'varchar'},
43+
{'table_name': 'products', 'column_name': 'price', 'data_type': 'numeric'}
44+
]
45+
}
46+
47+
result = SQLGenerationService.format_schema_for_prompt(schema)
48+
49+
self.assertIn('Table: products', result)
50+
self.assertIn('id (serial)', result)
51+
self.assertIn('price (numeric)', result)
52+
53+
@patch('core.services.SQLGenerationService.get_llm')
54+
def test_generate_sql_success(self, mock_get_llm):
55+
mock_chain = MagicMock()
56+
mock_chain.run.return_value = "SELECT * FROM users WHERE id = 1"
57+
58+
with patch('core.services.LLMChain', return_value=mock_chain):
59+
result = SQLGenerationService.generate_sql("Get all users", "Schema info")
60+
61+
self.assertIn("SELECT", result)
62+
63+
@patch('core.services.SQLGenerationService.get_llm')
64+
def test_generate_sql_strips_sql_prefix(self, mock_get_llm):
65+
mock_chain = MagicMock()
66+
mock_chain.run.return_value = "```sql\nSELECT * FROM users\n```"
67+
68+
with patch('core.services.LLMChain', return_value=mock_chain):
69+
result = SQLGenerationService.generate_sql("Get users", "Schema")
70+
71+
self.assertTrue(result.startswith("SELECT"))
72+
73+
@patch('core.services.SQLGenerationService.get_llm')
74+
def test_generate_sql_strips_sql_prefix_lowercase(self, mock_get_llm):
75+
mock_chain = MagicMock()
76+
mock_chain.run.return_value = "sql\nSELECT * FROM users"
77+
78+
with patch('core.services.LLMChain', return_value=mock_chain):
79+
result = SQLGenerationService.generate_sql("Get users", "Schema")
80+
81+
self.assertTrue(result.startswith("SELECT"))
82+
83+
@patch('core.services.SQLGenerationService.get_llm')
84+
def test_generate_sql_strips_backticks(self, mock_get_llm):
85+
mock_chain = MagicMock()
86+
mock_chain.run.return_value = "SELECT * FROM orders"
87+
88+
with patch('core.services.LLMChain', return_value=mock_chain):
89+
result = SQLGenerationService.generate_sql("Get orders", "Schema")
90+
91+
self.assertEqual(result, "SELECT * FROM orders")
92+
93+
@patch('core.services.SQLGenerationService.get_llm')
94+
def test_generate_sql_insert_statement(self, mock_get_llm):
95+
mock_chain = MagicMock()
96+
mock_chain.run.return_value = "INSERT INTO users (name) VALUES ('John')"
97+
98+
with patch('core.services.LLMChain', return_value=mock_chain):
99+
result = SQLGenerationService.generate_sql("Insert John", "Schema")
100+
101+
self.assertIn("INSERT", result)
102+
103+
104+
class EmbeddingServiceUnitTests(TestCase):
105+
106+
@patch('core.services.EmbeddingService.get_model')
107+
def test_embed_text_returns_list(self, mock_get_model):
108+
from .services import EmbeddingService
109+
110+
mock_model = MagicMock()
111+
mock_array = MagicMock()
112+
mock_array.tolist.return_value = [0.1, 0.2, 0.3, 0.4]
113+
mock_model.encode.return_value = mock_array
114+
mock_get_model.return_value = mock_model
115+
116+
result = EmbeddingService.embed_text("test text")
117+
118+
self.assertIsInstance(result, list)
119+
self.assertEqual(len(result), 4)
120+
mock_model.encode.assert_called_once()
121+
122+
@patch('core.services.SentenceTransformer')
123+
def test_embed_text_default_model_name(self, mock_transformer):
124+
from .services import EmbeddingService
125+
126+
mock_model = MagicMock()
127+
mock_array = MagicMock()
128+
mock_array.tolist.return_value = [0.1]
129+
mock_model.encode.return_value = mock_array
130+
mock_transformer.return_value = mock_model
131+
132+
EmbeddingService._model = None
133+
134+
EmbeddingService.embed_text("test")
135+
136+
mock_transformer.assert_called_once_with('all-MiniLM-L6-v2')
137+
138+
@patch('core.services.EmbeddingService.get_model')
139+
def test_embed_text_model_cached(self, mock_get_model):
140+
from .services import EmbeddingService
141+
142+
mock_model = MagicMock()
143+
mock_array = MagicMock()
144+
mock_array.tolist.return_value = [0.1]
145+
mock_model.encode.return_value = mock_array
146+
mock_get_model.return_value = mock_model
147+
148+
EmbeddingService.embed_text("test")
149+
EmbeddingService.embed_text("test2")
150+
151+
self.assertEqual(mock_model.encode.call_count, 2)
152+
153+
@patch('core.services.EmbeddingService.get_model')
154+
def test_embed_documents_multiple_texts(self, mock_get_model):
155+
from .services import EmbeddingService
156+
157+
mock_model = MagicMock()
158+
159+
embedding1 = MagicMock()
160+
embedding1.tolist.return_value = [0.1, 0.2]
161+
embedding2 = MagicMock()
162+
embedding2.tolist.return_value = [0.3, 0.4]
163+
embedding3 = MagicMock()
164+
embedding3.tolist.return_value = [0.5, 0.6]
165+
166+
mock_model.encode.return_value = [embedding1, embedding2, embedding3]
167+
mock_get_model.return_value = mock_model
168+
169+
result = EmbeddingService.embed_documents(["text1", "text2", "text3"])
170+
171+
self.assertIsInstance(result, list)
172+
self.assertEqual(len(result), 3)
173+
self.assertEqual(result[0], [0.1, 0.2])
174+
self.assertEqual(result[1], [0.3, 0.4])
175+
self.assertEqual(result[2], [0.5, 0.6])

0 commit comments

Comments
 (0)