Skip to content

Commit 00eb78e

Browse files
committed
add v1/models && stream fc
1 parent f5b4cbd commit 00eb78e

4 files changed

Lines changed: 106 additions & 56 deletions

File tree

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
9898
default="default_model_name",
9999
help="just help to distinguish internal model name, use 'host:port/get_model_name' to get",
100100
)
101+
parser.add_argument(
102+
"--model_owner",
103+
type=str,
104+
default=None,
105+
help="the model owner, if not set, will use lightllm",
106+
)
101107

102108
parser.add_argument(
103109
"--model_dir",

lightllm/server/api_http.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import asyncio
2020
import collections
2121
import time
22+
2223
import uvloop
2324
import requests
2425
import base64
@@ -57,6 +58,8 @@
5758
ChatCompletionResponse,
5859
CompletionRequest,
5960
CompletionResponse,
61+
ModelCard,
62+
ModelListResponse,
6063
)
6164
from .build_prompt import build_prompt, init_tokenizer
6265

@@ -72,6 +75,9 @@ class G_Objs:
7275
g_generate_stream_func: Callable = None
7376
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
7477
shared_token_load: TokenLoad = None
78+
# OpenAI-compatible "created" timestamp for /v1/models.
79+
# Should be stable for the lifetime of this server process.
80+
model_created: int = None
7581

7682
def set_args(self, args: StartArgs):
7783
self.args = args
@@ -101,6 +107,8 @@ def set_args(self, args: StartArgs):
101107
self.httpserver_manager = HttpServerManager(args=args)
102108
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
103109
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
110+
if self.model_created is None:
111+
self.model_created = int(time.time())
104112

105113

106114
g_objs = G_Objs()
@@ -258,6 +266,26 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo
258266
return resp
259267

260268

269+
@app.get("/v1/models", response_model=ModelListResponse)
270+
@app.post("/v1/models", response_model=ModelListResponse)
271+
async def get_models(raw_request: Request):
272+
model_name = g_objs.args.model_name
273+
max_model_len = g_objs.args.max_req_total_len
274+
if model_name == "default_model_name" and g_objs.args.model_dir:
275+
model_name = os.path.basename(g_objs.args.model_dir.rstrip("/"))
276+
277+
return ModelListResponse(
278+
data=[
279+
ModelCard(
280+
id=model_name,
281+
created=g_objs.model_created,
282+
max_model_len=max_model_len,
283+
owned_by=g_objs.args.model_owner,
284+
)
285+
]
286+
)
287+
288+
261289
@app.get("/tokens")
262290
@app.post("/tokens")
263291
async def tokens(request: Request):

lightllm/server/api_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,16 @@ class CompletionStreamResponse(BaseModel):
370370
@field_validator("id", mode="before")
371371
def ensure_id_is_str(cls, v):
372372
return str(v)
373+
374+
375+
class ModelCard(BaseModel):
376+
id: str
377+
object: str = "model"
378+
created: int = Field(default_factory=lambda: int(time.time()))
379+
owned_by: str = "lightllm"
380+
max_model_len: Optional[int] = None
381+
382+
383+
class ModelListResponse(BaseModel):
384+
object: str = "list"
385+
data: List[ModelCard]

lightllm/server/api_openai.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from http import HTTPStatus
2020
from PIL import Image
2121
import multiprocessing as mp
22-
from typing import Any, AsyncGenerator, Optional, Union, List, Dict
22+
from typing import Any, AsyncGenerator, Optional, Union, List, Dict, Tuple
2323
from typing import Callable
2424
from lightllm.server import TokenLoad
2525
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
@@ -371,16 +371,13 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
371371
)
372372
return resp
373373

374-
if sampling_params.n != 1:
375-
return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1")
376-
377374
parser_dict = {}
378375
reasoning_parser_dict = {}
379376

