Skip to content

Commit 81ab32c

Browse files
authored
Merge pull request #21 from hanjuhn/main
refactor: 관세 예측 기능 코드 리펙토링
2 parents 1351f63 + aeb1b43 commit 81ab32c

11 files changed

Lines changed: 485 additions & 439 deletions

core/shared/router/intent_router.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from core.shared.constants import (
77
TARIFF_SESSION_KEYWORDS,
88
NUMBER_SELECTION_PATTERNS,
9-
TARIFF_PREDICTION_KEYWORDS,
10-
CUSTOMS_TRACKING_KEYWORDS,
119
INTENT_CLASSIFICATION_PROMPT,
1210
INTENT_TYPES,
1311
DEFAULT_INTENT
@@ -16,11 +14,25 @@
1614
def intent_router(state: CustomsAgentState) -> CustomsAgentState:
1715
"""사용자 쿼리의 의도를 분류합니다."""
1816

17+
# 현재 쿼리가 숫자 선택인지 확인
18+
current_query = state["query"].strip()
19+
is_number_selection = False
20+
21+
# 숫자 선택 패턴 확인 (1번, 2번, 3번, 1, 2, 3 등)
22+
for pattern in NUMBER_SELECTION_PATTERNS:
23+
if re.match(pattern, current_query):
24+
is_number_selection = True
25+
break
26+
1927
# 이전 대화에서 관세 예측 중인지 확인
2028
messages = state.get("messages", [])
2129
is_in_tariff_session = False
2230

23-
# 최근 메시지들을 확인하여 관세 예측 세션 중인지 판단
31+
# 1. state의 intent 필드 확인 (가장 중요)
32+
if state.get("intent") == "tariff_prediction":
33+
is_in_tariff_session = True
34+
35+
# 2. 최근 메시지들을 확인하여 관세 예측 세션 중인지 판단
2436
for msg in messages[-5:]: # 최근 5개 메시지 확인
2537
if hasattr(msg, 'content') and isinstance(msg.content, str):
2638
content = msg.content.lower()
@@ -29,7 +41,7 @@ def intent_router(state: CustomsAgentState) -> CustomsAgentState:
2941
is_in_tariff_session = True
3042
break
3143

32-
# 이전 의도가 tariff_prediction이었는지도 확인
44+
# 3. 이전 의도가 tariff_prediction이었는지도 확인
3345
if messages and len(messages) > 0:
3446
last_msg = messages[-1]
3547
if hasattr(last_msg, 'content') and isinstance(last_msg.content, str):
@@ -43,23 +55,6 @@ def intent_router(state: CustomsAgentState) -> CustomsAgentState:
4355
print(state)
4456
return state
4557

46-
# 현재 쿼리가 숫자 선택인지 확인
47-
current_query = state["query"].strip()
48-
is_number_selection = False
49-
50-
# 숫자 선택 패턴 확인 (1번, 2번, 3번, 1, 2, 3 등)
51-
for pattern in NUMBER_SELECTION_PATTERNS:
52-
if re.match(pattern, current_query):
53-
is_number_selection = True
54-
break
55-
56-
# 관세 예측 세션 중이고 숫자 선택이면 무조건 tariff_prediction
57-
if is_in_tariff_session and is_number_selection:
58-
state["intent"] = "tariff_prediction" # type: ignore
59-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (세션 연속성 유지)"))
60-
print(state)
61-
return state
62-
6358
# 숫자 선택이지만 관세 예측 세션이 아닌 경우에도 tariff_prediction으로 분류
6459
# (HS 코드 직접 입력 등의 경우)
6560
if is_number_selection:
@@ -68,22 +63,7 @@ def intent_router(state: CustomsAgentState) -> CustomsAgentState:
6863
print(state)
6964
return state
7065

71-
# 관세 예측 관련 키워드가 있으면 무조건 tariff_prediction으로 분류
72-
current_query_lower = current_query.lower()
73-
if any(keyword in current_query_lower for keyword in TARIFF_PREDICTION_KEYWORDS):
74-
state["intent"] = "tariff_prediction" # type: ignore
75-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (관세 예측 키워드 감지)"))
76-
print(state)
77-
return state
78-
79-
# 운송장/배송 관련 키워드가 있으면 customs_tracking으로 분류 (배송은 제외)
80-
if any(keyword in current_query_lower for keyword in CUSTOMS_TRACKING_KEYWORDS):
81-
state["intent"] = "customs_tracking" # type: ignore
82-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (배송 추적 키워드 감지)"))
83-
print(state)
84-
return state
85-
86-
# 일반적인 의도 분류 (LLM 사용)
66+
# LLM 기반 의도 분류를 수행
8767
llm = get_llm()
8868

8969
try:
@@ -99,6 +79,6 @@ def intent_router(state: CustomsAgentState) -> CustomsAgentState:
9979
intent = DEFAULT_INTENT # 오류 시에도 tariff_prediction으로 분류
10080

10181
state["intent"] = intent # type: ignore
102-
state["messages"].append(AIMessage(content=f"의도 분류 완료: {intent}"))
82+
state["messages"].append(AIMessage(content=f"의도 분류 완료: {intent} (LLM 분류)"))
10383
print(state)
10484
return state

core/tariff_prediction/agent/step_api.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,34 @@
55
from core.tariff_prediction.tools.calculate_tariff_amount import calculate_tariff_amount
66
from core.tariff_prediction.tools.parse_tariff_result import parse_tariff_result
77
from core.shared.utils.llm import get_llm
8+
from core.tariff_prediction.constants import LLM_PROMPT_TEMPLATES, STEP_API
89

910
def tariff_prediction_step_api(req: TariffPredictionRequest) -> TariffPredictionResponse:
1011
step = req.step
11-
# Step 자동 분류: step이 'auto'이거나 비어 있으면 LLM으로 분류
12-
if not step or step == 'auto':
12+
if not step or step == STEP_API['AUTO_STEP']:
1313
llm = get_llm()
14-
step_prompt = f"""
15-
다음 사용자 입력이 관세 예측 플로우의 어떤 단계에 해당하는지 분류하세요.
16-
- 상품 설명 입력: input
17-
- HS6 코드 선택: hs6_select
18-
- HS10 코드 선택 및 관세 계산: hs10_select
19-
반드시 input, hs6_select, hs10_select 중 하나로만 답하세요.
20-
사용자 입력: {req.product_description or req.hs6_code or req.hs10_code or ''}
21-
"""
14+
user_input = req.product_description or req.hs6_code or req.hs10_code or ''
15+
step_prompt = LLM_PROMPT_TEMPLATES['step_classification'].format(user_input=user_input)
2216
step_result = llm.invoke([{"role": "system", "content": step_prompt}])
2317
step = str(getattr(step_result, 'content', step_result)).strip()
24-
# Step별 분기
25-
if step == "input":
18+
if step == STEP_API['INPUT_STEP']:
2619
# 상품 설명 → HS6 후보 예측
2720
hs6_result = get_hs_classification(req.product_description)
2821
hs6_candidates = parse_hs6_result(hs6_result)
2922
return TariffPredictionResponse(
30-
step="hs6_select",
23+
step=STEP_API['HS6_SELECT_STEP'],
3124
hs6_candidates=hs6_candidates,
32-
message="상품에 해당하는 HS6 코드를 선택해 주세요."
25+
message=STEP_API['HS6_SELECTION_MESSAGE']
3326
)
34-
elif step == "hs6_select":
27+
elif step == STEP_API['HS6_SELECT_STEP']:
3528
# HS6 코드 → HS10 후보 추출
3629
hs10_candidates = generate_hs10_candidates(req.hs6_code)
3730
return TariffPredictionResponse(
38-
step="hs10_select",
31+
step=STEP_API['HS10_SELECT_STEP'],
3932
hs10_candidates=hs10_candidates,
40-
message="HS10 코드 후보를 선택해 주세요."
33+
message=STEP_API['HS10_SELECTION_MESSAGE']
4134
)
42-
elif step == "hs10_select":
35+
elif step == STEP_API['HS10_SELECT_STEP']:
4336
# HS10 코드, 국가, 가격 등 입력받아 관세 계산
4437
result = calculate_tariff_amount.invoke({
4538
"product_code": req.hs10_code,
@@ -49,30 +42,27 @@ def tariff_prediction_step_api(req: TariffPredictionRequest) -> TariffPrediction
4942
"shipping_cost": req.shipping_cost,
5043
"situation": req.scenario
5144
})
52-
45+
5346
# 결과를 문자열로 변환
5447
result_str = str(result)
5548

56-
# 에러 메시지인지 확인 (에러 메시지는 보통 짧고 특정 키워드를 포함)
57-
if result_str.startswith("오류") or result_str.startswith("Error") or "실패" in result_str or "오류" in result_str:
58-
# 에러 메시지
49+
if any(keyword in result_str for keyword in STEP_API['ERROR_KEYWORDS']):
5950
return TariffPredictionResponse(
60-
step="result",
51+
step=STEP_API['RESULT_STEP'],
6152
calculation_result=None,
6253
message=result_str
6354
)
6455
else:
65-
# 성공적인 결과 - 예쁘게 포맷팅
6656
parsed_result = parse_tariff_result(result_str)
6757
formatted_result = parsed_result['formatted_result']
6858

6959
return TariffPredictionResponse(
70-
step="result",
71-
calculation_result=parsed_result, # 딕셔너리 형태로 전달
72-
message=formatted_result # 포맷팅된 결과를 message에 전달
60+
step=STEP_API['RESULT_STEP'],
61+
calculation_result=parsed_result,
62+
message=formatted_result
7363
)
7464
else:
7565
return TariffPredictionResponse(
76-
step="hs6_select",
77-
message="잘못된 요청입니다. 상품 설명을 입력해 주세요."
66+
step=STEP_API['HS6_SELECT_STEP'],
67+
message=STEP_API['DEFAULT_ERROR_MESSAGE']
7868
)

0 commit comments

Comments
 (0)