Skip to content

Commit 98554f8

Browse files
authored
Merge pull request #41 from CogStack/llm-gen2
Add micro batching and enpoints for v1 list_models and get_model
2 parents 18c4efa + 184832d commit 98554f8

51 files changed

Lines changed: 1089 additions & 2210 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/api-docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
runs-on: ubuntu-latest
1717
strategy:
1818
matrix:
19-
python-version: [ '3.10' ]
19+
python-version: [ '3.11' ]
2020
max-parallel: 1
2121

2222
steps:

.github/workflows/docker.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- uses: actions/checkout@v4
2020

2121
- name: Lint
22-
run: hadolint --ignore DL3008 --ignore DL3013 --ignore DL3003 --ignore DL4006 docker/Dockerfile* docker/**/Dockerfile*
22+
run: hadolint --ignore DL3008 --ignore DL4006 --ignore DL3006 --ignore SC2046 docker/Dockerfile
2323

2424
build-and-push:
2525
needs: lint
@@ -74,6 +74,9 @@ jobs:
7474
platforms: linux/amd64,linux/arm64
7575
context: .
7676
file: docker/Dockerfile
77+
build-args: |
78+
IMAGE_TYPE=gpu
79+
PIP_EXTRAS=llm
7780
push: true
7881
tags: ${{ steps.cms_meta.outputs.tags }}
7982
labels: ${{ steps.cms_meta.outputs.labels }}

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Install uv and set Python to ${{ matrix.python-version }}
2525
uses: astral-sh/setup-uv@v6
2626
with:
27-
version: "0.8.10"
27+
version: "0.9.30"
2828
python-version: ${{ matrix.python-version }}
2929
- name: Install dependencies
3030
run: |

.github/workflows/release-gpu.yaml

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
name: release
2+
3+
on:
4+
release:
5+
types: [published]
6+
7+
env:
8+
REGISTRY: docker.io
9+
CMS_GPU_IMAGE_NAME: cogstacksystems/cogstack-modelserve-gpu
10+
11+
jobs:
12+
ensure-branch:
13+
runs-on: ubuntu-latest
14+
outputs:
15+
is-valid: ${{ steps.ensure-branch.outputs.is-valid }}
16+
steps:
17+
- name: Ensures release is from the production branch only
18+
id: ensure-branch
19+
run: |
20+
TARGET_BRANCH="${{ github.event.release.target_commitish }}"
21+
if [ "$TARGET_BRANCH" != "production" ]; then
22+
echo "Only releases from the 'production' branch are allowed but found: $TARGET_BRANCH"
23+
echo "is-valid=false" >> "$GITHUB_OUTPUT"
24+
exit 1
25+
else
26+
echo "Target release branch is: $TARGET_BRANCH"
27+
echo "is-valid=true" >> "$GITHUB_OUTPUT"
28+
fi
29+
30+
qc:
31+
runs-on: ubuntu-latest
32+
needs: ensure-branch
33+
if: needs.ensure-branch.outputs.is-valid == 'true'
34+
steps:
35+
- uses: actions/checkout@v4
36+
- name: Install uv
37+
uses: astral-sh/setup-uv@v5
38+
with:
39+
version: "0.9.30"
40+
python-version: "3.11"
41+
- name: Install dependencies
42+
run: |
43+
uv sync --lock --extra dev --extra docs --extra llm
44+
uv run python -m ensurepip
45+
- name: Run unit tests
46+
run: |
47+
uv run pytest -v tests/app --cov --cov-report=html:coverage_reports #--random-order
48+
- name: Run integration tests
49+
run: |
50+
uv run pytest -s -v tests/integration
51+
52+
release-gpu:
53+
runs-on: ubuntu-latest
54+
needs: [ensure-branch, qc]
55+
if: needs.ensure-branch.outputs.is-valid == 'true'
56+
permissions:
57+
contents: read
58+
packages: write
59+
id-token: write
60+
attestations: write
61+
steps:
62+
- uses: actions/checkout@v4
63+
64+
- name: Set up QEMU
65+
uses: docker/setup-qemu-action@v3
66+
67+
- name: Set up Docker Buildx
68+
uses: docker/setup-buildx-action@v3
69+
70+
- name: Extract the tag
71+
run: |
72+
echo "RELEASE_VERSION=${GITHUB_REF/refs\/tags\/v/}" >> $GITHUB_ENV
73+
74+
- name: Login to Docker Hub
75+
uses: docker/login-action@v3
76+
with:
77+
registry: ${{ env.REGISTRY }}
78+
username: ${{ secrets.DOCKERHUB_USERNAME }}
79+
password: ${{ secrets.DOCKERHUB_TOKEN }}
80+
81+
- name: Extract CMS meta
82+
id: cms_meta
83+
uses: docker/metadata-action@v5
84+
with:
85+
images: ${{ env.REGISTRY }}/${{ env.CMS_GPU_IMAGE_NAME }}
86+
87+
- name: Build and push CMS image
88+
uses: docker/build-push-action@v6
89+
id: build_and_push_cms
90+
with:
91+
platforms: linux/amd64,linux/arm64
92+
context: .
93+
file: docker/Dockerfile
94+
build-args: |
95+
IMAGE_TYPE=gpu
96+
PIP_EXTRAS=llm
97+
push: true
98+
github-token: ${{ github.token }}
99+
tags: |
100+
${{ env.REGISTRY }}/${{ env.CMS_GPU_IMAGE_NAME }}:${{ env.RELEASE_VERSION }}
101+
labels: ${{ steps.cms_meta.outputs.labels }}
102+
103+
- name: Attest CMS image artifacts
104+
uses: actions/attest-build-provenance@v2
105+
with:
106+
subject-name: ${{ env.REGISTRY }}/${{ env.CMS_GPU_IMAGE_NAME }}
107+
subject-digest: ${{ steps.build_and_push_cms.outputs.digest }}
108+
push-to-registry: true
109+
110+
- name: Inspect the released image
111+
run: |
112+
docker pull ${{ env.REGISTRY }}/${{ env.CMS_GPU_IMAGE_NAME }}:${{ env.RELEASE_VERSION }}
113+
docker image inspect ${{ env.REGISTRY }}/${{ env.CMS_GPU_IMAGE_NAME }}:${{ env.RELEASE_VERSION }}

