Skip to content

Commit 171ff1c

Browse files
authored
Merge pull request #169 from CJackHwang/dev
图片多模态支持
2 parents 5aba782 + 69740e8 commit 171ff1c

8 files changed

Lines changed: 129 additions & 10 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ yarn-debug.log*
77
yarn-error.log*
88
pnpm-debug.log*
99
lerna-debug.log*
10+
/upload_images
1011

1112
# Diagnostic reports (https://nodejs.org/api/report.html)
1213
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json

api_utils/request_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ async def _process_request_refactored(
833833
await _handle_model_switching(req_id, context, check_client_disconnected)
834834
await _handle_parameter_cache(req_id, context)
835835

836-
prepared_prompt = await _prepare_and_validate_request(req_id, request, check_client_disconnected)
836+
prepared_prompt,image_list = await _prepare_and_validate_request(req_id, request, check_client_disconnected)
837837

838838
# 使用PageController处理页面交互
839839
# 注意:聊天历史清空已移至队列处理锁释放后执行
@@ -850,7 +850,7 @@ async def _process_request_refactored(
850850
# 优化:在提交提示前再次检查客户端连接,避免不必要的后台请求
851851
check_client_disconnected("提交提示前最终检查")
852852

853-
await page_controller.submit_prompt(prepared_prompt, check_client_disconnected)
853+
await page_controller.submit_prompt(prepared_prompt,image_list, check_client_disconnected)
854854

855855
# 响应处理仍然需要在这里,因为它决定了是流式还是非流式,并设置future
856856
response_result = await _handle_response_processing(

api_utils/request_processor_backup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def check_client_disconnected(*args):
198198
raise HTTPException(status_code=400, detail=f"[{req_id}] 无效请求: {e}")
199199

200200
# 准备提示
201-
prepared_prompt = prepare_combined_prompt(request.messages, req_id)
201+
prepared_prompt,image_list = prepare_combined_prompt(request.messages, req_id)
202202
check_client_disconnected("After Prompt Prep: ")
203203

204204
# 这里需要添加完整的处理逻辑 - 由于函数太长,暂时返回简化响应

api_utils/utils.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from typing import Any, Dict, List, Optional, AsyncGenerator
1111
from asyncio import Queue
1212
from models import Message
13-
13+
import re
14+
import base64
15+
import requests
16+
import os
17+
import hashlib
1418

1519

1620
# --- SSE生成函数 ---
@@ -194,6 +198,48 @@ def validate_chat_request(messages: List[Message], req_id: str) -> Dict[str, Opt
194198
}
195199

196200

201+
def extract_base64_to_local(base64_data: str) -> str:
202+
output_dir = os.path.join(os.path.dirname(__file__), '..', 'upload_images')
203+
match = re.match(r"data:image/(\w+);base64,(.*)", base64_data)
204+
if not match:
205+
print("错误: Base64 数据格式不正确。")
206+
return None
207+
208+
image_type = match.group(1) # 例如 "png", "jpeg"
209+
encoded_image_data = match.group(2)
210+
211+
try:
212+
# 解码 Base64 字符串
213+
decoded_image_data = base64.b64decode(encoded_image_data)
214+
except base64.binascii.Error as e:
215+
print(f"错误: Base64 解码失败 - {e}")
216+
return None
217+
218+
# 计算图片数据的 MD5 值
219+
md5_hash = hashlib.md5(decoded_image_data).hexdigest()
220+
221+
# 确定文件扩展名和完整文件路径
222+
file_extension = f".{image_type}"
223+
output_filepath = os.path.join(output_dir, f"{md5_hash}{file_extension}")
224+
225+
# 确保输出目录存在
226+
os.makedirs(output_dir, exist_ok=True)
227+
228+
if os.path.exists(output_filepath):
229+
print(f"文件已存在,跳过保存: {output_filepath}")
230+
return output_filepath
231+
232+
# 保存图片到文件
233+
try:
234+
with open(output_filepath, "wb") as f:
235+
f.write(decoded_image_data)
236+
print(f"图片已成功保存到: {output_filepath}")
237+
return output_filepath
238+
except IOError as e:
239+
print(f"错误: 保存文件失败 - {e}")
240+
return None
241+
242+
197243
# --- 提示准备函数 ---
198244
def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
199245
"""准备组合提示"""
@@ -238,6 +284,7 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
238284
role = msg.role or 'unknown'
239285
role_prefix_ui = f"{role_map_ui.get(role, role.capitalize())}:\n"
240286
current_turn_parts = [role_prefix_ui]
287+
images_list = []
241288

242289
content = msg.content or ''
243290
content_str = ""
@@ -252,6 +299,15 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
252299
text_parts.append(item.text or '')
253300
elif isinstance(item, dict) and item.get('type') == 'text':
254301
text_parts.append(item.get('text', ''))
302+
elif hasattr(item, 'type') and item.type == 'image_url':
303+
image_url_value = item.image_url.url
304+
if image_url_value.startswith("data:image/"):
305+
try:
306+
# 提取 Base64 字符串
307+
image_full_path = extract_base64_to_local(image_url_value)
308+
images_list.append(image_full_path)
309+
except (ValueError, requests.exceptions.RequestException, Exception) as e:
310+
print(f"处理 Base64 图片并上传到 Imgur 失败: {e}")
255311
else:
256312
logger.warning(f"[{req_id}] (准备提示) 警告: 在索引 {i} 的消息中忽略非文本或未知类型的 content item")
257313
content_str = "\n".join(text_parts).strip()
@@ -302,7 +358,7 @@ def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
302358
preview_text = final_prompt[:300].replace('\n', '\\n')
303359
logger.info(f"[{req_id}] (准备提示) 组合提示长度: {len(final_prompt)}。预览: '{preview_text}...'")
304360

305-
return final_prompt
361+
return final_prompt,images_list
306362

307363

308364
def estimate_tokens(text: str) -> int:

browser_utils/page_controller.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
MAT_CHIP_REMOVE_BUTTON_SELECTOR, TOP_P_INPUT_SELECTOR, SUBMIT_BUTTON_SELECTOR,
1313
CLEAR_CHAT_BUTTON_SELECTOR, CLEAR_CHAT_CONFIRM_BUTTON_SELECTOR, OVERLAY_SELECTOR,
1414
PROMPT_TEXTAREA_SELECTOR, RESPONSE_CONTAINER_SELECTOR, RESPONSE_TEXT_SELECTOR,
15-
EDIT_MESSAGE_BUTTON_SELECTOR
15+
EDIT_MESSAGE_BUTTON_SELECTOR,USE_URL_CONTEXT_SELECTOR,UPLOAD_BUTTON_SELECTOR
1616
)
1717
from config import (
1818
CLICK_TIMEOUT_MS, WAIT_FOR_ELEMENT_TIMEOUT_MS, CLEAR_CHAT_VERIFY_TIMEOUT_MS,
@@ -59,6 +59,30 @@ async def adjust_parameters(self, request_params: Dict[str, Any], page_params_ca
5959
await self._adjust_top_p(top_p_to_set, check_client_disconnected)
6060
await self._check_disconnect(check_client_disconnected, "End Parameter Adjustment")
6161

62+
# 调整URL CONTEXT
63+
await self._open_url_content(check_client_disconnected)
64+
65+
async def _open_url_content(self,check_client_disconnected: Callable):
66+
try:
67+
collapse_tools_locator = self.page.locator('button[aria-label="Expand or collapse tools"]')
68+
grandparent_locator = collapse_tools_locator.locator("xpath=../..")
69+
70+
# 3. 获取祖父级元素的 class 属性值
71+
# get_attribute 返回一个包含所有 class 的字符串,例如 "menu dropdown active"
72+
class_string = await grandparent_locator.get_attribute("class")
73+
74+
# 4. 在 Python 中进行判断
75+
# 确保 class_string 不是 None,并且 'expanded' 是一个独立的 class
76+
if class_string and "expanded" not in class_string.split():
77+
await collapse_tools_locator.click(timeout=CLICK_TIMEOUT_MS)
78+
await asyncio.sleep(0.5)
79+
use_url_content_selector = self.page.locator(USE_URL_CONTEXT_SELECTOR)
80+
is_checked = await use_url_content_selector.get_attribute("aria-checked")
81+
if "false" == is_checked:
82+
await use_url_content_selector.click(timeout=CLICK_TIMEOUT_MS)
83+
await self._check_disconnect(check_client_disconnected, "点击URLCONTEXT")
84+
except Exception as e:
85+
self.logger.error(f"[{self.req_id}] ❌ 操作USE_URL_CONTEXT_SELECTOR时发生错误:{e}。")
6286

6387
async def _adjust_temperature(self, temperature: float, page_params_cache: dict, params_cache_lock: asyncio.Lock, check_client_disconnected: Callable):
6488
"""调整温度参数。"""
@@ -76,6 +100,7 @@ async def _adjust_temperature(self, temperature: float, page_params_cache: dict,
76100
self.logger.info(f"[{self.req_id}] 请求温度 ({clamped_temp}) 与缓存值 ({cached_temp}) 不一致或缓存中无值。需要与页面交互。")
77101
temp_input_locator = self.page.locator(TEMPERATURE_INPUT_SELECTOR)
78102

103+
79104
try:
80105
await expect_async(temp_input_locator).to_be_visible(timeout=5000)
81106
await self._check_disconnect(check_client_disconnected, "温度调整 - 输入框可见后")
@@ -421,7 +446,7 @@ async def _verify_chat_cleared(self, check_client_disconnected: Callable):
421446
except Exception as verify_err:
422447
self.logger.warning(f"[{self.req_id}] ⚠️ 警告: 清空聊天验证失败 (最后响应容器未隐藏): {verify_err}")
423448

424-
async def submit_prompt(self, prompt: str, check_client_disconnected: Callable):
449+
async def submit_prompt(self, prompt: str,image_list: List, check_client_disconnected: Callable):
425450
"""提交提示到页面。"""
426451
self.logger.info(f"[{self.req_id}] 填充并提交提示 ({len(prompt)} chars)...")
427452
prompt_textarea_locator = self.page.locator(PROMPT_TEXTAREA_SELECTOR)
@@ -446,8 +471,39 @@ async def submit_prompt(self, prompt: str, check_client_disconnected: Callable):
446471
await autosize_wrapper_locator.evaluate('(element, text) => { element.setAttribute("data-value", text); }', prompt)
447472
await self._check_disconnect(check_client_disconnected, "After Input Fill")
448473

474+
# 上传
475+
if len(image_list) > 0:
476+
try:
477+
# 1. 监听文件选择器
478+
# page.expect_file_chooser() 会返回一个上下文管理器
479+
# 当文件选择器出现时,它会得到 FileChooser 对象
480+
function_btn_localtor = self.page.locator('button[aria-label="Insert assets such as images, videos, files, or audio"]')
481+
await function_btn_localtor.click()
482+
#asyncio.sleep(0.5)
483+
async with self.page.expect_file_chooser() as fc_info:
484+
# 2. 点击那个会触发文件选择的普通按钮
485+
upload_btn_localtor = self.page.locator(UPLOAD_BUTTON_SELECTOR)
486+
await upload_btn_localtor.click()
487+
print("点击了 JS 上传按钮,等待文件选择器...")
488+
489+
# 3. 获取文件选择器对象
490+
file_chooser = await fc_info.value
491+
print("文件选择器已出现。")
492+
493+
# 4. 设置要上传的文件
494+
await file_chooser.set_files(image_list)
495+
print(f"已将 '{image_list}' 设置到文件选择器。")
496+
497+
#asyncio.sleep(0.2)
498+
acknow_btn_locator = self.page.locator('button[aria-label="Agree to the copyright acknowledgement"]')
499+
if await acknow_btn_locator.count() > 0:
500+
await acknow_btn_locator.click()
501+
502+
except Exception as e:
503+
print(f"在上传文件时发生错误: {e}")
504+
449505
# 等待发送按钮启用
450-
wait_timeout_ms_submit_enabled = 40000
506+
wait_timeout_ms_submit_enabled = 100000
451507
try:
452508
await self._check_disconnect(check_client_disconnected, "填充提示后等待发送按钮启用 - 前置检查")
453509
await expect_async(submit_button_locator).to_be_enabled(timeout=wait_timeout_ms_submit_enabled)

config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
'MAT_CHIP_REMOVE_BUTTON_SELECTOR',
6666
'TOP_P_INPUT_SELECTOR',
6767
'TEMPERATURE_INPUT_SELECTOR',
68+
'USE_URL_CONTEXT_SELECTOR',
69+
'UPLOAD_BUTTON_SELECTOR',
6870

6971
# 设置配置
7072
'DEBUG_LOGS_ENABLED',

config/selectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SUBMIT_BUTTON_SELECTOR = 'button[aria-label="Run"].run-button'
1313
CLEAR_CHAT_BUTTON_SELECTOR = 'button[data-test-clear="outside"][aria-label="Clear chat"]'
1414
CLEAR_CHAT_CONFIRM_BUTTON_SELECTOR = 'button.mdc-button:has-text("Continue")'
15+
UPLOAD_BUTTON_SELECTOR = 'button[aria-label="Upload File"]'
1516

1617
# --- 响应相关选择器 ---
1718
RESPONSE_CONTAINER_SELECTOR = 'ms-chat-turn .chat-turn-container.model'
@@ -39,4 +40,5 @@
3940
STOP_SEQUENCE_INPUT_SELECTOR = 'input[aria-label="Add stop token"]'
4041
MAT_CHIP_REMOVE_BUTTON_SELECTOR = 'mat-chip-set mat-chip-row button[aria-label*="Remove"]'
4142
TOP_P_INPUT_SELECTOR = 'div.settings-item-column:has(h3:text-is("Top P")) input[type="number"].slider-input'
42-
TEMPERATURE_INPUT_SELECTOR = 'div[data-test-id="temperatureSliderContainer"] input[type="number"].slider-input'
43+
TEMPERATURE_INPUT_SELECTOR = 'div[data-test-id="temperatureSliderContainer"] input[type="number"].slider-input'
44+
USE_URL_CONTEXT_SELECTOR = 'button[aria-label="Browse the url context"]'

models/chat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ class ToolCall(BaseModel):
1313
type: str = "function"
1414
function: FunctionCall
1515

16+
class ImageURL(BaseModel):
17+
url: str
1618

1719
class MessageContentItem(BaseModel):
1820
type: str
1921
text: Optional[str] = None
20-
22+
image_url: Optional[ImageURL] = None
2123

2224
class Message(BaseModel):
2325
role: str

0 commit comments

Comments
 (0)