diff --git a/cookbook/cds_discharge_summarizer_hf_chat.py b/cookbook/cds_discharge_summarizer_hf_chat.py index e1b6b528..3bb67805 100644 --- a/cookbook/cds_discharge_summarizer_hf_chat.py +++ b/cookbook/cds_discharge_summarizer_hf_chat.py @@ -1,25 +1,43 @@ +#!/usr/bin/env python3 +""" +Discharge Note Summarizer (LangChain + HuggingFace Chat) + +CDS Hooks service that summarises discharge notes using a HuggingFace +chat model via LangChain (DeepSeek R1 by default). + +Requirements: + pip install healthchain langchain-core langchain-huggingface python-dotenv + # HUGGINGFACEHUB_API_TOKEN env var required + +Run: + python cookbook/cds_discharge_summarizer_hf_chat.py + # POST /cds/cds-services/discharge-summarizer + # Docs at: http://localhost:8000/docs +""" + import os import getpass -from healthchain.gateway import HealthChainAPI, CDSHooksService -from healthchain.pipeline import SummarizationPipeline -from healthchain.models import CDSRequest, CDSResponse +from dotenv import load_dotenv from langchain_huggingface.llms import HuggingFaceEndpoint from langchain_huggingface import ChatHuggingFace from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser -from dotenv import load_dotenv +from healthchain.gateway import HealthChainAPI, CDSHooksService +from healthchain.pipeline import SummarizationPipeline +from healthchain.models import CDSRequest, CDSResponse load_dotenv() -if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): - os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass("Enter your token: ") +def create_chain(): + if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): + os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass( + "Enter your HuggingFace token: " + ) - -def create_summarization_chain(): hf = HuggingFaceEndpoint( repo_id="deepseek-ai/DeepSeek-R1-0528", task="text-generation", @@ -27,72 +45,56 @@ def create_summarization_chain(): do_sample=False, repetition_penalty=1.03, ) - model = ChatHuggingFace(llm=hf) - - template = """ - You are a discharge planning assistant for hospital operations. - Provide a concise, objective summary focusing on actionable items - for care coordination, including appointments, medications, and - follow-up instructions. Format as bullet points with no preamble.\n'''{text}''' - """ - prompt = PromptTemplate.from_template(template) - + prompt = PromptTemplate.from_template( + "You are a discharge planning assistant for hospital operations. " + "Provide a concise, objective summary focusing on actionable items " + "for care coordination, including appointments, medications, and " + "follow-up instructions. Format as bullet points with no preamble.\n'''{text}'''" + ) return prompt | model | StrOutputParser() -# Create the healthcare application -app = HealthChainAPI( - title="Discharge Note Summarizer", - description="AI-powered discharge note summarization service", -) - -chain = create_summarization_chain() -pipeline = SummarizationPipeline.load( - chain, source="langchain", template_path="templates/cds_card_template.json" -) - -# Create CDS Hooks service -cds = CDSHooksService() +def create_app() -> HealthChainAPI: + chain = create_chain() + pipeline = SummarizationPipeline.load( + chain, source="langchain", template_path="templates/cds_card_template.json" + ) + cds = CDSHooksService() + @cds.hook("encounter-discharge", id="discharge-summarizer") + def discharge_summarizer(request: CDSRequest) -> CDSResponse: + return pipeline.process_request(request) -@cds.hook("encounter-discharge", id="discharge-summarizer") -def discharge_summarizer(request: CDSRequest) -> CDSResponse: - result = pipeline.process_request(request) - return result + app = HealthChainAPI( + title="Discharge Note Summarizer", + description="AI-powered discharge note summarization service", + port=8000, + service_type="cds-hooks", + ) + app.register_service(cds, path="/cds") + return app -# Register the CDS service -app.register_service(cds, path="/cds") +app = create_app() if __name__ == "__main__": - import uvicorn import threading - from healthchain.sandbox import SandboxClient - # Start the API server in a separate thread - def start_api(): - uvicorn.run(app, port=8000) - - api_thread = threading.Thread(target=start_api, daemon=True) + api_thread = threading.Thread(target=app.run, daemon=True) api_thread.start() - # Create sandbox client and load test data client = SandboxClient( url="http://localhost:8000/cds/cds-services/discharge-summarizer", workflow="encounter-discharge", ) - # Load discharge notes from CSV client.load_free_text( csv_path="data/discharge_notes.csv", column_name="text", ) - # Send requests and get responses responses = client.send_requests() - - # Save results client.save_results("./output/") try: diff --git a/cookbook/cds_discharge_summarizer_hf_trf.py b/cookbook/cds_discharge_summarizer_hf_trf.py index 6d08332c..0e438105 100644 --- a/cookbook/cds_discharge_summarizer_hf_trf.py +++ b/cookbook/cds_discharge_summarizer_hf_trf.py @@ -1,71 +1,79 @@ +#!/usr/bin/env python3 +""" +Discharge Note Summarizer (Transformer) + +CDS Hooks service that summarises discharge notes using a fine-tuned +HuggingFace transformer model (PEGASUS). + +Requirements: + pip install healthchain transformers torch python-dotenv + # HUGGINGFACEHUB_API_TOKEN env var required + +Run: + python cookbook/cds_discharge_summarizer_hf_trf.py + # POST /cds/cds-services/discharge-summarizer + # Docs at: http://localhost:8000/docs +""" + import os import getpass +from dotenv import load_dotenv + from healthchain.gateway import HealthChainAPI, CDSHooksService from healthchain.pipeline import SummarizationPipeline from healthchain.models import CDSRequest, CDSResponse -from dotenv import load_dotenv - load_dotenv() -if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): - os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass("Enter your token: ") - - -# Create the healthcare application -app = HealthChainAPI( - title="Discharge Note Summarizer", - description="AI-powered discharge note summarization service", -) +def create_pipeline() -> SummarizationPipeline: + if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): + os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass( + "Enter your HuggingFace token: " + ) + return SummarizationPipeline.from_model_id( + "google/pegasus-xsum", source="huggingface", task="summarization" + ) -# Initialize pipeline -pipeline = SummarizationPipeline.from_model_id( - "google/pegasus-xsum", source="huggingface", task="summarization" -) -# Create CDS Hooks service -cds = CDSHooksService() +def create_app() -> HealthChainAPI: + pipeline = create_pipeline() + cds = CDSHooksService() + @cds.hook("encounter-discharge", id="discharge-summarizer") + def discharge_summarizer(request: CDSRequest) -> CDSResponse: + return pipeline.process_request(request) -@cds.hook("encounter-discharge", id="discharge-summarizer") -def discharge_summarizer(request: CDSRequest) -> CDSResponse: - result = pipeline.process_request(request) - return result + app = HealthChainAPI( + title="Discharge Note Summarizer", + description="AI-powered discharge note summarization service", + port=8000, + service_type="cds-hooks", + ) + app.register_service(cds, path="/cds") + return app -# Register the CDS service -app.register_service(cds, path="/cds") +app = create_app() if __name__ == "__main__": - import uvicorn import threading - from healthchain.sandbox import SandboxClient - # Start the API server in a separate thread - def start_api(): - uvicorn.run(app, port=8000) - - api_thread = threading.Thread(target=start_api, daemon=True) + api_thread = threading.Thread(target=app.run, daemon=True) api_thread.start() - # Create sandbox client and load test data client = SandboxClient( url="http://localhost:8000/cds/cds-services/discharge-summarizer", workflow="encounter-discharge", ) - # Load discharge notes from CSV client.load_free_text( csv_path="data/discharge_notes.csv", column_name="text", ) - # Send requests and get responses responses = client.send_requests() - - # Save results client.save_results("./output/") try: diff --git a/cookbook/multi_ehr_data_aggregation.py b/cookbook/multi_ehr_data_aggregation.py index 85050829..5068b7bb 100644 --- a/cookbook/multi_ehr_data_aggregation.py +++ b/cookbook/multi_ehr_data_aggregation.py @@ -13,7 +13,7 @@ - Cerner Open Sandbox: No auth needed Run: -- python data_aggregation.py + python cookbook/multi_ehr_data_aggregation.py """ from typing import List @@ -23,7 +23,7 @@ from healthchain.fhir.r4b import Bundle, Condition, Annotation from healthchain.gateway import FHIRGateway, HealthChainAPI -from healthchain.gateway.clients.fhir.base import FHIRAuthConfig +from healthchain.gateway.clients import FHIRAuthConfig from healthchain.pipeline import Pipeline from healthchain.io.containers import Document from healthchain.fhir import merge_bundles @@ -100,15 +100,16 @@ def get_unified_patient(patient_id: str, sources: List[str]) -> Bundle: return doc.fhir.bundle - app = HealthChainAPI() - app.register_gateway(gateway) + app = HealthChainAPI( + title="Multi-EHR Data Aggregation", + description="Aggregate patient data from multiple FHIR sources", + port=8888, + service_type="fhir-gateway", + ) + app.register_gateway(gateway, path="/fhir") return app if __name__ == "__main__": - import uvicorn - - app = create_app() - uvicorn.run(app, port=8888) - # Runs at: http://127.0.0.1:8888/ + create_app().run() diff --git a/cookbook/notereader_clinical_coding_fhir.py b/cookbook/notereader_clinical_coding_fhir.py index dd677e25..0c0c5595 100644 --- a/cookbook/notereader_clinical_coding_fhir.py +++ b/cookbook/notereader_clinical_coding_fhir.py @@ -4,13 +4,11 @@ Demonstrates FHIR-native pipelines, legacy system integration, and multi-source data handling. Requirements: -- pip install healthchain -- pip install scispacy -- pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz -- pip install python-dotenv + pip install healthchain scispacy python-dotenv + pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz Run: -- python notereader_clinical_coding_fhir.py # Demo and start server + python cookbook/notereader_clinical_coding_fhir.py """ import logging @@ -21,7 +19,7 @@ from healthchain.fhir import add_provenance_metadata from healthchain.gateway.api import HealthChainAPI from healthchain.gateway.fhir import FHIRGateway -from healthchain.gateway.clients.fhir.base import FHIRAuthConfig +from healthchain.gateway.clients import FHIRAuthConfig from healthchain.gateway.soap import NoteReaderService from healthchain.io import CdaAdapter, Document from healthchain.models import CdaRequest @@ -104,7 +102,12 @@ def ai_coding_workflow(request: CdaRequest): return cda_response # Register services - app = HealthChainAPI(title="Epic CDI Service with FHIR integration") + app = HealthChainAPI( + title="Epic CDI Service", + description="Clinical document intelligence with FHIR and NoteReader integration", + port=8000, + service_type="fhir-gateway", + ) app.register_gateway(fhir_gateway, path="/fhir") app.register_service(note_service, path="/notereader") @@ -117,18 +120,12 @@ def ai_coding_workflow(request: CdaRequest): if __name__ == "__main__": import threading - import uvicorn - from time import sleep from healthchain.sandbox import SandboxClient - # Start server - def run_server(): - uvicorn.run(app, port=8000, log_level="warning") - - server_thread = threading.Thread(target=run_server, daemon=True) + server_thread = threading.Thread(target=app.run, daemon=True) server_thread.start() - sleep(2) # Wait for startup + sleep(2) # Create sandbox client for testing client = SandboxClient( diff --git a/cookbook/sepsis_cds_hooks.py b/cookbook/sepsis_cds_hooks.py index ff70be67..cf1ea7a2 100644 --- a/cookbook/sepsis_cds_hooks.py +++ b/cookbook/sepsis_cds_hooks.py @@ -115,7 +115,12 @@ def sepsis_alert(request: CDSRequest) -> CDSResponse: return CDSResponse(cards=[]) - app = HealthChainAPI(title="Sepsis CDS Hooks") + app = HealthChainAPI( + title="Sepsis CDS Hooks", + description="Real-time sepsis risk alerts via CDS Hooks", + port=8000, + service_type="cds-hooks", + ) app.register_service(cds, path="/cds") return app @@ -126,15 +131,10 @@ def sepsis_alert(request: CDSRequest) -> CDSResponse: if __name__ == "__main__": import threading - import uvicorn from time import sleep from healthchain.sandbox import SandboxClient - # Start server - def run_server(): - uvicorn.run(app, port=8000, log_level="warning") - - server = threading.Thread(target=run_server, daemon=True) + server = threading.Thread(target=app.run, daemon=True) server.start() sleep(2) diff --git a/cookbook/sepsis_fhir_batch.py b/cookbook/sepsis_fhir_batch.py index 231406c7..69232500 100644 --- a/cookbook/sepsis_fhir_batch.py +++ b/cookbook/sepsis_fhir_batch.py @@ -5,6 +5,9 @@ Query patients from a FHIR server, batch run sepsis predictions, and write RiskAssessment resources back. Demonstrates real FHIR server integration. +Requirements: + pip install healthchain joblib xgboost python-dotenv + Setup: 1. Extract and upload demo patients: python scripts/extract_mimic_demo_patients.py --minimal --upload @@ -24,7 +27,7 @@ from healthchain.fhir.r4b import Patient, Observation, RiskAssessment from healthchain.gateway import HealthChainAPI, FHIRGateway -from healthchain.gateway.clients.fhir.base import FHIRAuthConfig +from healthchain.gateway.clients import FHIRAuthConfig from healthchain.fhir import merge_bundles from healthchain.io import Dataset from healthchain.pipeline import Pipeline @@ -159,7 +162,11 @@ def create_app(): gateway.add_source("epic", EPIC_URL) logger.info("✓ Epic configured") - app = HealthChainAPI(title="Sepsis Batch Screening") + app = HealthChainAPI( + title="Sepsis Batch Screening", + description="Batch sepsis risk screening against a live FHIR server", + service_type="fhir-gateway", + ) app.register_gateway(gateway, path="/fhir") return app, gateway diff --git a/healthchain/cli.py b/healthchain/cli.py index 3846a78c..aad282b9 100644 --- a/healthchain/cli.py +++ b/healthchain/cli.py @@ -168,8 +168,8 @@ def patient_view(request: CDSRequest) -> CDSResponse: import os from typing import List -from fhir.resources.bundle import Bundle -from fhir.resources.condition import Condition +from healthchain.fhir.r4b import Bundle +from healthchain.fhir.r4b import Condition from healthchain.gateway import FHIRGateway, HealthChainAPI from healthchain.fhir import merge_bundles @@ -272,6 +272,19 @@ def _make_healthchain_yaml(name: str, service_type: str) -> str: site: name: "" environment: development # development | staging | production + +# FHIR data sources — declare sources here, credentials stay in .env +# sources: +# medplum: +# env_prefix: MEDPLUM # reads MEDPLUM_CLIENT_ID, MEDPLUM_BASE_URL etc. +# epic: +# env_prefix: EPIC # reads EPIC_CLIENT_ID, EPIC_BASE_URL etc. + +# LLM provider (used by app.py or cookbooks via config.llm.to_langchain()) +# llm: +# provider: anthropic # anthropic | openai | google | huggingface +# model: claude-opus-4-6 +# max_tokens: 512 """ @@ -413,6 +426,7 @@ def sandbox_run( config = AppConfig.load() resolved_output = output or (config.data.output_dir if config else "./output") + resolved_from_path = from_path or (config.data.patients_dir if config else None) print(f"\n{_BOLD}{_CYAN}◆ Sandbox{_RST} {_DIM}{url}{_RST}") print(f" {_CYAN}workflow {_RST}{workflow}") @@ -423,10 +437,10 @@ def sandbox_run( print(f"\n{_RED}Error:{_RST} {e}") return - if from_path: - print(f"\n{_DIM}Loading from {from_path}...{_RST}") + if resolved_from_path: + print(f"\n{_DIM}Loading from {resolved_from_path}...{_RST}") try: - client.load_from_path(from_path) + client.load_from_path(resolved_from_path) except (FileNotFoundError, ValueError) as e: print(f"{_RED}Error loading data:{_RST} {e}") return @@ -526,14 +540,16 @@ def _section(s: str) -> str: auth_col = _GREEN if config.security.auth != "none" else _AMBER print(f"{_key('auth ')}{auth_col}{config.security.auth}{_RST}") tls_val = ( - _val_on("enabled") if config.security.tls.enabled else _val_off("disabled") + _val_on("enabled") if config.security.tls.enabled else f"{_DIM}disabled{_RST}" ) print(f"{_key('TLS ')}{tls_val}") origins = ", ".join(config.security.allowed_origins) print(f"{_key('origins ')}{_DIM}{origins}{_RST}") print(_section("Compliance")) - hipaa_val = _val_on("enabled") if config.compliance.hipaa else _val_off("disabled") + hipaa_val = ( + _val_on("enabled") if config.compliance.hipaa else f"{_DIM}disabled{_RST}" + ) print(f"{_key('HIPAA ')}{hipaa_val}") if config.compliance.hipaa: print(f"{_key('audit log ')}{_BOLD}{config.compliance.audit_log}{_RST}") @@ -546,6 +562,18 @@ def _section(s: str) -> str: else: print(f" {_DIM}disabled{_RST}") + if config.sources: + print(_section("Sources")) + for source_name, source in config.sources.items(): + print( + f"{_key(f'{source_name:<12}')}{_DIM}env_prefix={source.env_prefix}{_RST}" + ) + + if config.llm: + print(_section("LLM")) + print(f"{_key('provider ')}{config.llm.provider}") + print(f"{_key('model ')}{_BOLD}{config.llm.model}{_RST}") + print() diff --git a/healthchain/config/appconfig.py b/healthchain/config/appconfig.py index 9fa7bdde..d3da5eb7 100644 --- a/healthchain/config/appconfig.py +++ b/healthchain/config/appconfig.py @@ -8,7 +8,7 @@ import logging from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional import yaml from pydantic import BaseModel, field_validator @@ -18,6 +18,56 @@ _CONFIG_FILENAME = "healthchain.yaml" +class SourceConfig(BaseModel): + """A FHIR data source. Credentials are loaded from environment variables.""" + + env_prefix: str # e.g. "MEDPLUM" reads MEDPLUM_CLIENT_ID, MEDPLUM_BASE_URL etc. + + def to_fhir_auth_config(self): + """Instantiate FHIRAuthConfig by reading env vars for this source's prefix.""" + from healthchain.gateway.clients.fhir.base import FHIRAuthConfig + + return FHIRAuthConfig.from_env(self.env_prefix) + + +class LLMConfig(BaseModel): + """LLM provider settings. API key is read from the standard env var for each provider.""" + + provider: str = "anthropic" # anthropic | openai | google | huggingface + model: str = "claude-opus-4-6" + max_tokens: int = 512 + + @field_validator("provider") + @classmethod + def validate_provider(cls, v: str) -> str: + allowed = {"anthropic", "openai", "google", "huggingface"} + if v not in allowed: + raise ValueError(f"provider must be one of: {', '.join(sorted(allowed))}") + return v + + def to_langchain(self): + """Instantiate the configured LangChain chat model.""" + if self.provider == "anthropic": + from langchain_anthropic import ChatAnthropic + + return ChatAnthropic(model=self.model, max_tokens=self.max_tokens) + elif self.provider == "openai": + from langchain_openai import ChatOpenAI + + return ChatOpenAI(model=self.model, max_tokens=self.max_tokens) + elif self.provider == "google": + from langchain_google_genai import ChatGoogleGenerativeAI + + return ChatGoogleGenerativeAI( + model=self.model, max_output_tokens=self.max_tokens + ) + elif self.provider == "huggingface": + from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint + + llm = HuggingFaceEndpoint(repo_id=self.model, max_new_tokens=self.max_tokens) + return ChatHuggingFace(llm=llm) + + class ServiceConfig(BaseModel): type: str = "cds-hooks" port: int = 8000 @@ -84,6 +134,8 @@ class AppConfig(BaseModel): compliance: ComplianceConfig = ComplianceConfig() eval: EvalConfig = EvalConfig() site: SiteConfig = SiteConfig() + sources: Dict[str, SourceConfig] = {} + llm: Optional[LLMConfig] = None @classmethod def from_yaml(cls, path: Path) -> "AppConfig": diff --git a/healthchain/gateway/api/app.py b/healthchain/gateway/api/app.py index bfd74573..4d4d437c 100644 --- a/healthchain/gateway/api/app.py +++ b/healthchain/gateway/api/app.py @@ -6,6 +6,7 @@ """ import logging +import os import re from contextlib import asynccontextmanager @@ -104,6 +105,8 @@ def _print_startup_banner( gateways: dict, services: dict, docs_url: str, + port: int = 8000, + service_type: Optional[str] = None, config=None, config_path: Optional[str] = None, ) -> None: @@ -116,23 +119,30 @@ def _print_startup_banner( LOGO_COL = 38 # ── resolve status values from config or sensible defaults ── - svc_type = (config.service.type if config else None) or ( - list({**gateways, **services}.keys())[0] - if {**gateways, **services} - else "unknown" + svc_type = ( + service_type + or (config.service.type if config else None) + or ( + list({**gateways, **services}.keys())[0] + if {**gateways, **services} + else "unknown" + ) ) env = config.site.environment if config else "development" - port = str(config.service.port if config else 8000) + port = str(port) site = config.site.name if config else None auth = config.security.auth if config else "none" tls = config.security.tls.enabled if config else False hipaa = config.compliance.hipaa if config else False eval_enabled = config.eval.enabled if config else False eval_provider = config.eval.provider if config else "mlflow" + # Check registered gateways first, then fall back to env var presence fhir_configured = any( - hasattr(gw, "sources") and getattr(gw, "sources", None) + hasattr(gw, "connection_manager") + and gw.connection_manager + and getattr(gw.connection_manager, "sources", None) for gw in gateways.values() - ) + ) or any(k.endswith("_CLIENT_ID") for k in os.environ) status: list[str] = [ f"\033[1m\033[38;2;255;121;198m{title}\033[0m \033[2mv{version}\033[0m", @@ -237,6 +247,8 @@ def __init__( title: str = "HealthChain API", description: str = "Healthcare Integration API", version: str = "1.0.0", + port: Optional[int] = None, + service_type: Optional[str] = None, enable_cors: bool = True, enable_events: bool = True, event_dispatcher: Optional[EventDispatcher] = None, @@ -262,6 +274,10 @@ def __init__( **kwargs, ) + # Display metadata for banner (when running outside healthchain serve) + self._port = port + self._service_type = service_type + # Gateway and service registries self.gateways = {} self.services = {} @@ -597,12 +613,15 @@ async def _startup(self) -> None: from healthchain.config.appconfig import AppConfig config = AppConfig.load() + port = self._port or (config.service.port if config else 8000) _print_startup_banner( title=config.name if config else self.title, version=config.version if config else self.version, gateways=self.gateways, services=self.services, - docs_url=self.docs_url or "http://localhost:8000/docs", + docs_url=self.docs_url or "/docs", + port=port, + service_type=self._service_type, config=config, config_path="./healthchain.yaml" if config else None, ) @@ -616,6 +635,25 @@ async def _startup(self) -> None: except Exception as e: logger.warning(f"Failed to initialize {name}: {e}") + def run(self, host: str = "0.0.0.0", **kwargs) -> None: + """Run the application with uvicorn. + + Convenience wrapper for local development and cookbooks. For production, + use `healthchain serve` which reads healthchain.yaml for TLS, port, etc. + + Args: + host: Host to bind to (default: 0.0.0.0) + **kwargs: Passed through to uvicorn.run (e.g. reload=True, workers=4) + + Example: + app = HealthChainAPI(title="My App", port=8888) + app.run() + app.run(reload=True) # with hot reload + """ + import uvicorn + + uvicorn.run(self, host=host, port=self._port or 8000, **kwargs) + async def _shutdown(self) -> None: """Handle graceful shutdown.""" for name, component in {**self.services, **self.gateways}.items(): diff --git a/tests/config_manager/test_appconfig.py b/tests/config_manager/test_appconfig.py index 5ef01c52..dd13e084 100644 --- a/tests/config_manager/test_appconfig.py +++ b/tests/config_manager/test_appconfig.py @@ -2,7 +2,7 @@ import pytest -from healthchain.config.appconfig import AppConfig +from healthchain.config.appconfig import AppConfig, LLMConfig def test_appconfig_loads_valid_yaml(tmp_path): @@ -111,6 +111,79 @@ def test_appconfig_tls_config_parsed(tmp_path): assert config.security.tls.key_path == "./certs/key.pem" +def test_llmconfig_valid_providers(): + """LLMConfig accepts all supported providers.""" + for provider in ("anthropic", "openai", "google", "huggingface"): + config = LLMConfig(provider=provider) + assert config.provider == provider + + +def test_llmconfig_invalid_provider_raises(): + """LLMConfig raises ValidationError for unsupported providers.""" + with pytest.raises(Exception): + LLMConfig(provider="cohere") + + +def test_llmconfig_defaults(): + """LLMConfig has sensible defaults.""" + config = LLMConfig() + assert config.provider == "anthropic" + assert config.model == "claude-opus-4-6" + assert config.max_tokens == 512 + + +def test_appconfig_llm_parsed(tmp_path): + """AppConfig parses llm section into LLMConfig correctly.""" + (tmp_path / "healthchain.yaml").write_text( + """ +llm: + provider: openai + model: gpt-4o + max_tokens: 1024 +""" + ) + config = AppConfig.from_yaml(tmp_path / "healthchain.yaml") + + assert config.llm.provider == "openai" + assert config.llm.model == "gpt-4o" + assert config.llm.max_tokens == 1024 + + +def test_appconfig_llm_defaults_to_none(tmp_path): + """AppConfig.llm is None when not specified in healthchain.yaml.""" + (tmp_path / "healthchain.yaml").write_text("name: minimal-app\n") + config = AppConfig.from_yaml(tmp_path / "healthchain.yaml") + + assert config.llm is None + + +def test_appconfig_sources_parsed(tmp_path): + """AppConfig parses sources section into SourceConfig correctly.""" + (tmp_path / "healthchain.yaml").write_text( + """ +sources: + medplum: + env_prefix: MEDPLUM + epic: + env_prefix: EPIC +""" + ) + config = AppConfig.from_yaml(tmp_path / "healthchain.yaml") + + assert "medplum" in config.sources + assert config.sources["medplum"].env_prefix == "MEDPLUM" + assert "epic" in config.sources + assert config.sources["epic"].env_prefix == "EPIC" + + +def test_appconfig_sources_defaults_to_empty(tmp_path): + """AppConfig.sources is empty dict when not specified.""" + (tmp_path / "healthchain.yaml").write_text("name: minimal-app\n") + config = AppConfig.from_yaml(tmp_path / "healthchain.yaml") + + assert config.sources == {} + + def test_appconfig_eval_track_events_parsed(tmp_path): """AppConfig parses eval.track list correctly.""" config_file = tmp_path / "healthchain.yaml"