diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index d5a6629366..794c9cf259 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -59,6 +59,7 @@ from .eval_case import SessionInput from .eval_set import EvalSet from .request_intercepter_plugin import _RequestIntercepterPlugin +from .simulation.user_simulator import BaseUserSimulatorConfig from .simulation.user_simulator import Status as UserSimulatorStatus from .simulation.user_simulator import UserSimulator from .simulation.user_simulator_provider import UserSimulatorProvider @@ -263,6 +264,7 @@ async def generate_responses( agent_module_path: str, repeat_num: int = 3, agent_name: str = None, + user_simulator_config: Optional[BaseUserSimulatorConfig] = None, ) -> list[EvalCaseResponses]: """Returns evaluation responses for the given dataset and agent. @@ -273,12 +275,19 @@ async def generate_responses( usually done to remove uncertainty that a single run may bring. agent_name: The name of the agent that should be evaluated. This is usually the sub-agent. + user_simulator_config: Optional configuration for the user simulator. + Only relevant for eval cases that use a `conversation_scenario` (which + are driven by `LlmBackedUserSimulator`); ignored for static + conversations. Pass an `LlmBackedUserSimulatorConfig` to override the + user-simulation model, max invocations, or custom instructions. """ results = [] for eval_case in eval_set.eval_cases: - # assume only static conversations are needed - user_simulator = UserSimulatorProvider().provide(eval_case) + user_simulator = UserSimulatorProvider( + user_simulator_config=user_simulator_config + ).provide(eval_case) + responses = [] for _ in range(repeat_num): response_invocations = await EvaluationGenerator._process_query( diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 508b6f5c9c..7c0c6dff5b 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -18,9 +18,12 @@ from google.adk.evaluation.app_details import AgentDetails from google.adk.evaluation.app_details import AppDetails +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet from google.adk.evaluation.evaluation_generator import _LiveSession from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig from google.adk.evaluation.simulation.user_simulator import NextUserMessage from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus from google.adk.evaluation.simulation.user_simulator import UserSimulator @@ -538,6 +541,48 @@ async def mock_generate_inferences_side_effect( called_with_content = mock_generate_inferences.call_args.args[3] assert called_with_content.parts[0].text == "message 1" + +class TestGenerateResponses: + """Test cases for EvaluationGenerator.generate_responses method.""" + + @pytest.mark.asyncio + async def test_generate_responses_forwards_llm_backed_user_simulator_config( + self, mocker + ): + """Tests that an LlmBackedUserSimulatorConfig is forwarded to the provider verbatim.""" + mock_provider_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator.UserSimulatorProvider" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._process_query", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + user_simulator_config = LlmBackedUserSimulatorConfig( + model="test-model", + max_allowed_invocations=5, + ) + eval_set = EvalSet( + eval_set_id="test_set", + eval_cases=[EvalCase(eval_id="case_0", conversation=[])], + ) + + await EvaluationGenerator.generate_responses( + eval_set=eval_set, + agent_module_path="some.agent.module", + repeat_num=1, + user_simulator_config=user_simulator_config, + ) + + mock_provider_cls.assert_called_once_with( + user_simulator_config=user_simulator_config + ) + assert ( + mock_provider_cls.call_args.kwargs["user_simulator_config"] + is user_simulator_config + ) + @pytest.mark.asyncio async def test_generates_inferences_with_user_simulator_live( self, mocker, mock_runner, mock_session_service