|
10 | 10 | from fastapi import APIRouter, Depends, Request, Body, Query |
11 | 11 | from fastapi.encoders import jsonable_encoder |
12 | 12 | from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse |
13 | | -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR |
| 13 | +from starlette.status import ( |
| 14 | + HTTP_200_OK, |
| 15 | + HTTP_400_BAD_REQUEST, |
| 16 | + HTTP_500_INTERNAL_SERVER_ERROR, |
| 17 | + HTTP_404_NOT_FOUND, |
| 18 | +) |
14 | 19 | from app.domain import ( |
15 | 20 | Tags, |
16 | 21 | TagsGenerative, |
|
35 | 40 | PATH_CHAT_COMPLETIONS = "/v1/chat/completions" |
36 | 41 | PATH_COMPLETIONS = "/v1/completions" |
37 | 42 | PATH_EMBEDDINGS = "/v1/embeddings" |
| 43 | +PATH_MODELS = "/v1/models" |
38 | 44 |
|
39 | 45 | router = APIRouter() |
40 | 46 | config = get_settings() |
@@ -200,7 +206,12 @@ def generate_chat_completions( |
200 | 206 | max_tokens = request_data.max_tokens |
201 | 207 | temperature = request_data.temperature |
202 | 208 | top_p = request_data.top_p |
203 | | - stop_sequences = request_data.stop_sequences |
| 209 | + if isinstance(request_data.stop, str): |
| 210 | + stop_sequences = [request_data.stop] |
| 211 | + elif isinstance(request_data.stop, list): |
| 212 | + stop_sequences = request_data.stop |
| 213 | + else: |
| 214 | + stop_sequences = [] |
204 | 215 | tracking_id = tracking_id or str(uuid.uuid4()) |
205 | 216 |
|
206 | 217 | if not messages: |
@@ -337,12 +348,11 @@ def generate_text_completions( |
337 | 348 | max_tokens = request_data.max_tokens |
338 | 349 | temperature = request_data.temperature |
339 | 350 | top_p = request_data.top_p |
340 | | - stop = request_data.stop |
341 | 351 |
|
342 | | - if isinstance(stop, str): |
343 | | - stop_sequences = [stop] |
344 | | - elif isinstance(stop, list): |
345 | | - stop_sequences = stop |
| 352 | + if isinstance(request_data.stop, str): |
| 353 | + stop_sequences = [request_data.stop] |
| 354 | + elif isinstance(request_data.stop, list): |
| 355 | + stop_sequences = request_data.stop |
346 | 356 | else: |
347 | 357 | stop_sequences = [] |
348 | 358 |
|
@@ -534,6 +544,81 @@ def embed_texts( |
534 | 544 | ) |
535 | 545 |
|
536 | 546 |
|
| 547 | +@router.get( |
| 548 | + PATH_MODELS, |
| 549 | + tags=[Tags.OpenAICompatible], |
| 550 | + dependencies=[Depends(cms_globals.props.current_active_user)], |
| 551 | + description="List available models, similar to OpenAI's /v1/models endpoint", |
| 552 | +) |
| 553 | +def list_models( |
| 554 | + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) |
| 555 | +) -> JSONResponse: |
| 556 | + """ |
| 557 | + Lists all available models, mimicking OpenAI's /v1/models endpoint. |
| 558 | +
|
| 559 | + Args: |
| 560 | + model_service (AbstractModelService): The model service dependency. |
| 561 | +
|
| 562 | + Returns: |
| 563 | + JSONResponse: A response containing the list of models. |
| 564 | + """ |
| 565 | + response = { |
| 566 | + "object": "list", |
| 567 | + "data": [ |
| 568 | + { |
| 569 | + "id": model_service.model_name.replace(" ", "_"), |
| 570 | + "object": "model", |
| 571 | + "created": 0, |
| 572 | + "owned_by": "cms", |
| 573 | + } |
| 574 | + ], |
| 575 | + } |
| 576 | + return JSONResponse(content=response) |
| 577 | + |
| 578 | + |
| 579 | +@router.get( |
| 580 | + PATH_MODELS + "/{model_name}", |
| 581 | + tags=[Tags.OpenAICompatible], |
| 582 | + dependencies=[Depends(cms_globals.props.current_active_user)], |
| 583 | + description="Get a specific model, similar to OpenAI's /v1/models/{model_id} endpoint", |
| 584 | +) |
| 585 | +def get_model( |
| 586 | + model_name: str, |
| 587 | + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) |
| 588 | +) -> JSONResponse: |
| 589 | + """ |
| 590 | + Gets a specific model by ID, mimicking OpenAI's /v1/models/{model_id} endpoint. |
| 591 | +
|
| 592 | + Args: |
| 593 | + model_name (str): The model name to retrieve. |
| 594 | + model_service (AbstractModelService): The model service dependency. |
| 595 | +
|
| 596 | + Returns: |
| 597 | + JSONResponse: A response containing the model details. |
| 598 | + """ |
| 599 | + if model_name != model_service.model_name.replace(" ", "_"): |
| 600 | + error_response = { |
| 601 | + "error": { |
| 602 | + "message": f"The model `{model_name}` does not exist", |
| 603 | + "type": "invalid_request_error", |
| 604 | + "param": None, |
| 605 | + "code": "model_not_found", |
| 606 | + } |
| 607 | + } |
| 608 | + return JSONResponse(content=error_response, status_code=HTTP_404_NOT_FOUND |
| 609 | +) |
| 610 | + response = { |
| 611 | + "id": model_name, |
| 612 | + "object": "model", |
| 613 | + "created": 0, |
| 614 | + "owned_by": "cms", |
| 615 | + "permission": [], |
| 616 | + "root": model_name, |
| 617 | + "parent": None, |
| 618 | + } |
| 619 | + return JSONResponse(content=response) |
| 620 | + |
| 621 | + |
537 | 622 | def _empty_prompt_error() -> Iterable[str]: |
538 | 623 | yield "ERROR: No prompt text provided\n" |
539 | 624 |
|
|
0 commit comments