Skip to content

Commit fba070a

Browse files
update retrieval test for new logic
1 parent 9e21b43 commit fba070a

1 file changed

Lines changed: 43 additions & 65 deletions

File tree

tests/test_retrieval.py

Lines changed: 43 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,44 @@
11
import pytest
2-
import numpy as np
3-
from retrieval_engine import RetrievalEngine
4-
import json
5-
import tempfile
6-
import os
7-
8-
@pytest.fixture
9-
def temp_faq_file():
10-
"""Crée une FAQ temporaire"""
11-
faq = [
12-
{"id": 1, "question": "Comment créer un compte ?", "answer": "Réponse 1"},
13-
{"id": 2, "question": "Prix de l'abonnement ?", "answer": "Réponse 2"},
14-
{"id": 3, "question": "Livraison internationale ?", "answer": "Réponse 3"}
15-
]
16-
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
17-
json.dump(faq, f)
18-
temp_path = f.name
19-
yield temp_path
20-
os.unlink(temp_path)
21-
22-
23-
def test_retrieval_initialization(temp_faq_file):
24-
"""Test l'initialisation du retrieval engine"""
25-
engine = RetrievalEngine(
26-
faq_path=temp_faq_file,
27-
model_name="sentence-transformers/all-MiniLM-L6-v2"
28-
)
29-
assert len(engine.faq_data) == 3
30-
assert engine.question_embeddings is not None
31-
assert engine.question_embeddings.shape[0] == 3
32-
33-
34-
def test_cosine_similarity(temp_faq_file):
35-
"""Test le calcul de similarité cosinus"""
36-
engine = RetrievalEngine(temp_faq_file, "sentence-transformers/all-MiniLM-L6-v2")
37-
vec1 = np.array([1.0, 0.0, 0.0])
38-
vec2 = np.array([1.0, 0.0, 0.0])
39-
sim = engine.cosine_similarity(vec1, vec2)
40-
assert abs(sim - 1.0) < 0.01 # Vecteurs identiques
41-
42-
43-
def test_get_best_match(temp_faq_file):
44-
"""Test la recherche de meilleure correspondance"""
45-
engine = RetrievalEngine(temp_faq_file, "sentence-transformers/all-MiniLM-L6-v2")
46-
result = engine.get_best_match("créer compte utilisateur")
47-
assert result["confidence"] > 0.3
48-
assert "Réponse 1" in result["answer"]
49-
assert result["matched_question"] is not None
50-
51-
52-
def test_get_top_k_matches(temp_faq_file):
53-
"""Test la récupération des top K résultats"""
54-
engine = RetrievalEngine(temp_faq_file, "sentence-transformers/all-MiniLM-L6-v2")
55-
results = engine.get_top_k_matches("prix livraison", k=2)
56-
assert len(results) == 2
57-
assert all("confidence" in r for r in results)
58-
assert results[0]["confidence"] >= results[1]["confidence"]
59-
60-
61-
def test_low_confidence_response(temp_faq_file):
62-
"""Test la gestion des requêtes sans match"""
63-
engine = RetrievalEngine(temp_faq_file, "sentence-transformers/all-MiniLM-L6-v2")
64-
result = engine.get_best_match("xyz question totalement aléatoire abc")
65-
assert "n'ai pas trouvé" in result["answer"].lower()
66-
assert result["matched_question"] is None
2+
from sqlmodel import Session
3+
from app.services.rag_engine import RAGService
4+
from app.db.models import FAQItem
5+
6+
def test_rag_static_rules():
7+
"""Vérifie que les règles statiques (Bonjour, etc.) fonctionnent sans DB."""
8+
engine = RAGService()
9+
# Test Bonjour
10+
result = engine.search("Bonjour")
11+
assert result["confidence"] == 1.0
12+
assert "aider" in result["answer"]
13+
assert result["provider"] == "static_rule"
14+
# Test Merci
15+
result = engine.search("Merci beaucoup")
16+
assert result["confidence"] == 1.0
17+
assert "questions" in result["answer"]
18+
19+
def test_rag_search_nominal(session: Session):
20+
"""Vérifie la recherche vectorielle avec des données en base."""
21+
# Peupler la base de test
22+
faq1 = FAQItem(question="Comment créer un compte ?", answer="Allez sur la page inscription.")
23+
faq2 = FAQItem(question="Quel est le prix ?", answer="C'est 10 euros.")
24+
session.add(faq1)
25+
session.add(faq2)
26+
session.commit()
27+
# Recharger le moteur RAG (synchronisation avec la DB)
28+
engine = RAGService()
29+
engine.reload_from_db(session)
30+
# Tester une recherche pertinente
31+
result = engine.search("créer compte utilisateur")
32+
# Le score devrait être élevé car la question est proche
33+
assert result["confidence"] > 0.7
34+
assert result["answer"] == "Allez sur la page inscription."
35+
assert result["faq_id"] is not None
36+
37+
def test_rag_search_no_match(session: Session):
38+
"""Vérifie le comportement quand rien ne correspond."""
39+
engine = RAGService()
40+
engine.reload_from_db(session) # RAG vide maintenant
41+
result = engine.search("Une question qui n'a aucun sens ici")
42+
# Sans données, la confiance doit être 0 ou très basse
43+
assert result["confidence"] < 0.5
44+
assert result["answer"] is None

0 commit comments

Comments
 (0)