diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index effc313a..224c4609 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -319,6 +319,11 @@ class ChatMcp(ChatQuestion): token: str +class McpDs(BaseModel): + token: str = Body(description='用户token') + oid: Optional[int | str] = Body(description='组织ID,如果不传则为最后一次登录SQLBot时所使用的组织ID', default=None) + + class ChatStart(BaseModel): username: str = Body(description='用户名') password: str = Body(description='密码') @@ -331,6 +336,8 @@ class McpQuestion(BaseModel): stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) lang: Optional[str] = Body(description='语言:zh-CN|en|ko-KR', default='zh-CN') datasource_id: Optional[int | str] = Body(description='数据源ID,仅当当前对话没有确定数据源时有效', default=None) + oid: Optional[int | str] = Body( + description='组织ID,仅当数据源ID为空时有效,如果不传则为最后一次登录SQLBot时所使用的组织ID', default=None) class AxisObj(BaseModel): diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index bd5f560b..7f2ee5d3 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -13,9 +13,9 @@ from apps.chat.api.chat import create_chat, question_answer_inner from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \ - ChatFinishStep + ChatFinishStep, McpDs from apps.datasource.crud.datasource import get_datasource_list -from apps.system.crud.user import authenticate +from apps.system.crud.user import authenticate, user_ws_options from apps.system.crud.user import get_db_user from apps.system.models.system_model import UserWsModel from apps.system.models.user import UserModel @@ -81,9 +81,34 @@ def get_user(session: SessionDep, token: str): return session_user -@router.post("/mcp_ds_list", operation_id="mcp_datasource_list") -async def datasource_list(session: SessionDep, token: str): +@router.post("/mcp_start", operation_id="mcp_start") +async def mcp_start(session: SessionDep, chat: ChatStart): + user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password) + if not user: + raise HTTPException(status_code=400, detail="Incorrect account or password") + + if not user.oid or user.oid == 0: + raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator") + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + user_dict = user.to_dict() + t = Token(access_token=create_access_token( + user_dict, expires_delta=access_token_expires + )) + c = create_chat(session, user, CreateChat(origin=1), False) + return {"access_token": t.access_token, "chat_id": c.id} + + +@router.post("/mcp_ws_list", operation_id="mcp_ws_list") +async def ws_list(session: SessionDep, token: str): session_user = get_user(session, token) + return await user_ws_options(session, session_user.id) + + +@router.post("/mcp_ds_list", operation_id="mcp_datasource_list") +async def datasource_list(session: SessionDep, mcp_ds: McpDs): + session_user = get_user(session, mcp_ds.token) + if mcp_ds.oid is not None: + session_user.oid = mcp_ds.oid ds_list = get_datasource_list(session=session, user=session_user) result = [] for item in ds_list: @@ -103,26 +128,11 @@ async def datasource_list(session: SessionDep, token: str): # return session.query(AiModelDetail).all() -@router.post("/mcp_start", operation_id="mcp_start") -async def mcp_start(session: SessionDep, chat: ChatStart): - user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password) - if not user: - raise HTTPException(status_code=400, detail="Incorrect account or password") - - if not user.oid or user.oid == 0: - raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator") - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - user_dict = user.to_dict() - t = Token(access_token=create_access_token( - user_dict, expires_delta=access_token_expires - )) - c = create_chat(session, user, CreateChat(origin=1), False) - return {"access_token": t.access_token, "chat_id": c.id} - - @router.post("/mcp_question", operation_id="mcp_question") async def mcp_question(session: SessionDep, chat: McpQuestion): session_user = get_user(session, chat.token) + if chat.oid is not None: + session_user.oid = chat.oid ds_id: Optional[int] = None if chat.datasource_id: if isinstance(chat.datasource_id, str):