Skip to content

Commit fec32bf

Browse files
committed
api: Manage on-demand model config lifecycle
Provide an admin API to expose on-demand model configuration management, including creation, updating, retrieval, listing, and soft-deletion. Configuration are now stored in the database, with versioning support. The Python client is also updated to support these operations. Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
1 parent 107b34f commit fec32bf

27 files changed

Lines changed: 3625 additions & 690 deletions

client/cogstack_model_gateway_client/client.py

Lines changed: 424 additions & 2 deletions
Large diffs are not rendered by default.

cogstack_model_gateway/common/config/models.py

Lines changed: 31 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from typing import Any
23

34
from pydantic import BaseModel, Field, field_validator, model_validator
45

@@ -17,7 +18,6 @@ class ResourceLimits(BaseModel):
1718
cpus: str | None = Field(
1819
None,
1920
description="CPU limit as string (e.g., '2.0', '0.5')",
20-
gt=0,
2121
examples=["2.0", "1.5", "0.5"],
2222
)
2323

@@ -51,77 +51,29 @@ class DeployResources(BaseModel):
5151
class DeploySpec(BaseModel):
5252
"""Deployment specification for model containers.
5353
54-
Currently supports resource constraints, but can be extended to include
55-
other Docker Compose deploy options like restart_policy, labels, placement, etc.
54+
Mirrors Docker Compose deploy specification.
5655
"""
5756

5857
resources: DeployResources | None = Field(None, description="Resource limits and reservations")
5958

6059

61-
class OnDemandModel(BaseModel):
62-
"""Configuration for an on-demand model that can be auto-deployed."""
60+
class TrackingMetadata(BaseModel):
61+
"""Model metadata from MLflow tracking server.
6362
64-
service_name: str = Field(
65-
...,
66-
description="Docker service/container name for the model",
67-
examples=["medcat-snomed-large", "medcat-umls-small"],
68-
)
69-
model_uri: str = Field(
70-
...,
71-
description="URI pointing to the model artifact (e.g., MLflow model URI)",
72-
examples=[
73-
"s3://models/medcat/snomed_large_v1.0",
74-
"models:/medcat-snomed/Production",
75-
"runs:/abc123/model",
76-
],
77-
)
78-
idle_ttl: int | None = Field(
79-
None,
80-
description="Time in seconds after which an idle model is removed (overrides default)",
81-
gt=0,
82-
examples=[3600, 7200, 86400],
83-
)
84-
description: str | None = Field(
85-
None,
86-
description="Human-readable description of the model",
87-
examples=["Large SNOMED CT model for clinical NLP"],
88-
)
89-
deploy: DeploySpec = Field(
90-
default_factory=DeploySpec,
91-
description="Deployment specification including resource constraints",
92-
)
93-
94-
@field_validator("service_name")
95-
@classmethod
96-
def validate_service_name(cls, v: str) -> str:
97-
"""Validate service name follows Docker naming constraints.
98-
99-
Docker container names must:
100-
- Start with alphanumeric character
101-
- Contain only alphanumeric, underscore, period, or hyphen
102-
"""
103-
if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9_.-]*$", v):
104-
raise ValueError(
105-
f"Invalid service name: {v}. Must start with alphanumeric and contain "
106-
"only alphanumeric, underscore, period, or hyphen characters"
107-
)
108-
if len(v) > 255:
109-
raise ValueError(f"Service name too long: {v}. Maximum length is 255 characters")
110-
return v
63+
Dict representation of mlflow.models.ModelInfo, returned by TrackingClient.get_model_metadata().
64+
"""
11165

