diff --git a/lightllm/common/mamba_cache_mem_manager/__init__.py b/lightllm/common/mamba_cache_mem_manager/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7f4bf2513e..fa865892d7 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -98,6 +98,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default="default_model_name", help="just help to distinguish internal model name, use 'host:port/get_model_name' to get", ) + parser.add_argument( + "--model_owner", + type=str, + default=None, + help="the model owner, if not set, will use lightllm", + ) parser.add_argument( "--model_dir", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..50d992bf9c 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -19,6 +19,7 @@ import asyncio import collections import time + import uvloop import requests import base64 @@ -57,6 +58,8 @@ ChatCompletionResponse, CompletionRequest, CompletionResponse, + ModelCard, + ModelListResponse, ) from .build_prompt import build_prompt, init_tokenizer @@ -72,6 +75,9 @@ class G_Objs: g_generate_stream_func: Callable = None httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None shared_token_load: TokenLoad = None + # OpenAI-compatible "created" timestamp for /v1/models. + # Should be stable for the lifetime of this server process. + model_created: int = None def set_args(self, args: StartArgs): self.args = args @@ -101,6 +107,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManager(args=args) dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) + if self.model_created is None: + self.model_created = int(time.time()) g_objs = G_Objs() @@ -258,6 +266,26 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo return resp +@app.get("/v1/models", response_model=ModelListResponse) +@app.post("/v1/models", response_model=ModelListResponse) +async def get_models(raw_request: Request): + model_name = g_objs.args.model_name + max_model_len = g_objs.args.max_req_total_len + if model_name == "default_model_name" and g_objs.args.model_dir: + model_name = os.path.basename(g_objs.args.model_dir.rstrip("/")) + + return ModelListResponse( + data=[ + ModelCard( + id=model_name, + created=g_objs.model_created, + max_model_len=max_model_len, + owned_by=g_objs.args.model_owner, + ) + ] + ) + + @app.get("/tokens") @app.post("/tokens") async def tokens(request: Request): diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 3d9a6bc8ed..7fc2696135 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -87,7 +87,7 @@ class ToolCall(BaseModel): id: Optional[str] = None index: Optional[int] = None - type: Literal["function"] = "function" + type: Optional[Literal["function"]] = None function: FunctionResponse @@ -370,3 +370,16 @@ class CompletionStreamResponse(BaseModel): @field_validator("id", mode="before") def ensure_id_is_str(cls, v): return str(v) + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "lightllm" + max_model_len: Optional[int] = None + + +class ModelListResponse(BaseModel): + object: str = "list" + data: List[ModelCard] diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 9bc4d26eb0..1a17691a95 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -19,7 +19,7 @@ from http import HTTPStatus from PIL import Image import multiprocessing as mp -from typing import Any, AsyncGenerator, Optional, Union, List, Dict +from typing import Any, AsyncGenerator, Optional, Union, List, Dict, Tuple from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -160,6 +160,22 @@ def _process_tools_stream(index: int, delta: str, parser_dict: Dict, request: Ch return normal_text, calls +def _split_tool_argument_delta(arguments: Optional[str]) -> List[str]: + """Split a complete JSON argument string into OpenAI-style deltas.""" + if not arguments: + return [] + if len(arguments) <= 2: + return [arguments] + if arguments[0] in "{[" and arguments[-1] in "}]": + middle = arguments[1:-1] + chunks = [arguments[0]] + if middle: + chunks.append(middle) + chunks.append(arguments[-1]) + return [chunk for chunk in chunks if chunk] + return [arguments] + + async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response: from .api_http import g_objs @@ -342,6 +358,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req ToolCall( id=tool_id, index=getattr(call_info, "tool_index", None), + type="function", function=FunctionResponse(name=call_info.name, arguments=call_info.parameters), ) ) @@ -371,16 +388,13 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req ) return resp - if sampling_params.n != 1: - return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1") - parser_dict = {} reasoning_parser_dict = {} # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: - finish_reason = None - has_emitted_tool_calls = False + has_emitted_tool_calls: Dict[int, bool] = collections.defaultdict(bool) + stream_tool_call_ids: Dict[Tuple[int, int], str] = {} from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 @@ -389,18 +403,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]: prompt_tokens = metadata["prompt_tokens"] completion_tokens += 1 group_request_id = convert_sub_id_to_group_id(sub_req_id) - index = sub_req_id + choice_index = sub_req_id - group_request_id + delta = request_output - finish_reason = finish_status.get_finish_reason() + current_finish_reason = finish_status.get_finish_reason() # Handle reasoning content if get_env_start_args().reasoning_parser and request.separate_reasoning: reasoning_text, delta = _process_reasoning_stream( - index, delta, reasoning_parser_dict, request_output, request + choice_index, delta, reasoning_parser_dict, request_output, request ) if reasoning_text: choice_data = ChatCompletionStreamResponseChoice( - index=0, + index=choice_index, delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), finish_reason=None, ) @@ -410,18 +425,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" if request.tool_choice != "none" and request.tools: # parse_increment => returns (normal_text, calls) normal_text, calls = _process_tools_stream( - index=index, delta=delta, parser_dict=parser_dict, request=request + index=choice_index, delta=delta, parser_dict=parser_dict, request=request ) # 1) if there's normal_text, output it as normal content - if normal_text: + if normal_text and (normal_text.strip() or not has_emitted_tool_calls[sub_req_id]): choice_data = ChatCompletionStreamResponseChoice( - index=0, + index=choice_index, delta=DeltaMessage(role="assistant", content=normal_text), finish_reason=None, ) @@ -431,87 +446,143 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) + fc_parser = parser_dict[choice_index] for call_item in calls: - has_emitted_tool_calls = True + has_emitted_tool_calls[sub_req_id] = True # transform call_item -> FunctionResponse + ToolCall - if finish_reason == "stop": - latest_delta_len = 0 - if isinstance(call_item.parameters, str): - latest_delta_len = len(call_item.parameters) - - expected_call = json.dumps( - parser.multi_format_parser.detectors[0].prev_tool_call_arr[index].get("arguments", {}), - ensure_ascii=False, - ) - actual_call = parser.multi_format_parser.detectors[0].streamed_args_for_tool[index] - if latest_delta_len > 0: - actual_call = actual_call[:-latest_delta_len] - remaining_call = expected_call.replace(actual_call, "", 1) - call_item.parameters = remaining_call + if current_finish_reason == "stop": + det = fc_parser.detector + ti = call_item.tool_index + if ti >= 0 and ti < len(det.prev_tool_call_arr) and ti < len(det.streamed_args_for_tool): + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + det.prev_tool_call_arr[ti].get("arguments", {}), + ensure_ascii=False, + ) + actual_call = det.streamed_args_for_tool[ti] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace(actual_call, "", 1) + call_item.parameters = remaining_call + tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" + stream_index = getattr(call_item, "tool_index", None) + id_key = (choice_index, stream_index) if call_item.name: - # First chunk: include ID and function name - tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" - tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt) + if id_key not in stream_tool_call_ids: + stream_tool_call_ids[id_key] = _process_tool_call_id( + tool_parser, call_item, history_tool_calls_cnt + ) + tool_call_id = stream_tool_call_ids[id_key] function_name = call_item.name else: - # Subsequent chunks: null ID and name for argument deltas - tool_call_id = None + tool_call_id = stream_tool_call_ids.get(id_key) function_name = None - tool_call = ToolCall( - id=tool_call_id, - index=getattr(call_item, "tool_index", None), - function=FunctionResponse( - name=function_name, - arguments=call_item.parameters, - ), - ) - choice_data = ChatCompletionStreamResponseChoice( - index=0, - delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason=None, - ) - chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" + is_tool_head = call_item.name is not None + + if is_tool_head and call_item.parameters: + head_tool_call = ToolCall( + id=tool_call_id, + index=stream_index, + type="function", + function=FunctionResponse( + name=function_name, + arguments="", + ), + ) + head_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(tool_calls=[head_tool_call]), + finish_reason=None, + ) + head_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[head_choice], + model=request.model, + ) + yield f"data: {head_chunk.model_dump_json(exclude_none=True)}\n\n" + + for arg_delta in _split_tool_argument_delta(call_item.parameters): + arg_tool_call = ToolCall( + index=stream_index, + function=FunctionResponse(arguments=arg_delta), + ) + arg_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(tool_calls=[arg_tool_call]), + finish_reason=None, + ) + arg_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[arg_choice], + model=request.model, + ) + yield f"data: {arg_chunk.model_dump_json(exclude_none=True)}\n\n" + else: + tool_call = ToolCall( + id=tool_call_id if is_tool_head else None, + index=stream_index, + type="function" if is_tool_head else None, + function=FunctionResponse( + name=function_name, + arguments=( + (call_item.parameters if call_item.parameters is not None else "") + if is_tool_head + else call_item.parameters + ), + ), + ) + choice_data = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(tool_calls=[tool_call]), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" else: delta_message = DeltaMessage(role="assistant", content=delta) - stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) + stream_choice = ChatCompletionStreamResponseChoice( + index=choice_index, delta=delta_message, finish_reason=None + ) stream_resp = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, choices=[stream_choice], ) - yield f"data: {stream_resp.model_dump_json()}\n\n" - - # Determine final finish_reason: override to "tool_calls" if tool calls were emitted - if has_emitted_tool_calls and finish_reason == "stop": - finish_reason = "tool_calls" - - # Final empty chunk containing only finish_reason (and role) - if finish_reason is not None: - final_choice = ChatCompletionStreamResponseChoice( - index=0, - delta=DeltaMessage(), - finish_reason=finish_reason, - ) - final_chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - model=request.model, - choices=[final_choice], - ) - yield f"data: {final_chunk.model_dump_json()}\n\n" + yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" + + # Emit a per-choice final empty chunk with finish_reason. + if current_finish_reason is not None: + if has_emitted_tool_calls[sub_req_id] and current_finish_reason == "stop": + current_finish_reason = "tool_calls" + final_choice = ChatCompletionStreamResponseChoice( + index=choice_index, + delta=DeltaMessage(), + finish_reason=current_finish_reason, + ) + final_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[final_choice], + ) + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( @@ -526,7 +597,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -634,9 +705,6 @@ async def _process_prompts_completion( "Streaming is not supported for batch requests", ) - if sampling_params.n != 1: - return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1") - return await _handle_streaming_completion( prompts[0], sampling_params, multimodal_params, raw_request, request, created_time ) @@ -690,6 +758,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for sub_req_id, request_output, metadata, finish_status in results_generator: group_request_id = convert_sub_id_to_group_id(sub_req_id) + choice_index = sub_req_id - group_request_id prompt_tokens = metadata["prompt_tokens"] completion_tokens += 1 current_finish_reason = None @@ -704,7 +773,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: output_text = prompt_str + output_text stream_choice = CompletionStreamChoice( - index=0, + index=choice_index, text=output_text, finish_reason=current_finish_reason, logprobs=None if request.logprobs is None else {}, diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 13aab66179..74208ac4b3 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1846,6 +1846,48 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional parameters=json.dumps(param_dict, ensure_ascii=False), ) + def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools: List[Tool]) -> Optional[str]: + """Build the current argument JSON from a partial XML tool-call body.""" + param_matches = self.parameter_regex.findall(partial_body) + if not param_matches: + return None + + param_config = self._get_param_config(func_name, tools) + param_dict = {} + has_visible_value = False + + for match in param_matches: + try: + idx = match.index(">") + except ValueError: + continue + + param_name = match[:idx].strip() + param_value = match[idx + 1 :] + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + if param_value.strip(): + has_visible_value = True + elif ( + f"" in partial_body + and f"{param_value}" in partial_body + ): + # Closed empty-string parameter. We can safely emit it. + has_visible_value = True + else: + # Parameter tag is present but its value has not started streaming yet. + continue + + param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + + if not param_dict and not has_visible_value: + return None + + return json.dumps(param_dict, ensure_ascii=False) + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: idx = text.find(self.bot_token) normal_text = text[:idx].strip() if idx != -1 else text @@ -1865,6 +1907,7 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult func_str = match[0] if match[0] else match[1] item = self._parse_function_call(func_str, tools) if item: + item.tool_index = len(calls) calls.append(item) return StreamingParseResult(normal_text=normal_text, calls=calls) @@ -1872,72 +1915,49 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: """Streaming incremental parsing for Qwen3-Coder XML tool calls.""" self._buffer += new_text - current_text = self._buffer - - if not self.has_tool_call(current_text): - partial_len = self._ends_with_partial_token(current_text, self.bot_token) - if partial_len: - return StreamingParseResult() - self._buffer = "" - cleaned = new_text.replace(self.eot_token, "") - return StreamingParseResult(normal_text=cleaned) - - # Check for complete tool call blocks - if self.eot_token in current_text: - result = self.detect_and_parse(current_text, tools) - last_end = current_text.rfind(self.eot_token) - if last_end != -1: - self._buffer = current_text[last_end + len(self.eot_token) :].lstrip() - else: - self._buffer = "" - self.current_tool_id = -1 - self.current_tool_name_sent = False - return result - - # Partial tool call - try to extract function name for early streaming if not hasattr(self, "_tool_indices"): self._tool_indices = self._get_tool_indices(tools) - calls = [] - tool_call_start = current_text.find(self.bot_token) - if tool_call_start == -1: - return StreamingParseResult() + normal_text = "" + calls: List[ToolCallItem] = [] - content_after = current_text[tool_call_start + len(self.bot_token) :] - func_prefix = "") - if gt_pos == -1: - return StreamingParseResult() + if tool_call_start == -1: + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult(normal_text=normal_text, calls=calls) + if current_text: + normal_text += current_text.replace(self.eot_token, "") + self._buffer = "" + return StreamingParseResult(normal_text=normal_text, calls=calls) - func_name = after_func[:gt_pos].strip() + if tool_call_start > 0: + normal_text += current_text[:tool_call_start] + self._buffer = current_text[tool_call_start:] + current_text = self._buffer - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] + eot_pos = current_text.find(self.eot_token) + if eot_pos == -1: + return StreamingParseResult(normal_text=normal_text, calls=calls) - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") + complete_block = current_text[: eot_pos + len(self.eot_token)] + func_matches = self.function_regex.findall(complete_block) - if func_name and func_name in self._tool_indices and not self.current_tool_name_sent: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=func_name, - parameters="", - ) - ) - self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id] = {"name": func_name, "arguments": {}} + if self.current_tool_id == -1: + self.current_tool_id = 0 + + for match in func_matches: + func_str = match[0] if match[0] else match[1] + item = self._parse_function_call(func_str, tools) + if item: + item.tool_index = self.current_tool_id + calls.append(item) + self.current_tool_id += 1 - return StreamingParseResult(normal_text="", calls=calls) + self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() class FunctionCallParser: diff --git a/test/test_api/test_stream_fc.py b/test/test_api/test_stream_fc.py new file mode 100644 index 0000000000..51f9ed9ae2 --- /dev/null +++ b/test/test_api/test_stream_fc.py @@ -0,0 +1,658 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +LightLLM OpenAI streaming function call (tool call) test script. + +Usage: + # Start LightLLM server first, e.g.: + # python -m lightllm.server.api_server --port 8000 --model_dir /path/to/model --tp 1 + + # Run all tests: + python test/test_api/test_stream_function_call.py + + # Specify server url and model: + python test/test_api/test_stream_function_call.py --base-url http://localhost:8080 --model my_model + + # Run a single test: + python test/test_api/test_stream_function_call.py --test single +""" + +import argparse +import json +import sys +import traceback +from typing import Dict, List, Optional + +from openai import OpenAI + +# ────────────────────────────────────────────── +# Tool definitions +# ────────────────────────────────────────────── + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "获取指定城市的天气信息", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "城市名称,例如:北京、上海"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "温度单位,默认 celsius", + }, + }, + "required": ["city"], + }, + }, +} + +CALCULATOR_TOOL = { + "type": "function", + "function": { + "name": "calculate", + "description": "执行数学计算", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "数学表达式,例如:2+3*4"}, + }, + "required": ["expression"], + }, + }, +} + +SEARCH_TOOL = { + "type": "function", + "function": { + "name": "web_search", + "description": "在互联网上搜索信息", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "搜索关键词"}, + "max_results": {"type": "integer", "description": "最大返回结果数,默认 5"}, + }, + "required": ["query"], + }, + }, +} + +ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL] + + +# ────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────── + + +def collect_stream_tool_calls(response) -> Dict: + """ + Consume a streaming chat completion response and reassemble: + - content: concatenated text content + - reasoning_content: concatenated reasoning content + - tool_calls: dict keyed by index -> {id, name, arguments} + - finish_reason: the final finish_reason + - chunks: raw chunk list for inspection + """ + content = "" + reasoning_content = "" + tool_calls: Dict[int, Dict] = {} + finish_reason = None + chunks = [] + + for chunk in response: + chunks.append(chunk) + choice = chunk.choices[0] + delta = choice.delta + + if choice.finish_reason is not None: + finish_reason = choice.finish_reason + + if delta.content: + content += delta.content + + if getattr(delta, "reasoning_content", None): + reasoning_content += delta.reasoning_content + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_calls: + tool_calls[idx] = {"id": None, "name": "", "arguments": ""} + if tc.id: + tool_calls[idx]["id"] = tc.id + if tc.function.name: + tool_calls[idx]["name"] = tc.function.name + if tc.function.arguments: + tool_calls[idx]["arguments"] += tc.function.arguments + + return { + "content": content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls, + "finish_reason": finish_reason, + "chunks": chunks, + } + + +def print_result(result: Dict): + """Pretty-print a collected stream result.""" + if result["reasoning_content"]: + print(f" [思考]: {result['reasoning_content'][:200]}...") + if result["content"]: + print(f" [内容]: {result['content']}") + if result["tool_calls"]: + for idx, tc in sorted(result["tool_calls"].items()): + print(f" [工具调用 {idx}]: id={tc['id']}, name={tc['name']}, arguments={tc['arguments']}") + print(f" [finish_reason]: {result['finish_reason']}") + print(f" [chunks数量]: {len(result['chunks'])}") + + +def assert_check(condition: bool, msg: str): + """Simple assertion with message.""" + if not condition: + raise AssertionError(f"FAIL: {msg}") + + +# ────────────────────────────────────────────── +# Test cases +# ────────────────────────────────────────────── + + +def test_single_tool_call(client: OpenAI, model: str): + """测试单个工具调用 - 查询天气""" + print("=" * 60) + print("[TEST] 单工具流式调用") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "北京今天天气怎么样?"}], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + # Validate + if result["tool_calls"]: + tc = result["tool_calls"][0] + assert_check(tc["id"] is not None and len(tc["id"]) > 0, "tool_call id 不应为空") + assert_check(tc["name"] == "get_weather", f"期望函数名 get_weather, 实际: {tc['name']}") + args = json.loads(tc["arguments"]) + assert_check("city" in args, "参数中应包含 city 字段") + assert_check( + result["finish_reason"] == "tool_calls", f"finish_reason 应为 tool_calls, 实际: {result['finish_reason']}" + ) + print(" [PASS] 单工具流式调用测试通过\n") + else: + print(" [WARN] 模型未调用工具,可能模型不支持该 tool_call_parser 格式\n") + + +def test_parallel_tool_calls(client: OpenAI, model: str): + """测试并行多工具调用""" + print("=" * 60) + print("[TEST] 并行多工具流式调用") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "帮我查一下北京和上海的天气"}], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + if len(result["tool_calls"]) >= 2: + # 检查每个 tool_call 的完整性 + ids_seen = set() + for idx, tc in result["tool_calls"].items(): + assert_check(tc["id"] is not None, f"tool_call[{idx}] id 不应为空") + assert_check(tc["id"] not in ids_seen, f"tool_call id 重复: {tc['id']}") + ids_seen.add(tc["id"]) + assert_check(tc["name"] == "get_weather", f"tool_call[{idx}] 函数名应为 get_weather") + args = json.loads(tc["arguments"]) + assert_check("city" in args, f"tool_call[{idx}] 参数中应包含 city 字段") + + assert_check( + result["finish_reason"] == "tool_calls", f"finish_reason 应为 tool_calls, 实际: {result['finish_reason']}" + ) + print(" [PASS] 并行多工具流式调用测试通过\n") + elif len(result["tool_calls"]) == 1: + print(" [WARN] 模型只调用了 1 个工具(可能不支持并行调用),跳过并行校验\n") + else: + print(" [WARN] 模型未调用工具\n") + + +def test_mixed_content_and_tool_calls(client: OpenAI, model: str): + """测试混合输出:模型先输出文本再调用工具""" + print("=" * 60) + print("[TEST] 文本+工具调用混合输出") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": "先说一句问候语,然后帮我查北京的天气", + } + ], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.7, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + if result["tool_calls"]: + tc = result["tool_calls"][0] + assert_check(tc["name"] == "get_weather", f"期望函数名 get_weather, 实际: {tc['name']}") + args = json.loads(tc["arguments"]) + assert_check("city" in args, "参数中应包含 city 字段") + print(" [PASS] 混合输出测试通过\n") + else: + # 有些模型可能只输出文本或只调用工具 + print(" [WARN] 模型未调用工具\n") + + +def test_tool_choice_required(client: OpenAI, model: str): + """测试 tool_choice=required,模型必须调用工具""" + print("=" * 60) + print("[TEST] tool_choice=required") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "你好"}], + tools=ALL_TOOLS, + tool_choice="required", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + if result["tool_calls"]: + tc = result["tool_calls"][0] + assert_check(tc["id"] is not None, "tool_call id 不应为空") + assert_check(tc["name"] in ["get_weather", "calculate", "web_search"], f"函数名不在预期范围: {tc['name']}") + assert_check( + result["finish_reason"] == "tool_calls", f"finish_reason 应为 tool_calls, 实际: {result['finish_reason']}" + ) + print(" [PASS] tool_choice=required 测试通过\n") + else: + print(" [WARN] tool_choice=required 但模型未调用工具\n") + + +def test_tool_choice_none(client: OpenAI, model: str): + """测试 tool_choice=none,模型不应调用工具""" + print("=" * 60) + print("[TEST] tool_choice=none") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "北京今天天气怎么样?"}], + tools=[WEATHER_TOOL], + tool_choice="none", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + assert_check(len(result["tool_calls"]) == 0, "tool_choice=none 时不应有工具调用") + assert_check( + result["finish_reason"] in ("stop", "length"), f"finish_reason 应为 stop 或 length, 实际: {result['finish_reason']}" + ) + print(" [PASS] tool_choice=none 测试通过\n") + + +def test_tool_choice_specific_function(client: OpenAI, model: str): + """测试 tool_choice 指定具体函数""" + print("=" * 60) + print("[TEST] tool_choice=指定函数") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "1+1等于几"}], + tools=ALL_TOOLS, + tool_choice={"type": "function", "function": {"name": "calculate"}}, + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + if result["tool_calls"]: + tc = result["tool_calls"][0] + assert_check(tc["name"] == "calculate", f"期望函数名 calculate, 实际: {tc['name']}") + args = json.loads(tc["arguments"]) + assert_check("expression" in args, "参数中应包含 expression 字段") + print(" [PASS] tool_choice 指定函数测试通过\n") + else: + print(" [WARN] 模型未调用指定函数\n") + + +def test_multi_turn_with_tool_result(client: OpenAI, model: str): + """测试多轮对话:工具调用 -> 返回结果 -> 模型继续回答""" + print("=" * 60) + print("[TEST] 多轮对话(工具调用+结果回传)") + print("=" * 60) + + # Round 1: 用户提问,模型应调用工具 + print(" --- Round 1: 用户提问 ---") + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "北京今天天气怎么样?"}], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result1 = collect_stream_tool_calls(response) + print_result(result1) + + if not result1["tool_calls"]: + print(" [SKIP] 模型未调用工具,跳过多轮测试\n") + return + + tc = result1["tool_calls"][0] + tool_call_id = tc["id"] + tool_name = tc["name"] + tool_args = tc["arguments"] + + # Round 2: 传回工具结果,模型应基于结果回答 + print(" --- Round 2: 回传工具结果 ---") + weather_result = json.dumps({"city": "北京", "temperature": 22, "condition": "晴", "humidity": 45}, ensure_ascii=False) + + messages = [ + {"role": "user", "content": "北京今天天气怎么样?"}, + { + "role": "assistant", + "content": result1["content"] if result1["content"] else None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": weather_result, + }, + ] + + response2 = client.chat.completions.create( + model=model, + messages=messages, + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result2 = collect_stream_tool_calls(response2) + print_result(result2) + + assert_check(len(result2["content"]) > 0, "模型应基于工具结果生成文本回复") + assert_check( + result2["finish_reason"] in ("stop", "length"), + f"finish_reason 应为 stop 或 length, 实际: {result2['finish_reason']}", + ) + print(" [PASS] 多轮对话测试通过\n") + + +def test_stream_chunk_integrity(client: OpenAI, model: str): + """测试流式 chunk 的结构完整性""" + print("=" * 60) + print("[TEST] 流式 chunk 结构完整性校验") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "帮我查北京天气"}], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + + if not result["tool_calls"]: + print(" [SKIP] 模型未调用工具,跳过 chunk 校验\n") + return + + # Validate chunk structure + has_role_chunk = False + tool_call_name_chunks = 0 + tool_call_arg_chunks = 0 + finish_reason_count = 0 + + for chunk in result["chunks"]: + assert_check(chunk.id is not None, "chunk.id 不应为空") + assert_check( + chunk.object == "chat.completion.chunk", f"chunk.object 应为 chat.completion.chunk, 实际: {chunk.object}" + ) + assert_check(len(chunk.choices) > 0, "chunk.choices 不应为空") + + choice = chunk.choices[0] + delta = choice.delta + + if delta.role == "assistant": + has_role_chunk = True + + if choice.finish_reason is not None: + finish_reason_count += 1 + + if delta.tool_calls: + for tc in delta.tool_calls: + assert_check(tc.index is not None, "tool_call.index 不应为 None") + assert_check( + tc.type is None or tc.type == "function", f"tool_call.type 应为 function 或 None, 实际: {tc.type}" + ) + if tc.function.name: + tool_call_name_chunks += 1 + if tc.function.arguments: + tool_call_arg_chunks += 1 + + print(f" 总 chunks: {len(result['chunks'])}") + print(f" 包含 role 的 chunk: {has_role_chunk}") + print(f" 包含函数名的 chunk: {tool_call_name_chunks}") + print(f" 包含参数的 chunk: {tool_call_arg_chunks}") + print(f" finish_reason chunk: {finish_reason_count}") + + assert_check(has_role_chunk, "应有至少一个 chunk 包含 role=assistant") + assert_check(tool_call_name_chunks >= 1, "应有至少一个 chunk 包含函数名") + assert_check(tool_call_arg_chunks >= 1, "应有至少一个 chunk 包含参数") + assert_check(finish_reason_count >= 1, "应有至少一个 chunk 包含 finish_reason") + + # 验证拼接后的 arguments 是合法 JSON + for idx, tc in result["tool_calls"].items(): + try: + json.loads(tc["arguments"]) + except json.JSONDecodeError as e: + raise AssertionError(f"tool_call[{idx}] arguments 不是合法 JSON: {tc['arguments']}, error: {e}") + + print(" [PASS] 流式 chunk 结构完整性校验通过\n") + + +def test_multiple_different_tools(client: OpenAI, model: str): + """测试同时调用多个不同的工具""" + print("=" * 60) + print("[TEST] 多种工具并行调用") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": "帮我查北京天气,再算一下 123*456 等于多少", + } + ], + tools=[WEATHER_TOOL, CALCULATOR_TOOL], + tool_choice="auto", + stream=True, + temperature=0.0, + max_tokens=1000, + ) + + result = collect_stream_tool_calls(response) + print_result(result) + + if len(result["tool_calls"]) >= 2: + names = {tc["name"] for tc in result["tool_calls"].values()} + assert_check("get_weather" in names, "应包含 get_weather 调用") + assert_check("calculate" in names, "应包含 calculate 调用") + + # 每个 tool_call 的 id 应唯一 + ids = [tc["id"] for tc in result["tool_calls"].values()] + assert_check(len(ids) == len(set(ids)), f"tool_call id 应唯一, 实际: {ids}") + + # 每个 arguments 应是合法 JSON + for idx, tc in result["tool_calls"].items(): + json.loads(tc["arguments"]) + + print(" [PASS] 多种工具并行调用测试通过\n") + elif len(result["tool_calls"]) == 1: + print(" [WARN] 模型只调用了 1 个工具,跳过多工具校验\n") + else: + print(" [WARN] 模型未调用工具\n") + + +def test_stream_with_usage(client: OpenAI, model: str): + """测试流式输出中的 usage 信息(stream_options.include_usage)""" + print("=" * 60) + print("[TEST] 流式输出 usage 信息") + print("=" * 60) + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "帮我查北京天气"}], + tools=[WEATHER_TOOL], + tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + temperature=0.0, + max_tokens=1000, + ) + + usage_info = None + chunks = [] + for chunk in response: + chunks.append(chunk) + if chunk.usage is not None: + usage_info = chunk.usage + + if usage_info: + print(f" prompt_tokens: {usage_info.prompt_tokens}") + print(f" completion_tokens: {usage_info.completion_tokens}") + print(f" total_tokens: {usage_info.total_tokens}") + assert_check(usage_info.prompt_tokens > 0, "prompt_tokens 应 > 0") + assert_check(usage_info.completion_tokens > 0, "completion_tokens 应 > 0") + assert_check( + usage_info.total_tokens == usage_info.prompt_tokens + usage_info.completion_tokens, + "total_tokens 应等于 prompt + completion", + ) + print(" [PASS] 流式 usage 信息测试通过\n") + else: + print(" [WARN] 未收到 usage 信息(服务端可能不支持 stream_options)\n") + + +# ────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────── + +TEST_REGISTRY = { + "single": test_single_tool_call, + "parallel": test_parallel_tool_calls, + "mixed": test_mixed_content_and_tool_calls, + "required": test_tool_choice_required, + "none": test_tool_choice_none, + "specific": test_tool_choice_specific_function, + "multi_turn": test_multi_turn_with_tool_result, + "chunk_integrity": test_stream_chunk_integrity, + "multi_tools": test_multiple_different_tools, + "usage": test_stream_with_usage, +} + + +def main(): + parser = argparse.ArgumentParser(description="LightLLM streaming function call test") + parser.add_argument("--base-url", default="http://localhost:8000/v1", help="LightLLM server base URL") + parser.add_argument("--model", default="default_model", help="Model name") + parser.add_argument( + "--test", + default=None, + choices=list(TEST_REGISTRY.keys()), + help="Run a specific test (default: run all)", + ) + parser.add_argument("--api-key", default="EMPTY", help="API key (default: EMPTY)") + args = parser.parse_args() + + client = OpenAI(base_url=args.base_url, api_key=args.api_key) + + print(f"Server: {args.base_url}") + print(f"Model: {args.model}") + print() + + tests_to_run = [args.test] if args.test else list(TEST_REGISTRY.keys()) + passed = 0 + failed = 0 + + for name in tests_to_run: + try: + TEST_REGISTRY[name](client, args.model) + passed += 1 + except AssertionError as e: + print(f" [FAIL] {e}") + traceback.print_exc() + failed += 1 + print() + except Exception as e: + print(f" [ERROR] {e}") + traceback.print_exc() + failed += 1 + print() + + print("=" * 60) + print(f"结果: {passed} passed, {failed} failed (共 {len(tests_to_run)} 个测试)") + print("=" * 60) + + sys.exit(1 if failed > 0 else 0) + + +if __name__ == "__main__": + main()