Skip to content

Commit 9cdb534

Browse files
committed
fix
1 parent cbf6f60 commit 9cdb534

2 files changed

Lines changed: 242 additions & 71 deletions

File tree

3.74 KB
Binary file not shown.

agents/query_analyzer.py

Lines changed: 242 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,245 @@
1-
# agents/query_analyzer.py
2-
import spacy
3-
import logging # Added import
4-
from .base import BaseAgent
5-
import re
1+
import logging
62
import time
3+
import faiss
4+
import pickle
5+
import numpy as np
6+
import itertools
7+
import re # Import re
8+
from .base import BaseAgent
9+
from gemini_utils import embed_text
10+
from utils.text_utils import simple_keyword_score, simple_entity_score, section_relevance_score
11+
from config import Config
12+
13+
logger = logging.getLogger(__name__)
14+
15+
DEFAULT_HYBRID_INITIAL_TOP_K = Config.RETRIEVER_INITIAL_K
16+
DEFAULT_HYBRID_FINAL_TOP_K = Config.RETRIEVER_FINAL_K
17+
18+
class RetrieverAgent(BaseAgent):
19+
"""Agent responsible for retrieving and re-ranking relevant text chunks."""
20+
def __init__(self, index_path="faiss_index.index", metadata_path="faiss_metadata.pkl"):
21+
logger.info(f"💾 Loading FAISS index from: {index_path}")
22+
try:
23+
self.index = faiss.read_index(index_path)
24+
logger.info(f"✅ FAISS index loaded successfully. Index dimension: {self.index.d}, Total vectors: {self.index.ntotal}")
25+
except Exception as e:
26+
logger.error(f"❌ Failed to load FAISS index: {e}", exc_info=True)
27+
raise
28+
logger.info(f"💾 Loading metadata from: {metadata_path}")
29+
try:
30+
with open(metadata_path, "rb") as f:
31+
self.metadatas = pickle.load(f)
32+
# Pre-extract texts for faster access if needed elsewhere
33+
self.texts = [m.pop('text', '') for m in self.metadatas] # Extract text and remove from metadata dict
34+
logger.info(f"✅ Metadata loaded successfully. Number of entries: {len(self.metadatas)}")
35+
if len(self.metadatas) != self.index.ntotal:
36+
logger.warning(f"⚠️ Mismatch between index size ({self.index.ntotal}) and metadata count ({len(self.metadatas)}).")
37+
except Exception as e:
38+
logger.error(f"❌ Failed to load metadata: {e}", exc_info=True)
39+
raise
740

