diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9be71784..7517b816 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,5 +16,6 @@ repos: rev: v1.15.0 hooks: - id: mypy + additional_dependencies: [httpx, pydantic] files: ^(examples/|src/mistralai/|packages/(azure|gcp)/src/mistralai/).*\.py$ exclude: ^src/mistralai/(__init__|sdkhooks|types)\.py$ diff --git a/tests/test_azure_integration.py b/tests/test_azure_integration.py index ac4e38a1..ad6da971 100644 --- a/tests/test_azure_integration.py +++ b/tests/test_azure_integration.py @@ -4,28 +4,23 @@ These tests require credentials and make real API calls. Skip if AZURE_API_KEY env var is not set. -Prerequisites: - 1. Azure API key (stored in Bitwarden at "[MaaS] - Azure Foundry API key") - 2. Tailscale connected via gw-0 exit node - Usage: - AZURE_API_KEY=xxx pytest tests/test_azure_integration.py -v + AZURE_API_KEY=xxx AZURE_SERVER_URL=https://.services.ai.azure.com AZURE_MODEL= AZURE_OCR_MODEL= pytest tests/test_azure_integration.py -v Environment variables: AZURE_API_KEY: API key (required) - AZURE_ENDPOINT: Base URL (default: https://maas-qa-aifoundry.services.ai.azure.com/models) - AZURE_MODEL: Model name (default: maas-qa-ministral-3b) + AZURE_SERVER_URL: Base host URL (required, e.g. https://.services.ai.azure.com) + AZURE_MODEL: Chat model name (required) + AZURE_OCR_MODEL: OCR model name (required) AZURE_API_VERSION: API version (default: 2024-05-01-preview) -Note: AZURE_ENDPOINT should be the base URL without path suffixes. -The SDK appends /chat/completions to this URL. The api_version parameter -is automatically injected as a query parameter by the SDK. - -Available models: - Chat: maas-qa-ministral-3b, maas-qa-mistral-large-3, maas-qa-mistral-medium-2505 - OCR: maas-qa-mistral-document-ai-2505, maas-qa-mistral-document-ai-2512 - (OCR uses a separate endpoint, not tested here) +Note: AZURE_SERVER_URL should be the base host URL without any path suffix. +The SDK appends the correct path per operation type: + - Chat: /models/chat/completions + - OCR: /providers/mistral/azure/ocr +The api_version parameter is automatically injected as a query parameter. """ +import base64 import json import os @@ -33,18 +28,16 @@ # Configuration from env vars AZURE_API_KEY = os.environ.get("AZURE_API_KEY") -AZURE_ENDPOINT = os.environ.get( - "AZURE_ENDPOINT", - "https://maas-qa-aifoundry.services.ai.azure.com/models", -) -AZURE_MODEL = os.environ.get("AZURE_MODEL", "maas-qa-ministral-3b") +AZURE_SERVER_URL = os.environ.get("AZURE_SERVER_URL") +AZURE_MODEL = os.environ.get("AZURE_MODEL") +AZURE_OCR_MODEL = os.environ.get("AZURE_OCR_MODEL") AZURE_API_VERSION = os.environ.get("AZURE_API_VERSION", "2024-05-01-preview") -SKIP_REASON = "AZURE_API_KEY env var required" +SKIP_REASON = "Required env vars: AZURE_API_KEY, AZURE_SERVER_URL, AZURE_MODEL, AZURE_OCR_MODEL" pytestmark = pytest.mark.skipif( - not AZURE_API_KEY, - reason=SKIP_REASON + not all([AZURE_API_KEY, AZURE_SERVER_URL, AZURE_MODEL, AZURE_OCR_MODEL]), + reason=SKIP_REASON, ) # Shared tool definition for tool-call tests @@ -61,15 +54,23 @@ }, } +# Minimal valid PDF for OCR tests (single blank page) +MINIMAL_PDF = ( + b"%PDF-1.0\n1 0 obj<>endobj\n" + b"2 0 obj<>endobj\n" + b"3 0 obj<>endobj\n" + b"trailer<>" +) + @pytest.fixture def azure_client(): - """Create an Azure client with api_version parameter.""" + """Create an Azure client for Foundry Resource endpoints.""" from mistralai.azure.client import MistralAzure assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set" return MistralAzure( api_key=AZURE_API_KEY, - server_url=AZURE_ENDPOINT, + server_url=AZURE_SERVER_URL, api_version=AZURE_API_VERSION, ) @@ -323,6 +324,37 @@ def test_stream_tool_call(self, azure_client): assert tool_call_found, "Expected tool_call delta chunks in stream" +class TestAzureOcr: + """Test OCR endpoint.""" + + def test_basic_ocr(self, azure_client): + """Test OCR processes a document and returns pages.""" + encoded = base64.b64encode(MINIMAL_PDF).decode("utf-8") + res = azure_client.ocr.process( + model=AZURE_OCR_MODEL, + document={ + "type": "document_url", + "document_url": f"data:application/pdf;base64,{encoded}", + }, + ) + assert res is not None + assert res.pages is not None + + @pytest.mark.asyncio + async def test_basic_ocr_async(self, azure_client): + """Test async OCR processes a document and returns pages.""" + encoded = base64.b64encode(MINIMAL_PDF).decode("utf-8") + res = await azure_client.ocr.process_async( + model=AZURE_OCR_MODEL, + document={ + "type": "document_url", + "document_url": f"data:application/pdf;base64,{encoded}", + }, + ) + assert res is not None + assert res.pages is not None + + class TestAzureChatCompleteAsync: """Test async chat completion.""" @@ -401,7 +433,7 @@ def test_sync_context_manager(self): assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set" with MistralAzure( api_key=AZURE_API_KEY, - server_url=AZURE_ENDPOINT, + server_url=AZURE_SERVER_URL, api_version=AZURE_API_VERSION, ) as client: res = client.chat.complete( @@ -420,7 +452,7 @@ async def test_async_context_manager(self): assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set" async with MistralAzure( api_key=AZURE_API_KEY, - server_url=AZURE_ENDPOINT, + server_url=AZURE_SERVER_URL, api_version=AZURE_API_VERSION, ) as client: res = await client.chat.complete_async( diff --git a/tests/test_gcp_integration.py b/tests/test_gcp_integration.py index fe24b8b0..1ed2fecc 100644 --- a/tests/test_gcp_integration.py +++ b/tests/test_gcp_integration.py @@ -6,7 +6,7 @@ Prerequisites: 1. Authenticate with GCP: gcloud auth application-default login - 2. Have "Vertex AI User" role on the project (e.g. model-garden-420509) + 2. Have "Vertex AI User" role on the project The SDK automatically: - Detects credentials via google.auth.default() @@ -19,7 +19,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral Usage: - GCP_PROJECT_ID=model-garden-420509 pytest tests/test_gcp_integration.py -v + GCP_PROJECT_ID= pytest tests/test_gcp_integration.py -v Environment variables: GCP_PROJECT_ID: GCP project ID (required, or auto-detected from credentials)