11# agents/context_expander.py
2- import logging # Added import
2+ import logging
3+ import spacy
34from .base import BaseAgent
45from utils .chunk_utils import filter_redundant_chunks
56
6- logger = logging .getLogger (__name__ ) # Get a logger for this module
7+ logger = logging .getLogger (__name__ )
78
89class ContextExpansionAgent (BaseAgent ):
910 """Agent responsible for assessing and expanding retrieval context."""
1011
11- def assess (self , retrieved_chunks : list [dict ]) -> dict :
12+ def __init__ (self ):
13+ super ().__init__ ()
14+ try :
15+ self .nlp = spacy .load ("en_core_web_sm" )
16+ logger .info ("✅ spaCy model 'en_core_web_sm' loaded successfully." )
17+ except OSError :
18+ logger .error ("❌ Error loading spaCy model 'en_core_web_sm'. Please run: python -m spacy download en_core_web_sm" )
19+ self .nlp = None
20+
21+ self .feedback = defaultdict (list )
22+
23+ def assess (self , retrieved_chunks : list [dict ], query_analysis : dict ) -> dict :
1224 """Assess if retrieved context is sufficient."""
1325 print ("🧐 Assessing context sufficiency..." )
14-
26+
1527 if not retrieved_chunks :
1628 print ("⚠️ Assessment: No chunks retrieved, expansion needed." )
1729 return {"needs_expansion" : True , "reason" : "No chunks retrieved" }
18-
30+
1931 # Check confidence of top chunks
2032 confidences = [chunk .get ("confidence" , 0 ) for chunk in retrieved_chunks ]
2133 avg_confidence = sum (confidences ) / len (confidences )
2234 top_confidence = confidences [0 ] if confidences else 0
23-
35+
2436 # Calculate context coverage
2537 total_text_length = sum (len (chunk ["text" ]) for chunk in retrieved_chunks )
26-
38+
2739 # Check if we have entities from query in the chunks
28- # This would be populated from query_analysis
29-
40+ keywords = query_analysis .get ("keywords" , [])
41+ entities = query_analysis .get ("entities" , [])
42+ search_terms = set ([k .lower () for k in keywords ] + [e .lower () for e in entities ])
43+ logger .debug (f"Checking context relevance. Search terms: { search_terms } " )
44+
45+ if not search_terms :
46+ logger .debug ("No keywords/entities found in query analysis, assuming context is relevant." )
47+ return {"needs_expansion" : False , "reason" : "No keywords/entities found" }
48+
49+ found_relevant_chunk = False
50+ for i , chunk in enumerate (retrieved_chunks ):
51+ text_lower = chunk .get ("text" , "" ).lower ()
52+ for term in search_terms :
53+ if re .search (r'\b' + re .escape (term ) + r'\b' , text_lower ):
54+ logger .debug (f"Found relevant term '{ term } ' in context chunk { i + 1 } ." )
55+ found_relevant_chunk = True
56+ break
57+ if found_relevant_chunk :
58+ break
59+
60+ if not found_relevant_chunk :
61+ logger .warning ("No relevant terms found in any context chunk." )
62+ return {"needs_expansion" : True , "reason" : "No relevant terms found" }
63+
3064 # Decision logic
3165 if top_confidence < 0.4 :
3266 print (f"⚠️ Assessment: Low top confidence ({ top_confidence :.2f} ), expansion needed." )
3367 return {"needs_expansion" : True , "reason" : "Low confidence" }
34-
68+
3569 if avg_confidence < 0.3 :
3670 print (f"⚠️ Assessment: Low average confidence ({ avg_confidence :.2f} ), expansion needed." )
3771 return {"needs_expansion" : True , "reason" : "Low average confidence" }
38-
72+
3973 if total_text_length < 500 :
4074 print (f"⚠️ Assessment: Short context ({ total_text_length } chars), expansion needed." )
4175 return {"needs_expansion" : True , "reason" : "Short context" }
42-
76+
4377 print (f"✅ Assessment: Context sufficient (Avg conf: { avg_confidence :.2f} , Length: { total_text_length } chars)" )
4478 return {"needs_expansion" : False , "reason" : "Sufficient confidence and context" }
4579
4680 def find_contextual_chunks (self , chunks , retriever , max_additional = 3 ):
4781 """Find chunks that might be contextually related to the given chunks."""
4882 if not chunks :
4983 return []
50-
84+
5185 # Strategy 1: Find adjacent chunks by page numbers
5286 pages = [chunk ["metadata" ].get ("page" , 0 ) for chunk in chunks if "metadata" in chunk ]
5387 adjacent_pages = set ()
54-
88+
5589 for page in pages :
5690 if page > 0 :
5791 adjacent_pages .add (page - 1 ) # Previous page
5892 adjacent_pages .add (page + 1 ) # Next page
59-
93+
6094 # Filter out pages we already have
6195 adjacent_pages = adjacent_pages - set (pages )
62-
96+
6397 # Find chunks from adjacent pages
6498 adjacent_chunks = []
6599 for i , metadata in enumerate (retriever .metadatas ):
@@ -70,26 +104,47 @@ def find_contextual_chunks(self, chunks, retriever, max_additional=3):
70104 "confidence" : 0.4 , # Lower confidence for adjacent chunks
71105 "expansion_method" : "adjacent_page"
72106 })
73-
107+
74108 # Strategy 2: Find chunks from same sections
75109 sections = [chunk ["metadata" ].get ("section" , "" ) for chunk in chunks if "metadata" in chunk ]
76110 sections = [s for s in sections if s ] # Remove empty sections
77-
111+
78112 section_chunks = []
79113 if sections :
80114 for i , metadata in enumerate (retriever .metadatas ):
81115 if metadata .get ("section" , "" ) in sections :
82116 # Skip if we already have this chunk
83117 if any (retriever .texts [i ] == c ["text" ] for c in chunks + adjacent_chunks ):
84118 continue
85-
119+
86120 section_chunks .append ({
87121 "text" : retriever .texts [i ],
88122 "metadata" : metadata ,
89123 "confidence" : 0.35 , # Lower confidence for section-based chunks
90124 "expansion_method" : "same_section"
91125 })
92-
126+
127+ # Strategy 3: Use advanced NLP techniques to find related chunks
128+ if self .nlp :
129+ for chunk in chunks :
130+ doc = self .nlp (chunk ["text" ].lower ())
131+ chunk_entities = [ent .text .lower () for ent in doc .ents ]
132+ chunk_keywords = [token .text .lower () for token in doc if token .dep_ in ("nsubj" , "dobj" , "pobj" )]
133+
134+ for i , metadata in enumerate (retriever .metadatas ):
135+ text_lower = retriever .texts [i ].lower ()
136+ doc = self .nlp (text_lower )
137+ entities = [ent .text .lower () for ent in doc .ents ]
138+ keywords = [token .text .lower () for token in doc if token .dep_ in ("nsubj" , "dobj" , "pobj" )]
139+
140+ if any (term in entities or term in keywords for term in chunk_entities + chunk_keywords ):
141+ section_chunks .append ({
142+ "text" : retriever .texts [i ],
143+ "metadata" : metadata ,
144+ "confidence" : 0.3 , # Lower confidence for NLP-based chunks
145+ "expansion_method" : "nlp_related"
146+ })
147+
93148 # Combine and limit additional chunks
94149 additional_chunks = (adjacent_chunks + section_chunks )[:max_additional ]
95150 print (f"✅ Found { len (additional_chunks )} additional context chunks." )
@@ -98,90 +153,98 @@ def find_contextual_chunks(self, chunks, retriever, max_additional=3):
98153 def fuse_chunks (self , chunks ):
99154 """Fuse chunks into a coherent context, managing token limits."""
100155 print ("🧩 Fusing chunks into coherent context..." )
101-
156+
102157 # Sort chunks by confidence
103158 sorted_chunks = sorted (chunks , key = lambda x : x .get ("confidence" , 0 ), reverse = True )
104-
159+
105160 # Get metadata for organization
106161 chunk_metadata = []
107162 for chunk in sorted_chunks :
108163 page = chunk ["metadata" ].get ("page" , "Unknown" )
109164 section = chunk ["metadata" ].get ("section" , "Unknown" )
110165 chunk_metadata .append (f"[Page { page } , Section: { section } ]" )
111-
166+
112167 # Combine text with metadata headers
113168 fused_text = ""
114169 for i , chunk in enumerate (sorted_chunks ):
115170 fused_text += f"\n \n --- Excerpt { i + 1 } : { chunk_metadata [i ]} ---\n \n "
116171 fused_text += chunk ["text" ]
117-
172+
118173 print (f"✅ Fused { len (sorted_chunks )} chunks into coherent context." )
119174 return fused_text
120175
121176 def aggregate_metadata (self , chunks : list [dict ]) -> dict :
122177 """Aggregate metadata from all chunks."""
123178 print ("📊 Aggregating metadata..." )
124-
179+
125180 # Extract page numbers
126181 pages = set ()
127182 sections = set ()
128-
183+
129184 for chunk in chunks :
130185 metadata = chunk .get ("metadata" , {})
131186 if "page" in metadata and metadata ["page" ]:
132187 pages .add (metadata ["page" ])
133188 if "section" in metadata and metadata ["section" ]:
134189 sections .add (metadata ["section" ])
135-
190+
136191 aggregated = {
137192 "pages" : sorted (list (pages )),
138193 "sections" : sorted (list (sections ))
139194 }
140-
195+
141196 print (f"✅ Metadata aggregated: { len (pages )} pages, { len (sections )} sections" )
142197 return aggregated
143198
199+ def _update_feedback (self , query : str , context_chunks : list [dict ], relevance : bool ):
200+ """Update feedback loop with user interaction data."""
201+ self .feedback [query ].append ({
202+ "context_chunks" : context_chunks ,
203+ "relevance" : relevance
204+ })
205+ logger .debug (f"Feedback updated for query: '{ query } ' with relevance: { relevance } " )
206+
144207 def run (self , retrieved_chunks : list [dict ], query_analysis : dict , retriever_agent ) -> tuple [list [dict ], dict ]:
145208 """Assess context, expand if needed, filter redundancy, and fuse chunks."""
146209 logger .debug (f"Running context expansion/filtering on { len (retrieved_chunks )} chunks." )
147210 # 1. Assess if the context is sufficient
148- assessment = self .assess (retrieved_chunks )
149-
211+ assessment = self .assess (retrieved_chunks , query_analysis )
212+
150213 final_chunks = retrieved_chunks .copy ()
151-
214+
152215 # 2. Expand context if needed
153216 if assessment ["needs_expansion" ]:
154217 print (f"🔍 Expanding context due to: { assessment ['reason' ]} " )
155-
218+
156219 # If complex query, consider processing sub-queries separately
157220 if query_analysis .get ("needs_decomposition" , False ):
158221 print ("📋 Complex query detected, expanding context for multiple aspects." )
159222 # In a full implementation, we might retrieve for each sub-query
160223 # For now, just get related chunks to the current results
161-
224+
162225 # Find related chunks
163226 additional_chunks = self .find_contextual_chunks (
164- retrieved_chunks ,
227+ retrieved_chunks ,
165228 retriever_agent
166229 )
167-
230+
168231 # Combine original and additional chunks
169232 expanded_chunks = retrieved_chunks + additional_chunks
170-
233+
171234 # 3. Filter redundant chunks using the utility function
172235 final_chunks = filter_redundant_chunks (expanded_chunks )
173-
236+
174237 print (f"✅ Context expansion complete: { len (final_chunks )} chunks after filtering." )
175238 else :
176239 print ("✅ Original context is sufficient, no expansion needed." )
177240 # Still filter original chunks for redundancy
178241 final_chunks = filter_redundant_chunks (retrieved_chunks )
179-
242+
180243 # 4. Aggregate metadata from all included chunks
181244 aggregated_metadata = self .aggregate_metadata (final_chunks )
182-
245+
183246 logger .debug (f"Context expansion complete. Final chunks: { len (final_chunks )} " )
184247 # Note: We don't actually fuse the chunks here - that will be handled by the generator
185248 # when it builds its prompt, using the separate chunks we provide
186-
249+
187250 return final_chunks , aggregated_metadata
0 commit comments