380377
# Streaming case
381378
async def stream_results() -> AsyncGenerator[bytes, None]:
382-
finish_reason = None
383-
has_emitted_tool_calls = False
379+
has_emitted_tool_calls: Dict[int, bool] = collections.defaultdict(bool)
380+
stream_tool_call_ids: Dict[Tuple[int, int], str] = {}
384381
from .req_id_generator import convert_sub_id_to_group_id
385382

386383
prompt_tokens = 0
@@ -389,18 +386,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
389386
prompt_tokens = metadata["prompt_tokens"]
390387
completion_tokens += 1
391388
group_request_id = convert_sub_id_to_group_id(sub_req_id)
392-
index = sub_req_id
389+
choice_index = sub_req_id - group_request_id
390+
393391
delta = request_output
394-
finish_reason = finish_status.get_finish_reason()
392+
current_finish_reason = finish_status.get_finish_reason()
395393

396394
# Handle reasoning content
397395
if get_env_start_args().reasoning_parser and request.separate_reasoning:
398396
reasoning_text, delta = _process_reasoning_stream(
399-
index, delta, reasoning_parser_dict, request_output, request
397+
choice_index, delta, reasoning_parser_dict, request_output, request
400398
)
401399
if reasoning_text:
402400
choice_data = ChatCompletionStreamResponseChoice(
403-
index=0,
401+
index=choice_index,
404402
delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text),
405403
finish_reason=None,
406404
)
@@ -415,13 +413,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
415413
if request.tool_choice != "none" and request.tools:
416414
# parse_increment => returns (normal_text, calls)
417415
normal_text, calls = _process_tools_stream(
418-
index=index, delta=delta, parser_dict=parser_dict, request=request
416+
index=choice_index, delta=delta, parser_dict=parser_dict, request=request
419417
)
420418

421419
# 1) if there's normal_text, output it as normal content
422420
if normal_text:
423421
choice_data = ChatCompletionStreamResponseChoice(
424-
index=0,
422+
index=choice_index,
425423
delta=DeltaMessage(role="assistant", content=normal_text),
426424
finish_reason=None,
427425
)
@@ -435,32 +433,39 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
435433

436434
# 2) if we found calls, we output them as separate chunk(s)
437435
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
436+
fc_parser = parser_dict[choice_index]
438437
for call_item in calls:
439-
has_emitted_tool_calls = True
438+
has_emitted_tool_calls[sub_req_id] = True
440439
# transform call_item -> FunctionResponse + ToolCall
441-
if finish_reason == "stop":
442-
latest_delta_len = 0
443-
if isinstance(call_item.parameters, str):
444-
latest_delta_len = len(call_item.parameters)
445-
446-
expected_call = json.dumps(
447-
parser.multi_format_parser.detectors[0].prev_tool_call_arr[index].get("arguments", {}),
448-
ensure_ascii=False,
449-
)
450-
actual_call = parser.multi_format_parser.detectors[0].streamed_args_for_tool[index]
451-
if latest_delta_len > 0:
452-
actual_call = actual_call[:-latest_delta_len]
453-
remaining_call = expected_call.replace(actual_call, "", 1)
454-
call_item.parameters = remaining_call
440+
if current_finish_reason == "stop":
441+
det = fc_parser.detector
442+
ti = call_item.tool_index
443+
if ti >= 0 and ti < len(det.prev_tool_call_arr) and ti < len(det.streamed_args_for_tool):
444+
latest_delta_len = 0
445+
if isinstance(call_item.parameters, str):
446+
latest_delta_len = len(call_item.parameters)
447+
448+
expected_call = json.dumps(
449+
det.prev_tool_call_arr[ti].get("arguments", {}),
450+
ensure_ascii=False,
451+
)
452+
actual_call = det.streamed_args_for_tool[ti]
453+
if latest_delta_len > 0:
454+
actual_call = actual_call[:-latest_delta_len]
455+
remaining_call = expected_call.replace(actual_call, "", 1)
456+
call_item.parameters = remaining_call
455457

