diff --git a/examples/python/a2a_server.py b/examples/python/a2a_server.py index 1f113ae6..91f3e2e8 100644 --- a/examples/python/a2a_server.py +++ b/examples/python/a2a_server.py @@ -1,6 +1,7 @@ # /// script # dependencies = [ # "a2a-sdk[http-server]", +# "mlflow", # "openai", # "uvicorn", # ] @@ -13,13 +14,12 @@ import os import uuid -import json -import datetime -from pathlib import Path from typing import Dict, List, Any, Optional from fastapi import FastAPI, Request, HTTPException, APIRouter from fastapi.responses import JSONResponse from openai import OpenAI +import mlflow +import mlflow.openai from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue @@ -32,6 +32,28 @@ # Initialize OpenAI client openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +# MLflow tracing configuration +MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "timestep-a2a") +MLFLOW_TRACING_ENABLED = os.getenv("MLFLOW_TRACING_ENABLED", "true").lower() in {"1", "true", "yes"} +_MLFLOW_TRACING_CONFIGURED = False + + +def setup_mlflow_tracing() -> None: + """Configure MLflow tracing for OpenAI calls.""" + global _MLFLOW_TRACING_CONFIGURED + if _MLFLOW_TRACING_CONFIGURED or not MLFLOW_TRACING_ENABLED: + return + + tracking_uri = os.getenv("MLFLOW_TRACKING_URI") + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) + mlflow.openai.autolog() + _MLFLOW_TRACING_CONFIGURED = True + + +setup_mlflow_tracing() + # Agent IDs PERSONAL_ASSISTANT_ID = "00000000-0000-0000-0000-000000000000" WEATHER_ASSISTANT_ID = "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" @@ -148,35 +170,6 @@ def build_system_message(agent_id: str, tools: List[Dict[str, Any]]) -> str: # Track all task IDs per agent for listing agent_task_ids: Dict[str, List[str]] = {} -def write_trace(task_id: str, agent_id: str, input_messages: List[Dict], input_tools: List[Dict], output_message: Dict) -> None: - """Write trace to traces/ folder.""" - traces_dir = Path("/workspace/traces") - traces_dir.mkdir(exist_ok=True) - - timestamp = datetime.datetime.now().isoformat().replace(":", "-") - # Use short task_id for filename (first 8 chars) - task_id_short = task_id[:8] if task_id else "unknown" - agent_id_short = agent_id[:8] if agent_id else "unknown" - trace_file = traces_dir / f"{timestamp}_{task_id_short}_{agent_id_short}.json" - - trace = { - "task_id": task_id, - "agent_id": agent_id, - "timestamp": timestamp, - "input": { - "messages": input_messages, - "tools": input_tools, - }, - "output": { - "content": output_message.get("content", ""), - "tool_calls": output_message.get("tool_calls", []), - } - } - - with open(trace_file, "w") as f: - json.dump(trace, f, indent=2) - - class MultiAgentExecutor(AgentExecutor): """Agent executor that uses OpenAI directly and configures tools based on agent_id.""" @@ -291,29 +284,6 @@ async def execute( # Convert OpenAI response to A2A format assistant_content = assistant_message.content or "" - # Capture trace: input messages + output message - output_message_dict = { - "content": assistant_content, - "tool_calls": [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - for tc in tool_calls - ] if tool_calls else [], - } - write_trace( - task_id=task_id or "", - agent_id=self.agent_id, - input_messages=openai_messages_with_system, - input_tools=self.tools or [], - output_message=output_message_dict, - ) - # Build A2A message using helper function # Role.agent is the correct role for assistant messages in A2A a2a_message = create_text_message_object( diff --git a/examples/python/compose.yml b/examples/python/compose.yml index dab7ba62..48b0729c 100644 --- a/examples/python/compose.yml +++ b/examples/python/compose.yml @@ -10,6 +10,9 @@ services: - "8000:8000" environment: - OPENAI_API_KEY=${OPENAI_API_KEY} + - MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI} + - MLFLOW_EXPERIMENT_NAME=${MLFLOW_EXPERIMENT_NAME} + - MLFLOW_TRACING_ENABLED=${MLFLOW_TRACING_ENABLED} - UV_CACHE_DIR=/workspace/.cache/uv develop: watch: diff --git a/examples/python/test_client.py b/examples/python/test_client.py index d8bd968c..a3253c13 100644 --- a/examples/python/test_client.py +++ b/examples/python/test_client.py @@ -1,6 +1,8 @@ # /// script # dependencies = [ # "a2a-sdk", +# "mlflow", +# "pandas", # "mcp", # "httpx", # ] @@ -19,6 +21,8 @@ import datetime from pathlib import Path from typing import Dict, Any, List, Optional +import mlflow +import pandas as pd from a2a.client import ClientFactory, ClientConfig from a2a.client.helpers import create_text_message_object from a2a.types import TransportProtocol, Role @@ -35,6 +39,24 @@ PERSONAL_ASSISTANT_ID = "00000000-0000-0000-0000-000000000000" WEATHER_ASSISTANT_ID = "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" +# MLflow configuration +MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "timestep-evals") +MLFLOW_EVAL_ENABLED = os.getenv("MLFLOW_EVAL_ENABLED", "true").lower() in {"1", "true", "yes"} +_MLFLOW_CONFIGURED = False + + +def setup_mlflow() -> None: + """Configure MLflow tracking for evals.""" + global _MLFLOW_CONFIGURED + if _MLFLOW_CONFIGURED: + return + + tracking_uri = os.getenv("MLFLOW_TRACKING_URI") + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) + _MLFLOW_CONFIGURED = True + def write_task(task: Any, agent_id: str) -> None: """Write task to tasks/ folder in proper A2A Task format.""" @@ -135,6 +157,52 @@ def parse_tool_call(tool_call: Dict[str, Any]) -> tuple[Optional[str], Dict[str, return tool_name, tool_args +def run_mlflow_eval(prompt: str, response: str, agent_id: str, task_id: Optional[str]) -> None: + """Run MLflow evals and log results.""" + if not MLFLOW_EVAL_ENABLED: + return + + try: + setup_mlflow() + eval_df = pd.DataFrame( + [ + { + "inputs": prompt, + "predictions": response, + "targets": "", + } + ] + ) + run_name = f"eval-{agent_id[:8]}-{task_id[:8] if task_id else 'unknown'}" + + with mlflow.start_run(run_name=run_name): + mlflow.set_tags( + { + "a2a.agent_id": agent_id, + "a2a.task_id": task_id or "", + } + ) + mlflow.log_text(prompt, "prompt.txt") + mlflow.log_text(response, "response.txt") + mlflow.log_metric("response_length", float(len(response))) + + try: + from mlflow.metrics.genai import relevance + + mlflow.evaluate( + data=eval_df, + model_type="question-answering", + targets="targets", + predictions="predictions", + extra_metrics=[relevance()], + ) + except Exception as eval_error: + mlflow.log_param("eval_error", str(eval_error)) + print(f"[MLflow eval skipped: {eval_error}]", file=sys.stderr) + except Exception as e: + print(f"[MLflow eval setup failed: {e}]", file=sys.stderr) + + async def mcp_sampling_callback( context: RequestContext["ClientSession", Any], params: mcp_types.CreateMessageRequestParams, @@ -303,8 +371,9 @@ async def run_client_loop( message = create_text_message_object(role="user", content=initial_message) print(f"\n[DEBUG: Starting to send message to A2A server]", file=sys.stderr) - async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) -> None: - """Process message stream and print output.""" + async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) -> str: + """Process message stream, print output, and return final response.""" + final_message = "" async for event in a2a_client.send_message(message_obj): task = extract_task_from_event(event) print(f"\n[DEBUG: Received task, id={getattr(task, 'id', 'NO_ID')}, type={type(task)}]", file=sys.stderr) @@ -326,6 +395,7 @@ async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) if task.status.state.value == "completed": print("\n[Task completed]") + final_message = extract_final_message(task) break if task.status.state.value == "input-required": @@ -349,10 +419,17 @@ async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) tool_result_msg.context_id = task.context_id # Recursively process tool result - await process_with_output(a2a_client, tool_result_msg, agent_id) + result_message = await process_with_output(a2a_client, tool_result_msg, agent_id) + if result_message: + final_message = result_message break - - await process_with_output(a2a_client, message, agent_id) + + return final_message.strip() + + final_message = await process_with_output(a2a_client, message, agent_id) + if final_message: + task_id = task_ids[-1] if task_ids else None + run_mlflow_eval(initial_message, final_message, agent_id, task_id) except Exception as e: print(f"\n[Error in client loop: {e}]") raise