.github/workflows/release.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ jobs:
3737
- name: Install uv
3838
uses: astral-sh/setup-uv@v5
3939
with:
40-
version: "0.8.10"
41-
python-version: "3.10"
40+
version: "0.9.30"
41+
python-version: "3.11"
4242
- name: Install dependencies
4343
run: |
44-
uv sync --extra dev --extra docs --extra llm
44+
uv sync --lock --extra dev --extra docs --extra llm
4545
uv run python -m ensurepip
4646
- name: Run unit tests
4747
run: |

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ venv/
9191
ENV/
9292
env.bak/
9393
venv.bak/
94-
.env
9594

9695
# Spyder project settings
9796
.spyderproject

app/api/routers/generative.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from fastapi import APIRouter, Depends, Request, Body, Query
1111
from fastapi.encoders import jsonable_encoder
1212
from fastapi.responses import PlainTextResponse, StreamingResponse, JSONResponse
13-
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
13+
from starlette.status import (
14+
HTTP_200_OK,
15+
HTTP_400_BAD_REQUEST,
16+
HTTP_500_INTERNAL_SERVER_ERROR,
17+
HTTP_404_NOT_FOUND,
18+
)
1419
from app.domain import (
1520
Tags,
1621
TagsGenerative,
@@ -35,6 +40,7 @@
3540
PATH_CHAT_COMPLETIONS = "/v1/chat/completions"
3641
PATH_COMPLETIONS = "/v1/completions"
3742
PATH_EMBEDDINGS = "/v1/embeddings"
43+
PATH_MODELS = "/v1/models"
3844

3945
router = APIRouter()
4046
config = get_settings()
@@ -200,7 +206,12 @@ def generate_chat_completions(
200206
max_tokens = request_data.max_tokens
201207
temperature = request_data.temperature
202208
top_p = request_data.top_p
203-
stop_sequences = request_data.stop_sequences
209+
if isinstance(request_data.stop, str):
210+
stop_sequences = [request_data.stop]
211+
elif isinstance(request_data.stop, list):
212+
stop_sequences = request_data.stop
213+
else:
214+
stop_sequences = []
204215
tracking_id = tracking_id or str(uuid.uuid4())
205216

206217
if not messages:
@@ -337,12 +348,11 @@ def generate_text_completions(
337348
max_tokens = request_data.max_tokens
338349
temperature = request_data.temperature
339350
top_p = request_data.top_p
340-
stop = request_data.stop
341351

342-
if isinstance(stop, str):
343-
stop_sequences = [stop]
344-
elif isinstance(stop, list):
345-
stop_sequences = stop
352+
if isinstance(request_data.stop, str):
353+
stop_sequences = [request_data.stop]
354+
elif isinstance(request_data.stop, list):
355+
stop_sequences = request_data.stop
346356
else:
347357
stop_sequences = []
348358

@@ -534,6 +544,81 @@ def embed_texts(
534544
)
535545

536546

547+
@router.get(
548+
PATH_MODELS,
549+
tags=[Tags.OpenAICompatible],
550+
dependencies=[Depends(cms_globals.props.current_active_user)],
551+
description="List available models, similar to OpenAI's /v1/models endpoint",
552+
)
553+
def list_models(
554+
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
555+
) -> JSONResponse:
556+
"""
557+
Lists all available models, mimicking OpenAI's /v1/models endpoint.
558+
559+
Args:
560+
model_service (AbstractModelService): The model service dependency.
561+
562+
Returns:
563+
JSONResponse: A response containing the list of models.
564+
"""
565+
response = {
566+
"object": "list",
567+
"data": [
568+
{
569+
"id": model_service.model_name.replace(" ", "_"),
570+
"object": "model",
571+
"created": 0,
572+
"owned_by": "cms",
573+
}
574+
],
575+
}
576+
return JSONResponse(content=response)
577+
578+
579+
@router.get(
580+
PATH_MODELS + "/{model_name}",
581+
tags=[Tags.OpenAICompatible],
582+
dependencies=[Depends(cms_globals.props.current_active_user)],
583+
description="Get a specific model, similar to OpenAI's /v1/models/{model_id} endpoint",
584+
)
585+
def get_model(
586+
model_name: str,
587+
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
588+
) -> JSONResponse:
589+
"""
590+
Gets a specific model by ID, mimicking OpenAI's /v1/models/{model_id} endpoint.
591+
592+
Args:
593+
model_name (str): The model name to retrieve.
594+
model_service (AbstractModelService): The model service dependency.
595+
596+
Returns:
597+
JSONResponse: A response containing the model details.
598+
"""
599+
if model_name != model_service.model_name.replace(" ", "_"):
600+
error_response = {
601+
"error": {
602+
"message": f"The model `{model_name}` does not exist",
603+
"type": "invalid_request_error",
604+
"param": None,
605+
"code": "model_not_found",
606+
}
607+
}
608+
return JSONResponse(content=error_response, status_code=HTTP_404_NOT_FOUND
609+
)
610+
response = {
611+
"id": model_name,
612+
"object": "model",
613+
"created": 0,
614+
"owned_by": "cms",
615+
"permission": [],
616+
"root": model_name,
617+
"parent": None,
618+
}
619+
return JSONResponse(content=response)
620+
621+
537622
def _empty_prompt_error() -> Iterable[str]:
538623
yield "ERROR: No prompt text provided\n"
539624

app/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Settings(BaseSettings): # type: ignore
3838
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
3939
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
4040
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to
41+
ENABLE_SPDA_ATTN: str = "true" # if "true", attempt to use SPDA attention for HuggingFace LLM loading
4142
DEBUG: str = "false" # if "true", the debug mode is switched on
4243

4344
class Config:

app/domain.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ class OpenAIChatCompletionsRequest(BaseModel):
218218
model: str = Field(..., description="The name of the model used for generating the completion")
219219
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
220220
top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0)
221-
stop_sequences: Optional[List[str]] = Field(default=None, description="The list of sequences used to stop the generation")
221+
stop: Optional[Union[str, List[str]]] = Field(
222+
default=None,
223+
description="The single sequence or the list of sequences used to stop the generation",
224+
)
222225

223226

224227
class OpenAIChatCompletionsResponse(BaseModel):
@@ -242,7 +245,7 @@ class OpenAICompletionsRequest(BaseModel):
242245
top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0)
243246
stop: Optional[Union[str, List[str]]] = Field(
244247
default=None,
245-
description="The list of sequences used to stop the generation",
248+
description="The single sequence or the list of sequences used to stop the generation",
246249
)
247250

248251

app/envs/.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,8 @@ TRAINING_HF_TAGGING_SCHEME=flat
7979
# The comma-separated names of ontologies for MedCAT2 to map to
8080
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10
8181

82+
# If "true", attempt to use SPDA attention for Hugging Face LLM loading
83+
ENABLE_SPDA_ATTN=true
84+
8285
# If "true", the debug mode is switched on
8386
DEBUG=false

0 commit comments

Comments
 (0)