@@ -163,7 +163,7 @@ def get_chat_context():
163163)
164164
165165reformulate_prompt = PromptTemplate (
166- input_variables = ["chat_history" , "last_suggested" , "question" , 'classifier' ],
166+ input_variables = ["chat_history" , "last_suggested" , "question" ],
167167 template = """
168168Return JSON only: with keys as "Rewritten" and "Correction" where correction being a dict of <original:corrected> pairs.
169169
@@ -187,7 +187,7 @@ def get_chat_context():
187187
188188- chat_history: ""
189189 question: "denger sign for johndice"
190- -> Rewritten: "WHat are the danger signs for jaundice?", Correction: {"johndice":"jaundice"}
190+ -> Rewritten: "WHat are the danger signs for jaundice?", Correction: {{ "johndice":"jaundice"} }
191191
192192- chat_history: ""
193193 question: "depression"
@@ -343,37 +343,65 @@ def judge_sufficiency(query, candidates, judge_llm=llm, threshold_weak=0.25):
343343 Return the top 4 qualified chunks for answering,
344344 and next 2 for follow-up suggestion.
345345 """
346-
347- qualified = []
348- followup_chunks = []
346+
347+ qualified_with_scores = [] # Store qualified chunks along with their cross-encoder scores and topic_match
348+ followup_chunks_raw = [] # Store non-qualified chunks for potential follow-up
349+ topic_match_order = {"strong" : 3 , "medium" : 2 , "absolutely_not_possible" : 1 }
350+
349351 print ("len of candidates" ,len (candidates ))
350- for c in candidates : # inspect up to 12
351- snippet = f"Source: { c ['meta' ].get ('doc_name' ,'unknown' )} \n Excerpt: { c ['text' ]} "
352+ for c in candidates : # inspect up to 12. Iterate through all candidates initially
353+ snippet = f"Source: { c ['meta' ].get ('doc_name' ,'unknown' )} \\ nExcerpt: { c ['text' ]} "
352354 prompt = judge_prompt .format (query = query , context_snippet = snippet )
353355
354356 resp = judge_llm .invoke ([HumanMessage (content = prompt )]).content
355- #print(candidates, resp)
356357
357358 try :
358359 obj = json .loads (resp [resp .rfind ("{" ):resp .rfind ("}" )+ 1 ])
359360 print (obj )
361+ topic_match_label = obj .get ("topic_match" , "absolutely_not_possible" )
362+ # Store topic_match_score in the chunk's meta for easier sorting
363+ c ['meta' ]['topic_match_score' ] = topic_match_order .get (topic_match_label , 0 ) # Default to 0 for unknown/error
364+
360365 if obj .get ("sufficient" , False ):
361- qualified .append (c )
366+ qualified_with_scores .append (c ) # Add to qualified list
362367 else :
363- followup_chunks .append (c )
368+ followup_chunks_raw .append (c )
364369 except Exception :
370+ # Fallback based on cross-encoder score if LLM judge fails
371+ # Assign a default topic_match_score (e.g., 'medium' equivalent if LLM fails to parse)
372+ c ['meta' ]['topic_match_score' ] = topic_match_order .get ("medium" , 0 )
365373 if c ["scores" ]["cross" ] > threshold_weak :
366- qualified .append (c )
374+ qualified_with_scores .append (c )
367375 else :
368- followup_chunks .append (c )
376+ followup_chunks_raw .append (c )
377+
378+ # NEW: Sort qualified chunks first by topic_match_score (desc), then by cross score (desc)
379+ qualified = sorted (qualified_with_scores ,
380+ key = lambda x : (x ['meta' ].get ('topic_match_score' , 0 ), x ["scores" ]["cross" ]),
381+ reverse = True )
382+
383+ print ("BEFORE len of answer_chunks" ,len (qualified ),"BEFORE len of followup_chunks" ,len (followup_chunks_raw ))
384+
385+ answer_chunks = qualified [:4 ] # Take top 4 from the re-sorted qualified list
369386
370- print ("BEFORE len of answer_chunks" ,len (qualified ),"BEFORE len of followup_chunks" ,len (followup_chunks ))
371- if len (followup_chunks )== 0 :
372- followup_chunks = qualified [- 2 :]
373- qualified = qualified [:- 2 ]
374- print ("AFTER len of answer_chunks" ,len (qualified ),"AFTER len of followup_chunks" ,len (followup_chunks ))
375- return {"answer_chunks" : qualified [:4 ], "followup_chunks" : followup_chunks [:2 ]}
387+ # Ensure all followup candidates also have a topic_match_score for consistent sorting
388+ for c in followup_chunks_raw :
389+ if 'topic_match_score' not in c ['meta' ]:
390+ c ['meta' ]['topic_match_score' ] = topic_match_order .get ("absolutely_not_possible" , 0 ) # Default for raw fallbacks
376391
392+ # Now, combine any remaining qualified chunks with the initially non-qualified ones for follow-up
393+ # This ensures higher-scored but not-top-4 answer chunks can still be follow-ups
394+ remaining_qualified_for_followup = qualified [4 :]
395+
396+ # Sort these combined candidates using the same two-tier logic
397+ combined_followup_candidates = sorted (followup_chunks_raw + remaining_qualified_for_followup ,
398+ key = lambda x : (x ['meta' ].get ('topic_match_score' , 0 ), x ["scores" ]["cross" ]),
399+ reverse = True )
400+
401+ followup_chunks = combined_followup_candidates [:2 ]
402+
403+ print ("AFTER len of answer_chunks" ,len (answer_chunks ),"AFTER len of followup_chunks" ,len (followup_chunks ))
404+ return {"answer_chunks" : answer_chunks , "followup_chunks" : followup_chunks }
377405def synthesize_answer (query , top_candidates , context_followup , main_llm = llm ):
378406 # Build context from top 3 candidates
379407 sources = []
@@ -398,6 +426,7 @@ def synthesize_answer(query, top_candidates, context_followup, main_llm=llm):
398426
399427 return resp
400428
429+
401430# -------------------- CLASSIFY / REFORMULATE / CHITCHAT --------------------
402431def classify_message (chat_history , user_message ):
403432 prompt = classifier_prompt .format (chat_history = chat_history , question = user_message )
@@ -418,19 +447,6 @@ def classify_message(chat_history, user_message):
418447 return "CHITCHAT" , "greeting heuristic"
419448 return "MEDICAL_QUESTION" , "fallback"
420449
421- # def reformulate_query(chat_history, user_message, last_suggested=""):
422- # print("here")
423- # prompt = reformulate_prompt.format(chat_history=chat_history, last_suggested=last_suggested, question=user_message)
424- # append_debug("[reformulate] sending reformulation prompt")
425- # try:
426- # resp = llm.invoke([HumanMessage(content=prompt)]).content
427- # print("resp is ",resp)
428- # parsed = safe_json_parse(resp)
429- # if parsed:
430- # return parsed.get("Rewritten", user_message), parsed.get("Correction", "")
431- # except Exception as e:
432- # append_debug(f"[reformulate] LLM failed: {e}")
433- # return user_message, ""
434450def reformulate_query (chat_history , user_message ,classify , last_suggested = "" ):
435451 print ("here" )
436452 prompt = f"""
0 commit comments