112-
@field_validator("model_uri")
113-
@classmethod
114-
def validate_model_uri(cls, v: str) -> str:
115-
"""Validate model URI format."""
116-
if not v or not v.strip():
117-
raise ValueError("Model URI cannot be empty")
118-
# Basic validation - just ensure it's not empty
119-
# More specific validation (s3://, models:/, runs:/) can be added if needed
120-
return v.strip()
66+
uuid: str = Field(..., description="Model UUID")
67+
run_id: str = Field(..., description="MLflow run ID that produced the model")
68+
artifact_path: str = Field(..., description="Path to model artifact within the run")
69+
signature: dict[str, Any] = Field(..., description="Model signature (inputs/outputs/params)")
70+
flavors: dict[str, Any] = Field(..., description="Model flavors (e.g. python_function)")
71+
utc_time_created: str = Field(..., description="UTC timestamp when model was created")
72+
mlflow_version: str = Field(..., description="MLflow version used to log the model")
12173

12274

123-
class AutoDeploymentConfig(BaseModel):
124-
"""Configuration for automatic model deployment behaviour."""
75+
class AutoDeployment(BaseModel):
76+
"""Auto-deployment configuration for on-demand models."""
12577

12678
health_check_timeout: int = Field(
12779
300,
@@ -147,40 +99,11 @@ class AutoDeploymentConfig(BaseModel):
14799
ge=0,
148100
examples=[0, 1, 2, 3],
149101
)
150-
151-
152-
class AutoDeployment(BaseModel):
153-
"""Auto-deployment configuration including behaviour and on-demand models."""
154-
155-
config: AutoDeploymentConfig = Field(
156-
default_factory=AutoDeploymentConfig,
157-
description="Auto-deployment behaviour configuration",
158-
)
159-
on_demand: list[OnDemandModel] = Field(
160-
default_factory=list,
161-
description="List of models available for on-demand deployment",
102+
require_model_uri_validation: bool = Field(
103+
False,
104+
description="Whether to validate that the model URI exists before creating configs",
162105
)
163106

164-
@field_validator("on_demand")
165-
@classmethod
166-
def validate_unique_service_names(cls, v: list[OnDemandModel]) -> list[OnDemandModel]:
167-
"""Ensure all service names are unique."""
168-
service_names = [model.service_name for model in v]
169-
duplicates = [name for name in service_names if service_names.count(name) > 1]
170-
if duplicates:
171-
raise ValueError(
172-
f"Duplicate service names found in on_demand models: {set(duplicates)}"
173-
)
174-
return v
175-
176-
@model_validator(mode="after")
177-
def apply_default_idle_ttl(self) -> "AutoDeployment":
178-
"""Apply default_idle_ttl to on-demand models that don't have an explicit idle_ttl."""
179-
for model in self.on_demand:
180-
if model.idle_ttl is None:
181-
model.idle_ttl = self.config.default_idle_ttl
182-
return self
183-
184107

185108
class ManualDeployment(BaseModel):
186109
"""Configuration for manual model deployments via POST /models API."""
@@ -228,6 +151,14 @@ class ModelsDeployment(BaseModel):
228151
default_factory=StaticDeployment,
229152
description="Configuration for static model management",
230153
)
154+
use_ip_addresses: bool = Field(
155+
False,
156+
description=(
157+
"Use IP addresses instead of container names when attempting to connect to CogStack"
158+
" Model Serve instances. Set to true when components (scheduler, tests) run outside"
159+
" Docker network."
160+
),
161+
)
231162

232163

233164
class ModelsConfig(BaseModel):
@@ -473,20 +404,13 @@ def default_tracking_from_cms(self) -> "Config":
473404
self.tracking = self.cms.tracking.model_copy(deep=True)
474405
return self
475406

476-
def get_on_demand_model(self, service_name: str) -> OnDemandModel | None:
477-
"""Get configuration for a specific on-demand model by service name."""
478-
for model in self.models.deployment.auto.on_demand:
479-
if model.service_name == service_name:
480-
return model
481-
return None
482-
483-
def list_on_demand_models(self) -> list[OnDemandModel]:
484-
"""Get list of all configured on-demand models."""
485-
return self.models.deployment.auto.on_demand
407+
def get_default_idle_ttl(self) -> int:
408+
"""Get the default idle TTL for auto-deployed on-demand models."""
409+
return self.models.deployment.auto.default_idle_ttl
486410

