-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
227 lines (186 loc) · 7.8 KB
/
Copy pathapp.py
File metadata and controls
227 lines (186 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Web 端入口 - FastAPI"""
import asyncio
import os
import time
import uuid, shutil
from contextlib import asynccontextmanager
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, UploadFile, File, Form, Request, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# 加载 .env 文件
load_dotenv()
from src.core.simple_classifier import classify, classify_with_llm, SubjectType, SUBJECT_CONFIG, SubjectCategory, get_category
from src.core.model_router import ModelRouter
from src.core.narrator import generate_script, generate_svg_animation, generate_highlight_html, parse_solution
from src.services.deepseek_client import DeepSeekClient
from src.services.qwen_client import QwenClient
from src.services.minimax_tts import MinimaxTTS
from src.services.dashscope_tts import DashScopeTTS
def _cleanup_old_tasks(max_age_seconds: int = 3600):
"""删除超过 max_age_seconds 的旧 task 目录"""
now = time.time()
output_dir = Path("output")
if not output_dir.exists():
return
for d in output_dir.iterdir():
if d.is_dir() and (now - d.stat().st_mtime) > max_age_seconds:
shutil.rmtree(d, ignore_errors=True)
async def _delete_task_dir_after_delay(task_dir: Path, delay: int = 60):
"""等待 delay 秒后删除 task 目录"""
await asyncio.sleep(delay)
shutil.rmtree(task_dir, ignore_errors=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
_cleanup_old_tasks()
yield
app = FastAPI(lifespan=lifespan)
# CORS middleware - allow frontend from different origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(exist_ok=True)
def get_config() -> dict:
"""从环境变量获取配置"""
return {
"deepseek_api_key": os.getenv("DEEPSEEK_API_KEY", ""),
"dashscope_api_key": os.getenv("DASHSCOPE_API_KEY", ""),
"minimax_api_key": os.getenv("MINIMAX_API_KEY", ""),
"minimax_group_id": os.getenv("MINIMAX_GROUP_ID", ""),
}
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/api/solve")
async def solve(
text: str = Form(""),
image: UploadFile | None = File(None),
):
try:
cfg = get_config()
task_id = uuid.uuid4().hex[:8]
task_dir = OUTPUT_DIR / task_id
task_dir.mkdir(exist_ok=True)
qwen = QwenClient(cfg["dashscope_api_key"])
deepseek = DeepSeekClient(cfg["deepseek_api_key"])
router = ModelRouter(deepseek, qwen)
# 1. 获取题目文本
if image and image.filename:
img_path = str(task_dir / image.filename)
with open(img_path, "wb") as f:
shutil.copyfileobj(image.file, f)
question = qwen.ocr(img_path)
elif text.strip():
question = text.strip()
else:
return JSONResponse({"error": "请输入题目或上传图片"}, status_code=400)
# 2. 分类
dashscope_key = cfg.get("dashscope_api_key", "")
if dashscope_key:
subject = classify_with_llm(question, dashscope_key)
else:
subject = classify(question)
label = SUBJECT_CONFIG[subject]["label"]
# 3. 求解
solution = router.solve(question, subject)
parsed = parse_solution(solution)
# 4. 动画
anim_type = SUBJECT_CONFIG[subject]["anim_type"]
if anim_type == "svg":
animation = generate_svg_animation(solution)
else:
animation = generate_highlight_html(solution)
# 5. TTS(优先使用 DashScope,其次 MiniMax)
audio_url = None
script = generate_script(solution, subject)
voice = SUBJECT_CONFIG[subject]["voice"]
audio_path = str(task_dir / "audio.mp3")
if cfg.get("dashscope_api_key"):
try:
tts = DashScopeTTS(cfg["dashscope_api_key"])
# Map voice name to DashScope voice
tts_voice = SUBJECT_CONFIG[subject].get("dashscope_voice", "female-yunxi")
tts.synthesize(script, voice=tts_voice, output_path=audio_path)
audio_url = f"/output/{task_id}/audio.mp3"
except Exception as e:
print(f"[TTS] DashScope TTS failed, trying MiniMax: {e}")
audio_url = None
if not audio_url and cfg.get("minimax_api_key") and cfg.get("minimax_group_id"):
try:
tts = MinimaxTTS(cfg["minimax_api_key"], cfg["minimax_group_id"])
tts.synthesize(script, voice=voice, output_path=audio_path)
audio_url = f"/output/{task_id}/audio.mp3"
except Exception as e:
print(f"[TTS] MiniMax TTS failed: {e}")
return JSONResponse({
"question": question,
"subject": label,
"solution": solution,
"thinking": parsed["thinking"],
"answer": parsed["answer"],
"animation": animation,
"anim_type": anim_type,
"audio_url": audio_url,
"tag_class": SUBJECT_CONFIG[subject]["tag_class"],
"task_id": task_id,
})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse({"error": f"API调用失败,请检查网络或API密钥: {type(e).__name__}"}, status_code=500)
@app.post("/api/ask")
async def ask_question(
task_id: str = Form(""),
question: str = Form(""),
subject: str = Form(""),
original_question: str = Form(""),
solution: str = Form(""),
thinking: str = Form(""),
answer: str = Form(""),
):
"""解答助手 - 基于解题上下文进行追问"""
try:
cfg = get_config()
deepseek = DeepSeekClient(cfg["deepseek_api_key"])
if not cfg.get("deepseek_api_key"):
return JSONResponse({"error": "DeepSeek API Key 未配置"}, status_code=400)
context = {
"subject": subject,
"question": original_question,
"solution": solution,
"thinking": thinking,
"answer": answer,
}
response = deepseek.ask(question, context)
return JSONResponse({"answer": response})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse({"error": f"API调用失败: {type(e).__name__}"}, status_code=500)
@app.get("/output/{task_id}/{filename}")
async def serve_output(task_id: str, filename: str, background_tasks: BackgroundTasks):
# Validate task_id is a valid hex string (prevents path traversal)
if not task_id.isalnum() or len(task_id) != 8:
return JSONResponse({"error": "invalid task_id"}, status_code=400)
# Validate filename doesn't contain path traversal attempts
if ".." in filename or "/" in filename or "\\" in filename:
return JSONResponse({"error": "invalid filename"}, status_code=400)
# Resolve path and verify it stays within OUTPUT_DIR
task_dir = OUTPUT_DIR / task_id
path = (task_dir / filename).resolve()
if not str(path).startswith(str(OUTPUT_DIR.resolve())):
return JSONResponse({"error": "access denied"}, status_code=403)
if path.exists():
background_tasks.add_task(_delete_task_dir_after_delay, task_dir, 60)
return FileResponse(path)
return JSONResponse({"error": "not found"}, status_code=404)