458+
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
459+
id_key = (choice_index, call_item.tool_index)
456460
if call_item.name:
457-
# First chunk: include ID and function name
458-
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
459-
tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt)
461+
if id_key not in stream_tool_call_ids:
462+
stream_tool_call_ids[id_key] = _process_tool_call_id(
463+
tool_parser, call_item, history_tool_calls_cnt
464+
)
465+
tool_call_id = stream_tool_call_ids[id_key]
460466
function_name = call_item.name
461467
else:
462-
# Subsequent chunks: null ID and name for argument deltas
463-
tool_call_id = None
468+
tool_call_id = stream_tool_call_ids.get(id_key)
464469
function_name = None
465470

466471
tool_call = ToolCall(
@@ -472,7 +477,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
472477
),
473478
)
474479
choice_data = ChatCompletionStreamResponseChoice(
475-
index=0,
480+
index=choice_index,
476481
delta=DeltaMessage(role="assistant", tool_calls=[tool_call]),
477482
finish_reason=None,
478483
)
@@ -485,7 +490,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
485490
yield f"data: {chunk.model_dump_json()}\n\n"
486491
else:
487492
delta_message = DeltaMessage(role="assistant", content=delta)
488-
stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None)
493+
stream_choice = ChatCompletionStreamResponseChoice(
494+
index=choice_index, delta=delta_message, finish_reason=None
495+
)
489496
stream_resp = ChatCompletionStreamResponse(
490497
id=group_request_id,
491498
created=created_time,
@@ -494,24 +501,22 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
494501
)
495502
yield f"data: {stream_resp.model_dump_json()}\n\n"
496503

497-
# Determine final finish_reason: override to "tool_calls" if tool calls were emitted
498-
if has_emitted_tool_calls and finish_reason == "stop":
499-
finish_reason = "tool_calls"
500-
501-
# Final empty chunk containing only finish_reason (and role)
502-
if finish_reason is not None:
503-
final_choice = ChatCompletionStreamResponseChoice(
504-
index=0,
505-
delta=DeltaMessage(),
506-
finish_reason=finish_reason,
507-
)
508-
final_chunk = ChatCompletionStreamResponse(
509-
id=group_request_id,
510-
created=created_time,
511-
model=request.model,
512-
choices=[final_choice],
513-
)
514-
yield f"data: {final_chunk.model_dump_json()}\n\n"
504+
# Emit a per-choice final empty chunk with finish_reason.
505+
if current_finish_reason is not None:
506+
if has_emitted_tool_calls[sub_req_id] and current_finish_reason == "stop":
507+
current_finish_reason = "tool_calls"
508+
final_choice = ChatCompletionStreamResponseChoice(
509+
index=choice_index,
510+
delta=DeltaMessage(),
511+
finish_reason=current_finish_reason,
512+
)
513+
final_chunk = ChatCompletionStreamResponse(
514+
id=group_request_id,
515+
created=created_time,
516+
model=request.model,
517+
choices=[final_choice],
518+
)
519+
yield f"data: {final_chunk.model_dump_json()}\n\n"
515520

516521
if request.stream_options and request.stream_options.include_usage:
517522
usage = UsageInfo(
@@ -634,9 +639,6 @@ async def _process_prompts_completion(
634639
"Streaming is not supported for batch requests",
635640
)
636641

637-
if sampling_params.n != 1:
638-
return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1")
639-
640642
return await _handle_streaming_completion(
641643
prompts[0], sampling_params, multimodal_params, raw_request, request, created_time
642644
)
@@ -690,6 +692,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
690692

691693
async for sub_req_id, request_output, metadata, finish_status in results_generator:
692694
group_request_id = convert_sub_id_to_group_id(sub_req_id)
695+
choice_index = sub_req_id - group_request_id
693696
prompt_tokens = metadata["prompt_tokens"]
694697
completion_tokens += 1
695698
current_finish_reason = None
@@ -704,7 +707,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
704707
output_text = prompt_str + output_text
705708

706709
stream_choice = CompletionStreamChoice(
707-
index=0,
710+
index=choice_index,
708711
text=output_text,
709712
finish_reason=current_finish_reason,
710713
logprobs=None if request.logprobs is None else {},

0 commit comments

Comments
 (0)