487-
def get_auto_deployment_config(self) -> AutoDeploymentConfig:
411+
def get_auto_deployment_config(self) -> AutoDeployment:
488412
"""Get auto-deployment behaviour configuration."""
489-
return self.models.deployment.auto.config
413+
return self.models.deployment.auto
490414

491415
def get_manual_deployment_config(self) -> ManualDeployment:
492416
"""Get manual deployment configuration."""
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,84 @@
1+
import logging
2+
3+
import docker
4+
from docker.models.containers import Container
5+
6+
from cogstack_model_gateway.common.config import get_config
7+
from cogstack_model_gateway.common.models import ModelDeploymentType
8+
19
PROJECT_NAME_LABEL = "com.docker.compose.project"
210
SERVICE_NAME_LABEL = "com.docker.compose.service"
11+
12+
log = logging.getLogger("cmg.common")
13+
14+
15+
def get_models(all: bool = False, managed_only: bool = False) -> list[dict]:
16+
"""Get model containers with filtering.
17+
18+
Args:
19+
all: If True, includes paused, restarting, and stopped containers.
20+
If False, only running (default).
21+
managed_only: If True, only CMG-managed containers (excludes 'static' deployments).
22+
23+
Returns:
24+
List of dicts with: service_name, model_uri, deployment_type, ip_address, container object.
25+
26+
Raises:
27+
docker.errors.APIError: If Docker API call fails.
28+
"""
29+
config = get_config()
30+
client = docker.from_env()
31+
32+
containers = client.containers.list(
33+
all=all,
34+
filters={
35+
"label": [
36+
config.labels.cms_model_label,
37+
f"{PROJECT_NAME_LABEL}={config.cms.project_name}",
38+
*(
39+
[f"{config.labels.managed_by_label}={config.labels.managed_by_value}"]
40+
if managed_only
41+
else []
42+
),
43+
]
44+
},
45+
)
46+
47+
return [
48+
{
49+
"service_name": c.labels.get(SERVICE_NAME_LABEL, c.name),
50+
"model_uri": c.labels.get(config.labels.cms_model_uri_label),
51+
"deployment_type": (
52+
c.labels.get(config.labels.deployment_type_label)
53+
or ModelDeploymentType.STATIC.value
54+
),
55+
"ip_address": c.attrs.get("NetworkSettings", {})
56+
.get("Networks", {})
57+
.get(config.cms.network, {})
58+
.get("IPAddress"),
59+
"container": c,
60+
}
61+
for c in containers
62+
]
63+
64+
65+
def stop_and_remove_model_container(container: Container) -> None:
66+
"""Stop and remove a model container using the Docker client.
67+
68+
Args:
69+
container: Docker container object to remove.
70+
71+
Raises:
72+
docker.errors.APIError: If container removal fails.
73+
"""
74+
log.info(
75+
f"Stopping and removing container '{container.name}'"
76+
f" (id={container.id}, status={container.status})"
77+
)
78+
79+
if container.status == "running":
80+
container.stop()
81+
log.debug(f"Container {container.name} stopped")
82+
83+
container.remove()
84+
log.debug(f"Successfully removed container: {container.name}")

cogstack_model_gateway/common/exceptions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,15 @@ def is_timeout_error(exception: Exception):
6161
retry=retry_if_exception(is_timeout_error),
6262
before_sleep=before_sleep_log(log, logging.WARNING),
6363
)
64+
65+
66+
class ConfigValidationError(ValueError):
67+
"""Raised when configuration validation fails (400 Bad Request)."""
68+
69+
pass
70+
71+
72+
class ConfigConflictError(ValueError):
73+
"""Raised when a configuration conflicts with an existing one (409 Conflict)."""
74+
75+
pass

0 commit comments

Comments
 (0)