-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_retriever_nolinenode.py
More file actions
266 lines (212 loc) · 9.87 KB
/
graph_retriever_nolinenode.py
File metadata and controls
266 lines (212 loc) · 9.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
V4 图检索器 - 简化版 (w/o LineNode)
消融实验:移除 LineNode 演化链,用于证明 LineNode 的贡献
简化设计:
- 通过 MENTIONED_IN 边检索,而不是 LineNode
- 无演化链信息
"""
from typing import List, Dict, Optional
from pathlib import Path
from knowledge_graph_nolinenode import KnowledgeGraphNoLineNode
from entity_extractor import QueryExtractor
from llm_client import get_llm_client
from config import STM_CONFIG, SYSTEM_CONFIG, LPM_CONFIG
class EvoGraphRetrieverNoLineNode:
"""
简化版检索器 (w/o LineNode)
检索流程:
1. LPM: 从查询中提取实体,匹配图中已有实体
2. 收集实体相关的所有 Notes(通过 MENTIONED_IN 边)
3. LLM Rerank: 直接让 LLM 选择最相关的 top-k
"""
def __init__(self, kg: KnowledgeGraphNoLineNode, debug_log_path: Optional[Path] = None):
self.kg = kg
self.query_extractor = QueryExtractor()
self.llm = get_llm_client()
self.top_k = STM_CONFIG.get("top_k", 5)
self.debug = SYSTEM_CONFIG.get("debug", False)
self.debug_log_path = debug_log_path
self._debug_file = None
def _debug_log(self, msg: str):
"""写入 debug 日志"""
if self.debug_log_path:
if self._debug_file is None:
self._debug_file = open(self.debug_log_path, "a", encoding="utf-8")
self._debug_file.write(msg + "\n")
self._debug_file.flush()
def retrieve(self, query: str) -> List[dict]:
"""
简化的两层检索
Returns:
检索到的 Notes 列表
"""
self._debug_log(f"\n{'='*60}")
self._debug_log(f"[DEBUG] Query: {query}")
# ========== LPM 层: 实体提取 ==========
existing_entities = self._get_existing_entities()
self._debug_log(f"[DEBUG] Existing entities count: {len(existing_entities)}")
self._debug_log(f"[DEBUG] Existing entities: {[e['name'] for e in existing_entities[:10]]}{'...' if len(existing_entities) > 10 else ''}")
query_entities, query_keywords = self.query_extractor.extract(query, existing_entities)
self._debug_log(f"[DEBUG] LLM extracted entities: {query_entities}")
self._debug_log(f"[DEBUG] LLM extracted keywords: {query_keywords}")
resolved_entities = self._resolve_entities(query_entities)
self._debug_log(f"[DEBUG] Resolved entities: {resolved_entities}")
self._debug_log(f"[DEBUG] Resolution rate: {len(resolved_entities)}/{len(query_entities)} = {len(resolved_entities)/len(query_entities)*100:.1f}%" if query_entities else "[DEBUG] No entities to resolve")
# ========== 收集候选 Notes ==========
all_notes = []
seen_ids = set()
# 策略1: 实体相关 Notes(通过 MENTIONED_IN 边)
if resolved_entities:
entity_notes = self._collect_entity_notes(resolved_entities)
for note in entity_notes:
if note["id"] not in seen_ids:
seen_ids.add(note["id"])
all_notes.append(note)
self._debug_log(f"[DEBUG] Entity notes: {len(entity_notes)}")
# 策略2: 关键词搜索补充
if query_keywords:
keyword_notes = self._search_by_keywords(query_keywords)
added = 0
for note in keyword_notes:
if note["id"] not in seen_ids:
seen_ids.add(note["id"])
all_notes.append(note)
added += 1
self._debug_log(f"[DEBUG] Keyword notes added: {added}")
# 策略3: 如果还是没有候选,全库搜索
if not all_notes:
self._debug_log("[DEBUG] No candidates, falling back to all notes")
all_notes = self.kg.get_all_notes()[:100]
self._debug_log(f"[DEBUG] Total candidates: {len(all_notes)}")
self._debug_log(f"[DEBUG] Candidate IDs: {[n['id'] for n in all_notes[:10]]}{'...' if len(all_notes) > 10 else ''}")
if not all_notes:
self._debug_log("[DEBUG] ❌ No notes found, returning empty!")
return []
# ========== LLM Rerank ==========
selected_ids = self._llm_rerank(query, all_notes)
self._debug_log(f"[DEBUG] LLM rerank selected: {selected_ids}")
if selected_ids:
return self.kg.get_notes_by_ids(selected_ids)
# 回退:按时序返回前 top_k
return all_notes[:self.top_k]
def _get_existing_entities(self) -> List[dict]:
"""获取 LPM 层实体索引"""
entities = self.kg.get_all_entities()
return [
{
"name": e.get("name", ""),
"type": e.get("type", ""),
"summary": e.get("summary", ""),
"aliases": e.get("aliases", [])
}
for e in entities[:LPM_CONFIG.get("max_entities_in_prompt", 50)]
]
def _resolve_entities(self, entity_names: List[str]) -> List[str]:
"""解析实体名称到规范名称"""
resolved = []
for name in entity_names:
canonical = self.kg.resolve_alias(name)
if canonical:
resolved.append(canonical)
self._debug_log(f"[DEBUG] ✓ '{name}' -> '{canonical}'")
else:
self._debug_log(f"[DEBUG] ✗ '{name}' -> NOT FOUND")
return resolved
def _search_by_keywords(self, keywords: List[str], limit: int = 50) -> List[dict]:
"""通过关键词搜索 Notes"""
if not keywords:
return []
all_notes = self.kg.get_all_notes()
matched = []
for note in all_notes:
text = note.get("text", "").lower()
if any(kw.lower() in text for kw in keywords):
matched.append(note)
if len(matched) >= limit:
break
return matched
def _collect_entity_notes(self, entity_names: List[str]) -> List[dict]:
"""
收集实体相关的 Notes(简化版)
通过 MENTIONED_IN 边检索,而不是 LineNode
"""
if not entity_names:
return []
# 收集每个实体的 notes
entity_note_map = {}
all_notes_by_id = {}
for entity_name in entity_names:
# 使用简化版的 get_entity_notes(通过 MENTIONED_IN 边)
entity_notes = self.kg.get_entity_notes(entity_name, limit=100)
entity_note_map[entity_name] = set()
for note in entity_notes:
entity_note_map[entity_name].add(note["id"])
all_notes_by_id[note["id"]] = note
self._debug_log(f"[DEBUG] Entity note counts: {[(e, len(ids)) for e, ids in entity_note_map.items()]}")
# 单实体:直接返回
if len(entity_names) == 1:
notes = list(all_notes_by_id.values())
self._debug_log(f"[DEBUG] Single entity, returning {len(notes)} notes")
notes.sort(key=lambda x: x.get("seq", 0))
return notes
# 多实体:计算交集
note_id_sets = list(entity_note_map.values())
common_ids = note_id_sets[0].intersection(*note_id_sets[1:])
self._debug_log(f"[DEBUG] Intersection of {len(entity_names)} entities: {len(common_ids)} notes")
if common_ids:
notes = [all_notes_by_id[nid] for nid in common_ids]
self._debug_log(f"[DEBUG] Using intersection: {len(notes)} notes")
else:
# 交集为空,回退到最小集
min_entity = min(entity_note_map.keys(), key=lambda e: len(entity_note_map[e]))
min_note_ids = entity_note_map[min_entity]
notes = [all_notes_by_id[nid] for nid in min_note_ids]
self._debug_log(f"[DEBUG] Intersection empty, fallback to smallest set (entity={min_entity}): {len(notes)} notes")
notes.sort(key=lambda x: x.get("seq", 0))
return notes
def _llm_rerank(self, query: str, candidates: List[dict]) -> List[str]:
"""LLM 重排序 - 直接让 LLM 选择最相关的 Notes"""
if len(candidates) <= self.top_k:
return [c["id"] for c in candidates]
# 构建候选列表
candidate_texts = []
for note in candidates:
user_text = note.get('user', '')
assistant_text = note.get('assistant', '')
text = f"[{note['id']}] {note.get('session_date', '')}: User: {user_text} | Assistant: {assistant_text}"
candidate_texts.append(text)
prompt = f"""Select the {self.top_k} most relevant notes to answer the question.
Question: {query}
Candidate Notes:
{chr(10).join(candidate_texts)}
Return JSON with the IDs of the {self.top_k} most relevant notes:
{{"top_ids": ["N1", "N2", ...]}}
Only return valid JSON."""
result = self.llm.call_json(prompt)
if result and isinstance(result, dict):
top_ids = result.get("top_ids", [])
if top_ids and isinstance(top_ids, list):
return [str(id) for id in top_ids[:self.top_k]]
if result and isinstance(result, list):
return [str(id) for id in result[:self.top_k]]
if self.debug:
print("[EvoRetriever] Rerank failed, using fallback")
return [c["id"] for c in candidates[:self.top_k]]
# 保持与 baseline 相同的类名别名
EvoGraphRetriever = EvoGraphRetrieverNoLineNode
GraphRetriever = EvoGraphRetrieverNoLineNode
HybridRetriever = EvoGraphRetrieverNoLineNode
if __name__ == "__main__":
kg = KnowledgeGraphNoLineNode()
retriever = EvoGraphRetrieverNoLineNode(kg)
test_queries = [
"When did Jon lose his job as a banker?",
"Which city have both Jean and John visited?",
"What does Jon's dance studio offer?",
]
for query in test_queries:
print(f"\n{'='*60}")
print(f"Query: {query}")
results = retriever.retrieve(query)
print(f"Results: {[r['id'] for r in results]}")
kg.close()