Skip to content

Commit cecd3b5

Browse files
committed
修复图片重传问题
1 parent 9d7a410 commit cecd3b5

3 files changed

Lines changed: 102 additions & 56 deletions

File tree

api_utils/request_processor.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MODEL_NAME,
2121
SUBMIT_BUTTON_SELECTOR,
2222
)
23+
from config import ONLY_COLLECT_CURRENT_USER_ATTACHMENTS, UPLOAD_FILES_DIR
2324

2425
# --- models模块导入 ---
2526
from models import ChatCompletionRequest, ClientDisconnectedError
@@ -125,6 +126,50 @@ async def _prepare_and_validate_request(
125126
prepared_prompt += f"\n---\n工具执行: {name}\n参数:\n{args}\n结果:\n{result_str}\n"
126127
except Exception:
127128
pass
129+
# 若配置仅收集当前用户消息附件,则在此过滤附件
130+
try:
131+
if ONLY_COLLECT_CURRENT_USER_ATTACHMENTS:
132+
latest_user = None
133+
for msg in reversed(request.messages or []):
134+
if getattr(msg, 'role', None) == 'user':
135+
latest_user = msg
136+
break
137+
if latest_user is not None:
138+
filtered: List[str] = []
139+
from api_utils.utils import extract_data_url_to_local
140+
from urllib.parse import urlparse, unquote
141+
import os
142+
# 收集该条 user 消息上的 data:/file:/绝对路径(存在的)
143+
content = getattr(latest_user, 'content', None)
144+
# 统一从 messages 附件字段抽取
145+
for key in ('attachments', 'images', 'files', 'media'):
146+
arr = getattr(latest_user, key, None)
147+
if not isinstance(arr, list):
148+
continue
149+
for it in arr:
150+
url_value = None
151+
if isinstance(it, str):
152+
url_value = it
153+
elif isinstance(it, dict):
154+
url_value = it.get('url') or it.get('path')
155+
url_value = (url_value or '').strip()
156+
if not url_value:
157+
continue
158+
if url_value.startswith('data:'):
159+
fp = extract_data_url_to_local(url_value)
160+
if fp:
161+
filtered.append(fp)
162+
elif url_value.startswith('file:'):
163+
parsed = urlparse(url_value)
164+
lp = unquote(parsed.path)
165+
if os.path.exists(lp):
166+
filtered.append(lp)
167+
elif os.path.isabs(url_value) and os.path.exists(url_value):
168+
filtered.append(url_value)
169+
images_list = filtered
170+
except Exception:
171+
pass
172+
128173
return prepared_prompt, images_list
129174

130175
async def _handle_response_processing(
@@ -347,11 +392,13 @@ async def _handle_playwright_response(req_id: str, request: ChatCompletionReques
347392
return None
348393

349394

350-
async def _cleanup_request_resources(req_id: str, disconnect_check_task: Optional[asyncio.Task],
351-
completion_event: Optional[Event], result_future: Future,
352-
is_streaming: bool) -> None:
353-
"""清理请求资源"""
354-
from server import logger
395+
async def _cleanup_request_resources(req_id: str, disconnect_check_task: Optional[asyncio.Task],
396+
completion_event: Optional[Event], result_future: Future,
397+
is_streaming: bool) -> None:
398+
"""清理请求资源"""
399+
from server import logger
400+
from config import UPLOAD_FILES_DIR
401+
import os, shutil
355402

356403
if disconnect_check_task and not disconnect_check_task.done():
357404
disconnect_check_task.cancel()
@@ -362,7 +409,16 @@ async def _cleanup_request_resources(req_id: str, disconnect_check_task: Optiona
362409
except Exception as task_clean_err:
363410
logger.error(f"[{req_id}] 清理任务时出错: {task_clean_err}")
364411

365-
logger.info(f"[{req_id}] 处理完成。")
412+
logger.info(f"[{req_id}] 处理完成。")
413+
414+
# 清理本次请求的上传子目录,避免磁盘累积
415+
try:
416+
req_dir = os.path.join(UPLOAD_FILES_DIR, req_id)
417+
if os.path.isdir(req_dir):
418+
shutil.rmtree(req_dir, ignore_errors=True)
419+
logger.info(f"[{req_id}] 已清理请求上传目录: {req_dir}")
420+
except Exception as clean_err:
421+
logger.warning(f"[{req_id}] 清理上传目录失败: {clean_err}")
366422

367423
if is_streaming and completion_event and not completion_event.is_set() and (result_future.done() and result_future.exception() is not None):
368424
logger.warning(f"[{req_id}] 流式请求异常,确保完成事件已设置。")
@@ -406,13 +462,28 @@ async def _process_request_refactored(
406462
await _handle_parameter_cache(req_id, context)
407463

408464
prepared_prompt,image_list = await _prepare_and_validate_request(req_id, request, check_client_disconnected)
465+
# 额外合并顶层与消息级 attachments/files(兼容历史记录)已在下方处理;此处确保路径存在
466+
try:
467+
import os
468+
valid_images = []
469+
for p in image_list:
470+
if isinstance(p, str) and p and os.path.isabs(p) and os.path.exists(p):
471+
valid_images.append(p)
472+
if len(valid_images) != len(image_list):
473+
from server import logger
474+
logger.warning(f"[{req_id}] 过滤掉不存在的附件路径: {set(image_list) - set(valid_images)}")
475+
image_list = valid_images
476+
except Exception:
477+
pass
409478
# 兼容: 顶层与消息级附件字段合并到上传列表(仅 data:/file:/绝对路径)
479+
# 附件来源策略:仅接受当前请求显式提供的 data:/file:/绝对路径(存在的)
410480
try:
481+
from api_utils.utils import extract_data_url_to_local
482+
from urllib.parse import urlparse, unquote
483+
import os
484+
# 顶层 attachments
411485
top_level_atts = getattr(request, 'attachments', None)
412486
if isinstance(top_level_atts, list) and len(top_level_atts) > 0:
413-
from api_utils.utils import extract_data_url_to_local
414-
from urllib.parse import urlparse, unquote
415-
import os
416487
for it in top_level_atts:
417488
url_value = None
418489
if isinstance(it, str):
@@ -423,7 +494,7 @@ async def _process_request_refactored(
423494
if not url_value:
424495
continue
425496
if url_value.startswith('data:'):
426-
fp = extract_data_url_to_local(url_value)
497+
fp = extract_data_url_to_local(url_value, req_id=req_id)
427498
if fp:
428499
image_list.append(fp)
429500
elif url_value.startswith('file:'):
@@ -433,7 +504,7 @@ async def _process_request_refactored(
433504
image_list.append(lp)
434505
elif os.path.isabs(url_value) and os.path.exists(url_value):
435506
image_list.append(url_value)
436-
# 消息级 attachments/images/files/media
507+
# 消息级 attachments/images/files/media(全量收集,但仅保留有效本地/data)
437508
for msg in (request.messages or []):
438509
for key in ('attachments', 'images', 'files', 'media'):
439510
arr = getattr(msg, key, None)
@@ -449,7 +520,7 @@ async def _process_request_refactored(
449520
if not url_value:
450521
continue
451522
if url_value.startswith('data:'):
452-
fp = extract_data_url_to_local(url_value)
523+
fp = extract_data_url_to_local(url_value, req_id=req_id)
453524
if fp:
454525
image_list.append(fp)
455526
elif url_value.startswith('file:'):

api_utils/utils.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,15 @@ def _extension_for_mime(mime_type: str) -> str:
234234
return mapping.get(mime_type, f".{mime_type.split('/')[-1]}" if '/' in mime_type else '.bin')
235235

236236

237-
def extract_data_url_to_local(data_url: str) -> Optional[str]:
237+
def extract_data_url_to_local(data_url: str, req_id: Optional[str] = None) -> Optional[str]:
238238
"""
239239
解析并保存任意 data:URL (data:<mime>;base64,<payload>) 到本地文件,返回文件路径。
240240
支持图片、视频、音频、PDF 等常见类型。
241241
"""
242242
from server import logger
243243
# 允许保存到通用上传目录
244-
output_dir = os.path.join(os.path.dirname(__file__), '..', 'upload_files')
244+
from config import UPLOAD_FILES_DIR
245+
output_dir = UPLOAD_FILES_DIR if req_id is None else os.path.join(UPLOAD_FILES_DIR, req_id)
245246

246247
match = re.match(r"^data:(?P<mime>[^;]+);base64,(?P<data>.*)$", data_url)
247248
if not match:
@@ -261,16 +262,7 @@ def extract_data_url_to_local(data_url: str) -> Optional[str]:
261262
file_extension = _extension_for_mime(mime_type)
262263
output_filepath = os.path.join(output_dir, f"{md5_hash}{file_extension}")
263264

264-
# 每次处理前清理旧文件,确保目录为空
265-
try:
266-
if os.path.isdir(output_dir):
267-
for name in os.listdir(output_dir):
268-
try:
269-
os.remove(os.path.join(output_dir, name))
270-
except Exception:
271-
pass
272-
except Exception:
273-
pass
265+
# 仅按请求粒度清理目录;此处不再删除,以免多附件互相覆盖
274266
os.makedirs(output_dir, exist_ok=True)
275267

276268
if os.path.exists(output_filepath):
@@ -287,10 +279,11 @@ def extract_data_url_to_local(data_url: str) -> Optional[str]:
287279
return None
288280

289281

290-
def save_blob_to_local(raw_bytes: bytes, mime_type: Optional[str] = None, fmt_ext: Optional[str] = None) -> Optional[str]:
282+
def save_blob_to_local(raw_bytes: bytes, mime_type: Optional[str] = None, fmt_ext: Optional[str] = None, req_id: Optional[str] = None) -> Optional[str]:
291283
"""将原始数据保存到 upload_files/ 下,按内容 MD5 命名,扩展名来源于 mime 或显式格式。"""
292284
from server import logger
293-
output_dir = os.path.join(os.path.dirname(__file__), '..', 'upload_files')
285+
from config import UPLOAD_FILES_DIR
286+
output_dir = UPLOAD_FILES_DIR if req_id is None else os.path.join(UPLOAD_FILES_DIR, req_id)
294287
md5_hash = hashlib.md5(raw_bytes).hexdigest()
295288
ext = None
296289
if fmt_ext:
@@ -300,15 +293,7 @@ def save_blob_to_local(raw_bytes: bytes, mime_type: Optional[str] = None, fmt_ex
300293
ext = _extension_for_mime(mime_type)
301294
if not ext:
302295
ext = '.bin'
303-
try:
304-
if os.path.isdir(output_dir):
305-
for name in os.listdir(output_dir):
306-
try:
307-
os.remove(os.path.join(output_dir, name))
308-
except Exception:
309-
pass
310-
except Exception:
311-
pass
296+
# 仅按请求粒度清理目录;此处不再删除,以免多附件互相覆盖
312297
os.makedirs(output_dir, exist_ok=True)
313298
output_filepath = os.path.join(output_dir, f"{md5_hash}{ext}")
314299
if os.path.exists(output_filepath):
@@ -330,19 +315,7 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> Tuple[str,
330315
from server import logger
331316

332317
logger.info(f"[{req_id}] (准备提示) 正在从 {len(messages)} 条消息准备组合提示 (包括历史)。")
333-
# 清空上一请求的上传目录(按请求粒度),避免残留文件
334-
try:
335-
upload_dir = os.path.join(os.path.dirname(__file__), '..', 'upload_files')
336-
if os.path.isdir(upload_dir):
337-
for name in os.listdir(upload_dir):
338-
fp = os.path.join(upload_dir, name)
339-
try:
340-
if os.path.isfile(fp):
341-
os.remove(fp)
342-
except Exception:
343-
pass
344-
except Exception:
345-
pass
318+
# 不在此处清空 upload_files;由上层在每次请求开始时按需清理,避免历史附件丢失导致“文件不存在”错误。
346319

347320
combined_parts = []
348321
system_prompt_content: Optional[str] = None
@@ -477,7 +450,7 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> Tuple[str,
477450

478451
# 归一化到本地文件列表,并记录日志
479452
if url_value.startswith('data:'):
480-
file_path = extract_data_url_to_local(url_value)
453+
file_path = extract_data_url_to_local(url_value, req_id=req_id)
481454
if file_path:
482455
files_list.append(file_path)
483456
logger.info(f"[{req_id}] (准备提示) 已识别并加入 data:URL 附件: {file_path}")
@@ -527,7 +500,7 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> Tuple[str,
527500

528501
if url_value:
529502
if url_value.startswith('data:'):
530-
saved = extract_data_url_to_local(url_value)
503+
saved = extract_data_url_to_local(url_value, req_id=req_id)
531504
if saved:
532505
files_list.append(saved)
533506
logger.info(f"[{req_id}] (准备提示) 已识别并加入音视频 data:URL 附件: {saved}")
@@ -542,15 +515,15 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> Tuple[str,
542515
logger.info(f"[{req_id}] (准备提示) 已识别并加入音视频本地附件(绝对路径): {url_value}")
543516
elif data_val:
544517
if isinstance(data_val, str) and data_val.startswith('data:'):
545-
saved = extract_data_url_to_local(data_val)
518+
saved = extract_data_url_to_local(data_val, req_id=req_id)
546519
if saved:
547520
files_list.append(saved)
548521
logger.info(f"[{req_id}] (准备提示) 已识别并加入音视频 data:URL 附件: {saved}")
549522
else:
550523
# 认为是纯 base64 数据
551524
try:
552525
raw = base64.b64decode(data_val)
553-
saved = save_blob_to_local(raw, mime_val, fmt_val)
526+
saved = save_blob_to_local(raw, mime_val, fmt_val, req_id=req_id)
554527
if saved:
555528
files_list.append(saved)
556529
logger.info(f"[{req_id}] (准备提示) 已识别并加入音视频 base64 附件: {saved}")

config/settings.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
AUTH_PROFILES_DIR = os.path.join(os.path.dirname(__file__), '..', 'auth_profiles')
2323
ACTIVE_AUTH_DIR = os.path.join(AUTH_PROFILES_DIR, 'active')
2424
SAVED_AUTH_DIR = os.path.join(AUTH_PROFILES_DIR, 'saved')
25-
LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'logs')
26-
APP_LOG_FILE_PATH = os.path.join(LOG_DIR, 'app.log')
25+
LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'logs')
26+
APP_LOG_FILE_PATH = os.path.join(LOG_DIR, 'app.log')
27+
UPLOAD_FILES_DIR = os.path.join(os.path.dirname(__file__), '..', 'upload_files')
2728

2829
def get_environment_variable(key: str, default: str = '') -> str:
2930
"""获取环境变量值"""
@@ -49,6 +50,7 @@ def get_int_env(key: str, default: int = 0) -> int:
4950
NO_PROXY_ENV = os.environ.get('NO_PROXY')
5051

5152
# --- 脚本注入配置 ---
52-
ENABLE_SCRIPT_INJECTION = get_boolean_env('ENABLE_SCRIPT_INJECTION', True)
53+
ENABLE_SCRIPT_INJECTION = get_boolean_env('ENABLE_SCRIPT_INJECTION', True)
54+
ONLY_COLLECT_CURRENT_USER_ATTACHMENTS = get_boolean_env('ONLY_COLLECT_CURRENT_USER_ATTACHMENTS', False)
5355
USERSCRIPT_PATH = get_environment_variable('USERSCRIPT_PATH', 'browser_utils/more_modles.js')
54-
# 注意:MODEL_CONFIG_PATH 已废弃,现在直接从油猴脚本解析模型数据
56+
# 注意:MODEL_CONFIG_PATH 已废弃,现在直接从油猴脚本解析模型数据

0 commit comments

Comments
 (0)