Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
default="default_model_name",
help="just help to distinguish internal model name, use 'host:port/get_model_name' to get",
)
parser.add_argument(
"--model_owner",
type=str,
default=None,
help="the model owner, if not set, will use lightllm",
)

parser.add_argument(
"--model_dir",
Expand Down
28 changes: 28 additions & 0 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import asyncio
import collections
import time

import uvloop
import requests
import base64
Expand Down Expand Up @@ -57,6 +58,8 @@
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ModelCard,
ModelListResponse,
)
from .build_prompt import build_prompt, init_tokenizer

Expand All @@ -72,6 +75,9 @@ class G_Objs:
g_generate_stream_func: Callable = None
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
shared_token_load: TokenLoad = None
# OpenAI-compatible "created" timestamp for /v1/models.
# Should be stable for the lifetime of this server process.
model_created: int = None

def set_args(self, args: StartArgs):
self.args = args
Expand Down Expand Up @@ -101,6 +107,8 @@ def set_args(self, args: StartArgs):
self.httpserver_manager = HttpServerManager(args=args)
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
if self.model_created is None:
self.model_created = int(time.time())


g_objs = G_Objs()
Expand Down Expand Up @@ -258,6 +266,26 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo
return resp


@app.get("/v1/models", response_model=ModelListResponse)
@app.post("/v1/models", response_model=ModelListResponse)
async def get_models(raw_request: Request):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The raw_request parameter is unused in the get_models function and can be removed to improve code clarity.

Suggested change
async def get_models(raw_request: Request):
async def get_models():

model_name = g_objs.args.model_name
max_model_len = g_objs.args.max_req_total_len
if model_name == "default_model_name" and g_objs.args.model_dir:
model_name = os.path.basename(g_objs.args.model_dir.rstrip("/"))

return ModelListResponse(
data=[
ModelCard(
id=model_name,
created=g_objs.model_created,
max_model_len=max_model_len,
owned_by=g_objs.args.model_owner,
)
]
)


@app.get("/tokens")
@app.post("/tokens")
async def tokens(request: Request):
Expand Down
15 changes: 14 additions & 1 deletion lightllm/server/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ToolCall(BaseModel):

id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
type: Optional[Literal["function"]] = None
function: FunctionResponse


Expand Down Expand Up @@ -370,3 +370,16 @@ class CompletionStreamResponse(BaseModel):
@field_validator("id", mode="before")
def ensure_id_is_str(cls, v):
return str(v)


class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "lightllm"
max_model_len: Optional[int] = None


class ModelListResponse(BaseModel):
object: str = "list"
data: List[ModelCard]
Loading
Loading