8-
# Load the spaCy model once when the class is instantiated
9-
try:
10-
nlp = spacy.load("en_core_web_sm")
11-
print("✅ spaCy model 'en_core_web_sm' loaded successfully.")
12-
except OSError:
13-
print("❌ Error loading spaCy model 'en_core_web_sm'.")
14-
print(" Please run: python -m spacy download en_core_web_sm")
15-
nlp = None # Set nlp to None if loading fails
16-
17-
logger = logging.getLogger(__name__) # Get a logger for this module
18-
19-
class QueryAnalyzerAgent(BaseAgent):
20-
"""Agent responsible for analyzing the user query."""
21-
def run(self, query: str, chat_history: list = None) -> dict: # Add chat_history parameter
22-
start_time = time.time()
23-
logger.debug(f"Analyzing query: '{query}' with history: {chat_history is not None}") # Log if history is present
24-
if not nlp:
25-
logger.warning("spaCy model not loaded, falling back to basic analysis.")
26-
# Fallback basic extraction (similar to previous web.py logic)
27-
keywords = re.findall(r'"(.*?)"|\b[A-Z][a-zA-Z]+\b', query)
28-
entities = re.findall(r'\b[A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\b', query)
29-
keywords = list(set([k.strip().lower() for k in keywords if k]))
30-
entities = list(set([e.strip() for e in entities if len(e.split()) > 1 or e in keywords]))
41+
def re_rank_chunks(self, initial_results, query, query_analysis):
42+
"""Re-rank chunks based on multiple factors using utility functions."""
43+
rerank_start_time = time.time()
44+
logger.info("⚖️ Re-ranking retrieved chunks...")
45+
if not initial_results:
46+
logger.warning("No initial results to re-rank.")
47+
return []
48+
49+
keywords = query_analysis.get("keywords", [])
50+
entities = query_analysis.get("entities", [])
51+
query_type = query_analysis.get("query_type", "unknown")
52+
intent_type = query_analysis.get("intent_type", "new_topic") # Get intent
53+
topic_keywords = query_analysis.get("topic_keywords", []) # Get topic keywords
54+
topic_entities = query_analysis.get("topic_entities", []) # Get topic entities
55+
56+
query_keywords_set = set(keywords)
57+
topic_terms_set = set(topic_keywords + topic_entities) # Combine topic terms
58+
59+
logger.debug(f"Re-ranking based on -> Query Keywords: {keywords}, Entities: {entities}, Type: {query_type}, Intent: {intent_type}, Topic Terms: {topic_terms_set}")
60+
61+
# --- Tuned Weights ---
62+
# Adjust weights based on intent? (Example)
63+
if intent_type in ["follow_up", "clarification"] and topic_terms_set:
64+
logger.debug("Adjusting weights for follow-up/clarification intent.")
65+
weights = {
66+
"semantic": 0.15, # Slightly lower semantic weight for current query
67+
"keyword": 0.4, # Keep keyword weight
68+
"entity": 0.25, # Keep entity weight
69+
"topic": 0.2, # Add weight for topic relevance
70+
"section": 0.0
71+
}
3172
else:
32-
# TODO: Incorporate chat_history into spaCy analysis if needed
33-
# For now, just process the current query
34-
doc = nlp(query)
35-
36-
# Extract Named Entities (GPE, PERSON, ORG, LOC, EVENT, DATE etc.)
37-
entities = list(set([ent.text.strip() for ent in doc.ents if ent.label_ in ["GPE", "PERSON", "ORG", "LOC", "EVENT", "DATE", "FAC", "PRODUCT", "WORK_OF_ART"]]))
38-
39-
# Extract Keywords (Noun chunks and Proper Nouns)
40-
keywords = list(set([chunk.text.lower().strip() for chunk in doc.noun_chunks]))
41-
# Add proper nouns that might not be part of chunks or recognized entities
42-
keywords.extend([token.text.lower().strip() for token in doc if token.pos_ == "PROPN" and token.text not in entities])
43-
# Remove duplicates that might exist between entities and keywords after lowercasing
44-
keywords = list(set(keywords))
45-
# Optional: Remove very short keywords if needed
46-
# keywords = [kw for kw in keywords if len(kw) > 2]
47-
48-
# Determine Query Type (Keep existing logic)
49-
query_lower = query.lower()
50-
query_type = "unknown"
51-
if "cause" in query_lower or "why" in query_lower or "effect" in query_lower or "impact" in query_lower:
52-
query_type = "causal/analytical"
53-
elif "compare" in query_lower or "difference" in query_lower or "similar" in query_lower or "contrast" in query_lower:
54-
query_type = "comparative"
55-
elif re.match(r"^(what|who|when|where|which)\s+(is|was|are|were|did|do|does)\b", query_lower) or \
56-
re.match(r"^(define|describe|explain|list)\b", query_lower):
57-
query_type = "factual"
58-
# Add more rules if needed
59-
60-
analysis = {
61-
"original_query": query, # Add the original query here
62-
"keywords": keywords,
63-
"entities": entities,
64-
"query_type": query_type,
65-
# Optionally include history info if used
66-
# "history_considered": chat_history is not None
67-
}
68-
69-
end_time = time.time()
70-
# Log the extracted information
71-
logger.debug(f"Analysis Results: Keywords: {analysis['keywords']}, Entities: {analysis['entities']}, Query Type: {analysis['query_type']}")
72-
logger.debug(f"Analysis Time: {end_time - start_time:.4f}s")
73-
74-
return analysis
73+
weights = {
74+
"semantic": 0.2,
75+
"keyword": 0.5,
76+
"entity": 0.3,
77+
"topic": 0.0, # No topic weight for new topics
78+
"section": 0.0
79+
}
80+
# ---------------------
81+
82+
# Normalize semantic scores (FAISS distances are lower for better matches)
83+
max_faiss_dist = max(r["score"] for r in initial_results) if initial_results else 1.0
84+
if max_faiss_dist <= 0: # Avoid division by zero
85+
max_faiss_dist = 1.0
86+
87+
logger.debug(f"Re-ranking {len(initial_results)} chunks...")
88+
for i, result in enumerate(initial_results):
89+
text_lower = self.texts[result["index"]].lower() # Get text using index
90+
result["text"] = self.texts[result["index"]] # Add full text back for generator
91+
result["metadata"] = self.metadatas[result["index"]] # Add metadata back
92+
93+
result["semantic_score"] = max(0.0, 1.0 - (max(0.0, result["score"]) / max_faiss_dist))
94+
# Use utility functions for scoring
95+
result["keyword_score"] = simple_keyword_score(text_lower, query_keywords_set)
96+
result["entity_score"] = simple_entity_score(text_lower, entities)
97+
result["section_score"] = section_relevance_score(result["metadata"], query_type)
98+
# Add topic score if applicable
99+
result["topic_score"] = simple_keyword_score(text_lower, topic_terms_set) if weights["topic"] > 0 else 0.0
100+
101+
combined_score = (
102+
weights["semantic"] * result["semantic_score"] +
103+
weights["keyword"] * result["keyword_score"] +
104+
weights["entity"] * result["entity_score"] +
105+
weights["topic"] * result["topic_score"] # Include topic score
106+
# + weights["section"] * result["section_score"] # Section score currently unused
107+
)
108+
result["combined_score"] = combined_score
109+
110+
# Confidence calculation (can be refined)
111+
if combined_score > 0.75:
112+
confidence = 0.95
113+
elif combined_score > 0.6:
114+
confidence = 0.8
115+
elif combined_score > 0.45:
116+
confidence = 0.65
117+
elif combined_score > 0.3:
118+
confidence = 0.5
119+
else:
120+
confidence = 0.3
121+
result["confidence"] = confidence
122+
123+
# Sort by combined score
124+
ranked_results = sorted(initial_results, key=lambda x: x["combined_score"], reverse=True)
125+
126+
# Filter based on presence of *query* keywords/entities (important!)
127+
logger.info(f"🔍 Filtering {len(ranked_results)} re-ranked chunks for *query* keyword/entity presence...")
128+
filtered_results = []
129+
query_terms_lower = {k.lower() for k in keywords} | {e.lower() for e in entities}
130+
131+
# If the query itself has no terms, but it's a follow-up, rely on topic terms for filtering?
132+
# Or maybe skip filtering if query terms are absent? Let's skip for now.
133+
if not query_terms_lower and intent_type not in ["follow_up", "clarification"]:
134+
logger.warning("⚠️ No keywords or entities found in query analysis, and not a follow-up. Skipping filtering.")
135+
filtered_results = ranked_results
136+
elif not query_terms_lower and intent_type in ["follow_up", "clarification"]:
137+
logger.warning("⚠️ No keywords or entities in query, but it's a follow-up/clarification. Filtering based on *topic* terms.")
138+
filter_terms = {t.lower() for t in topic_terms_set} # Use topic terms for filtering
139+
if not filter_terms:
140+
logger.warning("⚠️ No topic terms found either. Skipping filtering.")
141+
filtered_results = ranked_results
142+
else:
143+
for result in ranked_results:
144+
text_lower = result["text"].lower()
145+
# Check for topic terms instead of query terms
146+
if any(re.search(r'\b' + re.escape(term) + r'\b', text_lower) for term in filter_terms):
147+
filtered_results.append(result)
148+
else:
149+
# Standard filtering based on query terms
150+
filter_terms = query_terms_lower
151+
for result in ranked_results:
152+
text_lower = result["text"].lower()
153+
if any(re.search(r'\b' + re.escape(term) + r'\b', text_lower) for term in filter_terms):
154+
filtered_results.append(result)
155+
156+
157+
logger.info(f"✅ Filtered down to {len(filtered_results)} chunks containing relevant terms.")
158+
logger.debug("Top 5 Filtered & Re-ranked Chunks (Combined | Sem | Key | Ent | Top | Conf | Page):")
159+
for i, r in enumerate(filtered_results[:5]):
160+
page = r.get("metadata", {}).get("page", "?")
161+
logger.debug(f"{i+1}. Score={r['combined_score']:.3f} (S:{r['semantic_score']:.2f} K:{r['keyword_score']:.2f} E:{r['entity_score']:.2f} T:{r['topic_score']:.2f}) | Conf={r['confidence']:.2f} | Page={page} | Text: {r['text'][:100]}...")
162+
163+
total_rerank_time = time.time() - rerank_start_time
164+
logger.info(f"Step 2b: Re-ranking & Filtering took: {total_rerank_time:.4f}s")
165+
return filtered_results
166+
167+
168+
def _simple_expand_query(self, query_analysis: dict, max_expansions: int = 2) -> list[str]:
169+
"""Generates simple query variations based on keywords and entities."""
170+
expansions = []
171+
keywords = query_analysis.get("keywords", [])
172+
entities = query_analysis.get("entities", [])
173+
# Consider adding topic terms if it's a follow-up with few query terms?
174+
intent_type = query_analysis.get("intent_type", "new_topic")
175+
topic_keywords = query_analysis.get("topic_keywords", [])
176+
topic_entities = query_analysis.get("topic_entities", [])
177+
178+
terms = list(set(entities + keywords))
179+
180+
# If few terms in query but it's a follow-up, add topic terms to expansion base
181+
if len(terms) < 2 and intent_type in ["follow_up", "clarification"]:
182+
logger.debug("Expanding query using topic terms for follow-up.")
183+
terms.extend(topic_keywords)
184+
terms.extend(topic_entities)
185+
terms = list(set(terms)) # Ensure uniqueness
186+
187+
if not terms:
188+
return []
189+
190+
# Prioritize entities for combinations
191+
priority_terms = entities if entities else keywords
192+
other_terms = keywords if entities else []
193+
194+
# Generate pairs (priority x other, priority x priority)
195+
pairs = []
196+
if priority_terms and other_terms:
197+
pairs.extend(list(itertools.product(priority_terms, other_terms)))
198+
if len(priority_terms) >= 2:
199+
pairs.extend(list(itertools.combinations(priority_terms, 2)))
200+
201+
# Add single terms if not enough pairs
202+
if len(pairs) < max_expansions:
203+
pairs.extend([(t,) for t in terms]) # Add single terms
204+
205+
# Create expansion strings
206+
for pair in pairs:
207+
expansions.append(" ".join(pair))
208+
if len(expansions) >= max_expansions:
209+
break
210+
211+
# Fallback: if still no expansions, use top terms directly
212+
if not expansions and terms:
213+
expansions.extend(terms[:max_expansions])
214+
215+
unique_expansions = list(dict.fromkeys(expansions)) # Maintain order while making unique
216+
logger.debug(f"Generated query expansions: {unique_expansions[:max_expansions]}")
217+
return unique_expansions[:max_expansions]
218+
219+
220+
def run(self, query: str, query_analysis: dict, initial_top_k: int = DEFAULT_HYBRID_INITIAL_TOP_K, final_top_k: int = 5):
221+
"""Retrieves chunks using semantic search (with expansion), filters and re-ranks them."""
222+
run_start_time = time.time()
223+
logger.info(f"🔎 Running hybrid retrieval for: '{query}' (Initial K={initial_top_k}, Final K={final_top_k})")
224+
logger.debug(f"Query Analysis for Retrieval: {query_analysis}") # Log full analysis
225+
226+
expansion_start_time = time.time()
227+
# Use original query if analysis didn't refine, otherwise use refined
228+
query_to_expand = query_analysis.get("original_query", query) # Use original for expansion base
229+
expanded_queries = self._simple_expand_query(query_analysis)
230+
all_queries = [query_to_expand] + expanded_queries # Include original query
231+
232+
query_embeddings = []
233+
for q in all_queries:
234+
emb = embed_text(q)
235+
if emb:
236+
query_embeddings.append(np.array([emb]).astype("float32"))
237+
else:
238+
logger.warning(f"Failed to generate embedding for query variant: '{q}'")
239+
240+
if not query_embeddings:
241+
logger.error("Failed to generate any query embeddings.")
242+
return []
243+
244+
expansion_time = time.time() - expansion_start_time
245+
logger.info(f"Step 2a: Query expansion

0 commit comments

Comments
 (0)