Skip to content

Commit cd42327

Browse files
authored
Merge pull request #23 from hanjuhn/main
fix: intent router 수정
2 parents a3331ba + d34354e commit cd42327

1 file changed

Lines changed: 80 additions & 53 deletions

File tree

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from langchain_core.messages import SystemMessage, AIMessage
22
import re
3+
from typing import List, Optional
34

45
from core.shared.states.states import CustomsAgentState
56
from core.shared.utils.llm import get_llm
@@ -13,80 +14,106 @@
1314
SESSION_CHECK_MESSAGE_COUNT
1415
)
1516

16-
def intent_router(state: CustomsAgentState) -> CustomsAgentState:
17-
"""사용자 쿼리의 의도를 분류합니다."""
18-
19-
# 현재 쿼리가 숫자 선택인지 확인
20-
current_query = state["query"].strip()
21-
is_number_selection = False
22-
23-
# 숫자 선택 패턴 확인 (1번, 2번, 3번, 1, 2, 3 등)
24-
for pattern in NUMBER_SELECTION_PATTERNS:
25-
if re.match(pattern, current_query):
26-
is_number_selection = True
27-
break
28-
29-
# 질문 형태 감지
30-
is_question = any(re.search(pattern, current_query) for pattern in QUESTION_PATTERNS)
31-
32-
# 이전 대화에서 관세 예측 중인지 확인
33-
messages = state.get("messages", [])
34-
is_in_tariff_session = False
35-
17+
18+
def _is_number_selection(query: str) -> bool:
19+
"""숫자 선택 패턴인지 확인합니다."""
20+
return any(re.match(pattern, query) for pattern in NUMBER_SELECTION_PATTERNS)
21+
22+
23+
def _is_question(query: str) -> bool:
24+
"""질문 형태인지 확인합니다."""
25+
return any(re.search(pattern, query) for pattern in QUESTION_PATTERNS)
26+
27+
28+
def _is_in_tariff_session(state: CustomsAgentState) -> bool:
29+
"""관세 예측 세션 중인지 확인합니다."""
3630
# 1. state의 intent 필드 확인 (가장 중요)
3731
if state.get("intent") == "tariff_prediction":
38-
is_in_tariff_session = True
32+
return True
33+
34+
messages = state.get("messages", [])
35+
if not messages:
36+
return False
3937

4038
# 2. 최근 메시지들을 확인하여 관세 예측 세션 중인지 판단
41-
for msg in messages[-SESSION_CHECK_MESSAGE_COUNT:]: # 최근 메시지 확인
39+
recent_messages = messages[-SESSION_CHECK_MESSAGE_COUNT:]
40+
for msg in recent_messages:
4241
if hasattr(msg, 'content') and isinstance(msg.content, str):
4342
content = msg.content.lower()
44-
# 관세 예측 관련 키워드가 있거나 HS 코드 선택 메시지가 있으면 관세 예측 세션으로 판단
4543
if any(keyword in content for keyword in TARIFF_SESSION_KEYWORDS):
46-
is_in_tariff_session = True
47-
break
44+
return True
45+
46+
# 3. 마지막 메시지에서 세션 상태 확인
47+
last_msg = messages[-1]
48+
if hasattr(last_msg, 'content') and isinstance(last_msg.content, str):
49+
content = last_msg.content.lower()
50+
if 'tariff_prediction' in content or '관세 예측 세션' in content:
51+
return True
52+
53+
return False
54+
55+
56+
def _classify_with_llm(query: str) -> str:
57+
"""LLM을 사용하여 의도를 분류합니다."""
58+
try:
59+
llm = get_llm()
60+
result = llm.invoke([
61+
SystemMessage(content=INTENT_CLASSIFICATION_PROMPT.format(query=query))
62+
])
63+
64+
intent = str(result.content).strip()
65+
return intent if intent in INTENT_TYPES else DEFAULT_INTENT
66+
67+
except Exception:
68+
# LLM 분류 실패 시 기본값 사용
69+
return DEFAULT_INTENT
70+
71+
72+
def _add_classification_message(state: CustomsAgentState, intent: str, reason: str) -> None:
73+
"""의도 분류 완료 메시지를 추가합니다."""
74+
state["messages"].append(
75+
AIMessage(content=f"의도 분류 완료: {intent} ({reason})")
76+
)
77+
78+
79+
def intent_router(state: CustomsAgentState) -> CustomsAgentState:
80+
"""사용자 쿼리의 의도를 분류합니다."""
4881

49-
# 3. 이전 의도가 tariff_prediction이었는지도 확인
50-
if messages and len(messages) > 0:
51-
last_msg = messages[-1]
52-
if hasattr(last_msg, 'content') and isinstance(last_msg.content, str):
53-
if 'tariff_prediction' in last_msg.content or '관세 예측 세션' in last_msg.content:
54-
is_in_tariff_session = True
82+
current_query = state["query"].strip()
83+
if not current_query:
84+
state["intent"] = DEFAULT_INTENT
85+
_add_classification_message(state, DEFAULT_INTENT, "빈 쿼리")
86+
return state
87+
88+
# 세션 연속성 확인
89+
is_in_tariff_session = _is_in_tariff_session(state)
5590

5691
# 관세 예측 세션 중이면 무조건 tariff_prediction으로 분류
5792
if is_in_tariff_session:
58-
state["intent"] = "tariff_prediction" # type: ignore
59-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (관세 예측 세션 연속성 유지)"))
93+
state["intent"] = "tariff_prediction"
94+
_add_classification_message(state, "tariff_prediction", "관세 예측 세션 연속성 유지")
6095
return state
6196

97+
# 패턴 기반 분류
98+
is_number_selection = _is_number_selection(current_query)
99+
is_question = _is_question(current_query)
100+
62101
# 질문 형태이면서 관세 예측 세션이 아닌 경우 QnA로 분류
63-
if is_question and not is_in_tariff_session:
64-
state["intent"] = "qna" # type: ignore
65-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (질문 형태 감지)"))
102+
if is_question:
103+
state["intent"] = "qna"
104+
_add_classification_message(state, "qna", "질문 형태 감지")
66105
return state
67106

68107
# 숫자 선택이지만 관세 예측 세션이 아닌 경우에도 tariff_prediction으로 분류
69108
# (HS 코드 직접 입력 등의 경우)
70109
if is_number_selection:
71-
state["intent"] = "tariff_prediction" # type: ignore
72-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (숫자 선택 감지)"))
110+
state["intent"] = "tariff_prediction"
111+
_add_classification_message(state, "tariff_prediction", "숫자 선택 감지")
73112
return state
74113

75114
# LLM 기반 의도 분류를 수행
76-
llm = get_llm()
77-
78-
try:
79-
result = llm.invoke([
80-
SystemMessage(content=INTENT_CLASSIFICATION_PROMPT.format(query=state["query"]))
81-
])
82-
83-
intent = str(result.content).strip()
84-
if intent not in INTENT_TYPES:
85-
intent = DEFAULT_INTENT # 기본값을 tariff_prediction으로 변경
86-
except Exception as e:
87-
# LLM 분류 실패 시 기본값 사용
88-
intent = DEFAULT_INTENT # 오류 시에도 tariff_prediction으로 분류
115+
intent = _classify_with_llm(current_query)
116+
state["intent"] = intent
117+
_add_classification_message(state, intent, "LLM 분류")
89118

90-
state["intent"] = intent # type: ignore
91-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {intent} (LLM 분류)"))
92119
return state

0 commit comments

Comments
 (0)