Skip to content

Commit ed620d7

Browse files
authored
Add MCP server to integrate CMS capabilities for LLM responses (#38)
* feat: add MCP server to integrate CMS capabilities for LLM responses
1 parent 042dd83 commit ed620d7

35 files changed

Lines changed: 3857 additions & 457 deletions

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
python-version: ${{ matrix.python-version }}
2929
- name: Install dependencies
3030
run: |
31-
uv sync --extra dev --extra docs --extra llm
31+
uv sync --extra dev --extra docs --extra llm --extra mcp
3232
uv run python -m ensurepip
3333
- name: Check types
3434
run: |

app/api/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,10 @@ async def init_vllm_engine(app: FastAPI,
347347
)
348348

349349
tokenizer = await engine.get_tokenizer()
350-
vllm_config = await engine.get_vllm_config()
351-
model_config = await engine.get_model_config()
350+
vllm_config = await engine.get_vllm_config() # type: ignore
351+
model_config = await engine.get_model_config() # type: ignore
352352

353-
await init_app_state(engine, vllm_config, app.state, args)
353+
await init_app_state(engine, vllm_config, app.state, args) # type: ignore
354354

355355
async def generate_text(
356356
request: Request,

app/cli/README.md

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ $ cms [OPTIONS] COMMAND [ARGS]...
1010

1111
**Options**:
1212

13+
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`
14+
* `--host TEXT`
15+
* `--port TEXT`
1316
* `--install-completion`: Install completion for the current shell.
1417
* `--show-completion`: Show completion for the current shell, to copy it or customize the installation.
1518
* `--help`: Show this message and exit.
@@ -24,6 +27,7 @@ $ cms [OPTIONS] COMMAND [ARGS]...
2427
* `export-openapi-spec`: This generates an API document for all...
2528
* `stream`: This groups various stream operations
2629
* `package`: This groups various package operations
30+
* `mcp`: Run the MCP server for accessing CMS...
2731

2832
## `cms serve`
2933

@@ -37,14 +41,17 @@ $ cms serve [OPTIONS]
3741

3842
**Options**:
3943

40-
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to serve [required]
41-
* `--model-path TEXT`: The file path to the model package
44+
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to serve [required]
45+
* `--model-path TEXT`: Either the file path to the local model package or the URL to the remote one
4246
* `--mlflow-model-uri models:/MODEL_NAME/ENV`: The URI of the MLflow model to serve
4347
* `--host TEXT`: The hostname of the server [default: 127.0.0.1]
4448
* `--port TEXT`: The port of the server [default: 8000]
4549
* `--model-name TEXT`: The string representation of the model name
4650
* `--streamable / --no-streamable`: Serve the streamable endpoints only [default: no-streamable]
4751
* `--device [default|cpu|cuda|mps]`: The device to serve the model on [default: default]
52+
* `--llm-engine [CMS|vLLM]`: The engine to use for text generation [default: CMS]
53+
* `--load-in-4bit / --no-load-in-4bit`: Load the model in 4-bit precision, used by 'huggingface_llm' models [default: no-load-in-4bit]
54+
* `--load-in-8bit / --no-load-in-8bit`: Load the model in 8-bit precision, used by 'huggingface_llm' models [default: no-load-in-8bit]
4855
* `--debug / --no-debug`: Run in the debug mode
4956
* `--help`: Show this message and exit.
5057

@@ -60,7 +67,7 @@ $ cms train [OPTIONS]
6067

6168
**Options**:
6269

63-
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to train [required]
70+
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to train [required]
6471
* `--base-model-path TEXT`: The file path to the base model package to be trained on
6572
* `--mlflow-model-uri models:/MODEL_NAME/ENV`: The URI of the MLflow model to train
6673
* `--training-type [supervised|unsupervised|meta_supervised]`: The type of training [required]
@@ -71,6 +78,8 @@ $ cms train [OPTIONS]
7178
* `--description TEXT`: The description of the training or change logs
7279
* `--model-name TEXT`: The string representation of the model name
7380
* `--device [default|cpu|cuda|mps]`: The device to train the model on [default: default]
81+
* `--load-in-4bit / --no-load-in-4bit`: Load the model in 4-bit precision, used by 'huggingface_llm' models [default: no-load-in-4bit]
82+
* `--load-in-8bit / --no-load-in-8bit`: Load the model in 8-bit precision, used by 'huggingface_llm' models [default: no-load-in-8bit]
7483
* `--debug / --no-debug`: Run in the debug mode
7584
* `--help`: Show this message and exit.
7685

@@ -86,7 +95,7 @@ $ cms register [OPTIONS]
8695

8796
**Options**:
8897

89-
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to register [required]
98+
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to register [required]
9099
* `--model-path TEXT`: The file path to the model package [required]
91100
* `--model-name TEXT`: The string representation of the registered model [required]
92101
* `--training-type [supervised|unsupervised|meta_supervised]`: The type of training the model went through
@@ -108,7 +117,7 @@ $ cms export-model-apis [OPTIONS]
108117

109118
**Options**:
110119

111-
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner]`: The type of the model to serve [required]
120+
* `--model-type [medcat_snomed|medcat_umls|medcat_icd10|medcat_opcs4|medcat_deid|anoncat|transformers_deid|huggingface_ner|huggingface_llm]`: The type of the model to serve [required]
112121
* `--add-training-apis / --no-add-training-apis`: Add training APIs to the doc [default: no-add-training-apis]
113122
* `--add-evaluation-apis / --no-add-evaluation-apis`: Add evaluation APIs to the doc [default: no-add-evaluation-apis]
114123
* `--add-previews-apis / --no-add-previews-apis`: Add preview APIs to the doc [default: no-add-previews-apis]
@@ -269,3 +278,47 @@ $ cms package hf-dataset [OPTIONS]
269278
* `--remove-cached / --no-remove-cached`: Whether to remove the downloaded cache after the dataset package is saved [default: no-remove-cached]
270279
* `--trust-remote-code / --no-trust-remote-code`: Whether to trust and use the remote script of the dataset [default: no-trust-remote-code]
271280
* `--help`: Show this message and exit.
281+
282+
## `cms mcp`
283+
284+
Run the MCP server for accessing CMS capabilities
285+
286+
**Usage**:
287+
288+
```console
289+
$ cms mcp [OPTIONS] COMMAND [ARGS]...
290+
```
291+
292+
**Options**:
293+
294+
* `--help`: Show this message and exit.
295+
296+
**Commands**:
297+
298+
* `run`: Run the MCP server for accessing CMS...
299+
300+
### `cms mcp run`
301+
302+
Run the MCP server for accessing CMS capabilities
303+
304+
**Usage**:
305+
306+
```console
307+
$ cms mcp run [OPTIONS]
308+
```
309+
310+
**Options**:
311+
312+
* `--host TEXT`: The hostname of the MCP server [default: 127.0.0.1]
313+
* `--port INTEGER`: The port of the MCP server [default: 8080]
314+
* `--transport TEXT`: The transport type (either 'stdio', 'sse' or 'http') [default: http]
315+
* `--cms-base-url TEXT`: The base URL of the CMS API [default: http://localhost:8000]
316+
* `--cms-api-key TEXT`: The API key for authenticating with the CMS API
317+
* `--mcp-api-keys TEXT`: Comma-separated API keys for authenticating MCP clients
318+
* `--cms-mcp-oauth-enabled / --no-cms-mcp-oauth-enabled`: Whether to enable OAuth2 authentication for MCP clients
319+
* `--github-client-id TEXT`: The GitHub OAuth2 client ID
320+
* `--github-client-secret TEXT`: The GitHub OAuth2 client secret
321+
* `--google-client-id TEXT`: The Google OAuth2 client ID
322+
* `--google-client-secret TEXT`: The Google OAuth2 client secret
323+
* `--debug / --no-debug`: Run in debug mode
324+
* `--help`: Show this message and exit.

app/cli/cli.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import uuid
88
import inspect
99
import warnings
10+
import multiprocessing
1011
import subprocess
1112

1213
current_frame = inspect.currentframe()
@@ -52,9 +53,11 @@
5253

5354
cmd_app = typer.Typer(name="cms", help="CLI for various CogStack ModelServe operations", add_completion=True)
5455
stream_app = typer.Typer(name="stream", help="This groups various stream operations", add_completion=True)
55-
cmd_app.add_typer(stream_app, name="stream")
56+
mcp_app = typer.Typer(name="mcp", help="Run the MCP server for accessing CMS capabilities", add_completion=True)
5657
package_app = typer.Typer(name="package", help="This groups various package operations", add_completion=True)
58+
cmd_app.add_typer(stream_app, name="stream")
5759
cmd_app.add_typer(package_app, name="package")
60+
cmd_app.add_typer(mcp_app, name="mcp")
5861
logging.config.fileConfig(os.path.join(parent_dir, "logging.ini"), disable_existing_loggers=False)
5962

6063
@cmd_app.command("serve", help="This serves various CogStack NLP models")
@@ -69,6 +72,7 @@ def serve_model(
6972
device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"),
7073
llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"),
7174
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
75+
load_in_8bit: Optional[bool] = typer.Option(False, help="Load the model in 8-bit precision, used by 'huggingface_llm' models"),
7276
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
7377
) -> None:
7478
"""
@@ -87,6 +91,7 @@ def serve_model(
8791
device (Device): The device to serve the model on. Defaults to Device.DEFAULT.
8892
llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS.
8993
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
94+
load_in_8bit (bool): Load the model in 8-bit precision, used by 'huggingface_llm' models. Defaults to False.
9095
debug (Optional[bool]): Run in debug mode if set to True.
9196
"""
9297

@@ -138,7 +143,7 @@ def serve_model(
138143
if model_path:
139144
model_service = model_service_dep()
140145
model_service.model_name = model_name
141-
model_service.init_model(load_in_4bit=load_in_4bit)
146+
model_service.init_model(load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit)
142147
cms_globals.model_manager_dep = ModelManagerDep(model_service)
143148
elif mlflow_model_uri:
144149
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
@@ -191,6 +196,7 @@ def train_model(
191196
model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"),
192197
device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"),
193198
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
199+
load_in_8bit: Optional[bool] = typer.Option(False, help="Load the model in 8-bit precision, used by 'huggingface_llm' models"),
194200
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
195201
) -> None:
196202
"""
@@ -211,6 +217,7 @@ def train_model(
211217
model_name (Optional[str]): The optional string representation of the model name.
212218
device (Device): The device to train the model on. Defaults to Device.DEFAULT.
213219
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
220+
load_in_8bit (bool): Load the model in 8-bit precision, used by 'huggingface_llm' models. Defaults to False.
214221
debug (Optional[bool]): Run in debug mode if set to True.
215222
"""
216223

@@ -232,7 +239,7 @@ def train_model(
232239
pass
233240
model_service = model_service_dep()
234241
model_service.model_name = model_name if model_name is not None else "CMS model"
235-
model_service.init_model(load_in_4bit=load_in_4bit)
242+
model_service.init_model(load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit)
236243
elif mlflow_model_uri:
237244
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
238245
model_service.model_name = model_name if model_name is not None else "CMS model"
@@ -495,6 +502,7 @@ def package_model(
495502

496503
model_package_archive = os.path.abspath(os.path.expanduser(output_model_package))
497504
if hf_repo_id:
505+
download_path = None
498506
try:
499507
with tempfile.TemporaryDirectory() as tmp_dir:
500508
if not hf_repo_revision:
@@ -510,15 +518,14 @@ def package_model(
510518
local_dir=tmp_dir,
511519
local_dir_use_symlinks=False,
512520
)
513-
514-
shutil.make_archive(model_package_archive, archive_format.value, download_path)
521+
_make_archive_file(model_package_archive, archive_format.value, download_path)
515522
finally:
516-
if remove_cached:
523+
if remove_cached and download_path:
517524
cached_model_path = os.path.abspath(os.path.join(download_path, "..", ".."))
518525
shutil.rmtree(cached_model_path)
519526
elif cached_model_dir:
520527
cached_model_path = os.path.abspath(os.path.expanduser(cached_model_dir))
521-
shutil.make_archive(model_package_archive, archive_format.value, cached_model_path)
528+
_make_archive_file(model_package_archive, archive_format.value, cached_model_path)
522529

523530
typer.echo(f"Model package saved to {model_package_archive}.{'zip' if archive_format == ArchiveFormat.ZIP else 'tar.gz'}")
524531

@@ -585,6 +592,73 @@ def package_dataset(
585592
typer.echo(f"Dataset package saved to {dataset_package_archive}.{'zip' if archive_format == ArchiveFormat.ZIP else 'tar.gz'}")
586593

587594

595+
@mcp_app.command("run", help="Run the MCP server for accessing CMS capabilities")
596+
def run_mcp_server(
597+
host: str = typer.Option("127.0.0.1", help="The hostname of the MCP server"),
598+
port: int = typer.Option(8080, help="The port of the MCP server"),
599+
transport: str = typer.Option("http", help="The transport type (either 'stdio', 'sse' or 'http')"),
600+
cms_base_url: str = typer.Option("http://127.0.0.1:8000", help="The base URL of the CMS API"),
601+
cms_api_key: str = typer.Option("Bearer", help="The API key for authenticating with the CMS API"),
602+
mcp_api_keys: str = typer.Option("", help="Comma-separated API keys for authenticating MCP clients"),
603+
cms_mcp_oauth_enabled: Optional[bool] = typer.Option(None, help="Whether to enable OAuth2 authentication for MCP clients"),
604+
github_client_id: str = typer.Option("", help="The GitHub OAuth2 client ID"),
605+
github_client_secret: str = typer.Option("", help="The GitHub OAuth2 client secret"),
606+
google_client_id: str = typer.Option("", help="The Google OAuth2 client ID"),
607+
google_client_secret: str = typer.Option("", help="The Google OAuth2 client secret"),
608+
debug: Optional[bool] = typer.Option(None, help="Run in debug mode"),
609+
) -> None:
610+
"""
611+
Runs the CogStack ModelServe MCP server.
612+
613+
This function starts an MCP server that provides AI assistants with tools to interact
614+
with deployed CMS models through the Model Context Protocol interface.
615+
616+
Args:
617+
host (str): The hostname of the MCP server. Defaults to "127.0.0.1".
618+
port (int): The port of the MCP server. Defaults to 8080.
619+
transport (str): The transport type for the MCP server. Can be "stdio" or "http". Defaults to "stdio".
620+
cms_base_url (str): The base URL of the CMS API endpoint. Defaults to "http://localhost:8000".
621+
debug (Optional[bool]): Run in debug mode if set to True.
622+
"""
623+
624+
logger = _get_logger(debug)
625+
logger.info("Starting CMS MCP server...")
626+
627+
os.environ["CMS_BASE_URL"] = cms_base_url
628+
os.environ["CMS_MCP_SERVER_HOST"] = host
629+
os.environ["CMS_MCP_SERVER_PORT"] = str(port)
630+
os.environ["CMS_MCP_TRANSPORT"] = transport.lower()
631+
os.environ["CMS_API_KEY"] = cms_api_key
632+
os.environ["MCP_API_KEYS"] = mcp_api_keys
633+
os.environ["CMS_MCP_OAUTH_ENABLED"] = "true" if cms_mcp_oauth_enabled else "false"
634+
os.environ["GITHUB_CLIENT_ID"] = github_client_id
635+
os.environ["GITHUB_CLIENT_SECRET"] = github_client_secret
636+
os.environ["GOOGLE_CLIENT_ID"] = google_client_id
637+
os.environ["GOOGLE_CLIENT_SECRET"] = google_client_secret
638+
639+
if debug:
640+
os.environ["CMS_MCP_DEV"] = "1"
641+
642+
try:
643+
from app.mcp.server import main
644+
logger.info(f"MCP server starting with transport: {transport}")
645+
logger.info(f"Connected to CMS API at {cms_base_url}")
646+
main()
647+
except ImportError as e:
648+
logger.error(f"Cannot import MCP. Please install it with `pip install cms[mcp]`: {e}")
649+
typer.echo(f"ERROR: Cannot import MCP: {e}")
650+
typer.echo("Please install it with `pip install cms[mcp]`.")
651+
raise typer.Exit(code=1)
652+
except KeyboardInterrupt:
653+
logger.info("MCP server stopped by the user")
654+
typer.echo("MCP server stopped.")
655+
raise typer.Exit(code=0)
656+
except Exception as e:
657+
logger.error(f"Failed to start MCP server: {e}")
658+
typer.echo(f"ERROR: Failed to start MCP server: {e}")
659+
raise typer.Exit(code=1)
660+
661+
588662
@cmd_app.command("build", help="This builds an OCI-compliant image to containerise CMS")
589663
def build_image(
590664
dockerfile_path: str = typer.Option(..., help="The path to the Dockerfile"),
@@ -798,6 +872,24 @@ def _ensure_dst_model_path(model_path: str, parent_dir: str, config: Settings) -
798872
return dst_model_path
799873

800874

875+
def _make_archive_file(base_name: str, format: str, root_dir: str) -> None:
876+
if format == ArchiveFormat.TAR_GZ.value:
877+
try:
878+
result = subprocess.run(["which", "pigz"], capture_output=True, text=True, check=True)
879+
if result.returncode == 0:
880+
num_cores = max(1, multiprocessing.cpu_count() - 1)
881+
compress_program = f"pigz -p {num_cores}"
882+
subprocess.run(
883+
["tar", f"--use-compress-program={compress_program}", "-cf", f"{base_name}.tar.gz", "-C", root_dir, "."],
884+
check=True
885+
)
886+
return
887+
except subprocess.CalledProcessError:
888+
typer.echo("Use non-parallel compression...")
889+
890+
shutil.make_archive(base_name, format, root_dir)
891+
892+
801893
def _get_logger(
802894
debug: Optional[bool] = None,
803895
model_type: Optional[ModelType] = None,

0 commit comments

Comments
 (0)