Skip to content

Commit a46b776

Browse files
added code with changes to chitchat history
1 parent acb57d1 commit a46b776

1 file changed

Lines changed: 47 additions & 31 deletions

File tree

new_architecture_v4_cursor.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def get_chat_context():
163163
)
164164

165165
reformulate_prompt = PromptTemplate(
166-
input_variables=["chat_history", "last_suggested", "question",'classifier'],
166+
input_variables=["chat_history", "last_suggested", "question"],
167167
template="""
168168
Return JSON only: with keys as "Rewritten" and "Correction" where correction being a dict of <original:corrected> pairs.
169169
@@ -187,7 +187,7 @@ def get_chat_context():
187187
188188
- chat_history: ""
189189
question: "denger sign for johndice"
190-
-> Rewritten: "WHat are the danger signs for jaundice?", Correction: {"johndice":"jaundice"}
190+
-> Rewritten: "WHat are the danger signs for jaundice?", Correction: {{"johndice":"jaundice"}}
191191
192192
- chat_history: ""
193193
question: "depression"
@@ -343,37 +343,65 @@ def judge_sufficiency(query, candidates, judge_llm=llm, threshold_weak=0.25):
343343
Return the top 4 qualified chunks for answering,
344344
and next 2 for follow-up suggestion.
345345
"""
346-
347-
qualified = []
348-
followup_chunks=[]
346+
347+
qualified_with_scores = [] # Store qualified chunks along with their cross-encoder scores and topic_match
348+
followup_chunks_raw = [] # Store non-qualified chunks for potential follow-up
349+
topic_match_order = {"strong": 3, "medium": 2, "absolutely_not_possible": 1}
350+
349351
print("len of candidates",len(candidates))
350-
for c in candidates: # inspect up to 12
351-
snippet = f"Source: {c['meta'].get('doc_name','unknown')}\nExcerpt: {c['text']}"
352+
for c in candidates: # inspect up to 12. Iterate through all candidates initially
353+
snippet = f"Source: {c['meta'].get('doc_name','unknown')}\\nExcerpt: {c['text']}"
352354
prompt = judge_prompt.format(query=query, context_snippet=snippet)
353355

354356
resp = judge_llm.invoke([HumanMessage(content=prompt)]).content
355-
#print(candidates, resp)
356357

357358
try:
358359
obj = json.loads(resp[resp.rfind("{"):resp.rfind("}")+1])
359360
print(obj)
361+
topic_match_label = obj.get("topic_match", "absolutely_not_possible")
362+
# Store topic_match_score in the chunk's meta for easier sorting
363+
c['meta']['topic_match_score'] = topic_match_order.get(topic_match_label, 0) # Default to 0 for unknown/error
364+
360365
if obj.get("sufficient", False):
361-
qualified.append(c)
366+
qualified_with_scores.append(c) # Add to qualified list
362367
else:
363-
followup_chunks.append(c)
368+
followup_chunks_raw.append(c)
364369
except Exception:
370+
# Fallback based on cross-encoder score if LLM judge fails
371+
# Assign a default topic_match_score (e.g., 'medium' equivalent if LLM fails to parse)
372+
c['meta']['topic_match_score'] = topic_match_order.get("medium", 0)
365373
if c["scores"]["cross"] > threshold_weak:
366-
qualified.append(c)
374+
qualified_with_scores.append(c)
367375
else:
368-
followup_chunks.append(c)
376+
followup_chunks_raw.append(c)
377+
378+
# NEW: Sort qualified chunks first by topic_match_score (desc), then by cross score (desc)
379+
qualified = sorted(qualified_with_scores,
380+
key=lambda x: (x['meta'].get('topic_match_score', 0), x["scores"]["cross"]),
381+
reverse=True)
382+
383+
print("BEFORE len of answer_chunks",len(qualified),"BEFORE len of followup_chunks",len(followup_chunks_raw))
384+
385+
answer_chunks = qualified[:4] # Take top 4 from the re-sorted qualified list
369386

370-
print("BEFORE len of answer_chunks",len(qualified),"BEFORE len of followup_chunks",len(followup_chunks))
371-
if len(followup_chunks)==0:
372-
followup_chunks=qualified[-2:]
373-
qualified=qualified[:-2]
374-
print("AFTER len of answer_chunks",len(qualified),"AFTER len of followup_chunks",len(followup_chunks))
375-
return {"answer_chunks": qualified[:4], "followup_chunks": followup_chunks[:2]}
387+
# Ensure all followup candidates also have a topic_match_score for consistent sorting
388+
for c in followup_chunks_raw:
389+
if 'topic_match_score' not in c['meta']:
390+
c['meta']['topic_match_score'] = topic_match_order.get("absolutely_not_possible", 0) # Default for raw fallbacks
376391

392+
# Now, combine any remaining qualified chunks with the initially non-qualified ones for follow-up
393+
# This ensures higher-scored but not-top-4 answer chunks can still be follow-ups
394+
remaining_qualified_for_followup = qualified[4:]
395+
396+
# Sort these combined candidates using the same two-tier logic
397+
combined_followup_candidates = sorted(followup_chunks_raw + remaining_qualified_for_followup,
398+
key=lambda x: (x['meta'].get('topic_match_score', 0), x["scores"]["cross"]),
399+
reverse=True)
400+
401+
followup_chunks = combined_followup_candidates[:2]
402+
403+
print("AFTER len of answer_chunks",len(answer_chunks),"AFTER len of followup_chunks",len(followup_chunks))
404+
return {"answer_chunks": answer_chunks, "followup_chunks": followup_chunks}
377405
def synthesize_answer(query, top_candidates, context_followup, main_llm=llm):
378406
# Build context from top 3 candidates
379407
sources = []
@@ -398,6 +426,7 @@ def synthesize_answer(query, top_candidates, context_followup, main_llm=llm):
398426

399427
return resp
400428

429+
401430
# -------------------- CLASSIFY / REFORMULATE / CHITCHAT --------------------
402431
def classify_message(chat_history, user_message):
403432
prompt = classifier_prompt.format(chat_history=chat_history, question=user_message)
@@ -418,19 +447,6 @@ def classify_message(chat_history, user_message):
418447
return "CHITCHAT", "greeting heuristic"
419448
return "MEDICAL_QUESTION", "fallback"
420449

421-
# def reformulate_query(chat_history, user_message, last_suggested=""):
422-
# print("here")
423-
# prompt = reformulate_prompt.format(chat_history=chat_history, last_suggested=last_suggested, question=user_message)
424-
# append_debug("[reformulate] sending reformulation prompt")
425-
# try:
426-
# resp = llm.invoke([HumanMessage(content=prompt)]).content
427-
# print("resp is ",resp)
428-
# parsed = safe_json_parse(resp)
429-
# if parsed:
430-
# return parsed.get("Rewritten", user_message), parsed.get("Correction", "")
431-
# except Exception as e:
432-
# append_debug(f"[reformulate] LLM failed: {e}")
433-
# return user_message, ""
434450
def reformulate_query(chat_history, user_message,classify, last_suggested=""):
435451
print("here")
436452
prompt = f"""

0 commit comments

Comments
 (0)