Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from .eval_case import SessionInput
from .eval_set import EvalSet
from .request_intercepter_plugin import _RequestIntercepterPlugin
from .simulation.user_simulator import BaseUserSimulatorConfig
from .simulation.user_simulator import Status as UserSimulatorStatus
from .simulation.user_simulator import UserSimulator
from .simulation.user_simulator_provider import UserSimulatorProvider
Expand Down Expand Up @@ -263,6 +264,7 @@ async def generate_responses(
agent_module_path: str,
repeat_num: int = 3,
agent_name: str = None,
user_simulator_config: Optional[BaseUserSimulatorConfig] = None,
) -> list[EvalCaseResponses]:
"""Returns evaluation responses for the given dataset and agent.

Expand All @@ -273,12 +275,19 @@ async def generate_responses(
usually done to remove uncertainty that a single run may bring.
agent_name: The name of the agent that should be evaluated. This is
usually the sub-agent.
user_simulator_config: Optional configuration for the user simulator.
Only relevant for eval cases that use a `conversation_scenario` (which
are driven by `LlmBackedUserSimulator`); ignored for static
conversations. Pass an `LlmBackedUserSimulatorConfig` to override the
user-simulation model, max invocations, or custom instructions.
"""
results = []

for eval_case in eval_set.eval_cases:
# assume only static conversations are needed
user_simulator = UserSimulatorProvider().provide(eval_case)
user_simulator = UserSimulatorProvider(
user_simulator_config=user_simulator_config
).provide(eval_case)

responses = []
for _ in range(repeat_num):
response_invocations = await EvaluationGenerator._process_query(
Expand Down
45 changes: 45 additions & 0 deletions tests/unittests/evaluation/test_evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

from google.adk.evaluation.app_details import AgentDetails
from google.adk.evaluation.app_details import AppDetails
from google.adk.evaluation.eval_case import EvalCase
from google.adk.evaluation.eval_set import EvalSet
from google.adk.evaluation.evaluation_generator import _LiveSession
from google.adk.evaluation.evaluation_generator import EvaluationGenerator
from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin
from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
from google.adk.evaluation.simulation.user_simulator import NextUserMessage
from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus
from google.adk.evaluation.simulation.user_simulator import UserSimulator
Expand Down Expand Up @@ -538,6 +541,48 @@ async def mock_generate_inferences_side_effect(
called_with_content = mock_generate_inferences.call_args.args[3]
assert called_with_content.parts[0].text == "message 1"


class TestGenerateResponses:
"""Test cases for EvaluationGenerator.generate_responses method."""

@pytest.mark.asyncio
async def test_generate_responses_forwards_llm_backed_user_simulator_config(
self, mocker
):
"""Tests that an LlmBackedUserSimulatorConfig is forwarded to the provider verbatim."""
mock_provider_cls = mocker.patch(
"google.adk.evaluation.evaluation_generator.UserSimulatorProvider"
)
mocker.patch(
"google.adk.evaluation.evaluation_generator.EvaluationGenerator._process_query",
new_callable=mocker.AsyncMock,
return_value=[],
)

user_simulator_config = LlmBackedUserSimulatorConfig(
model="test-model",
max_allowed_invocations=5,
)
eval_set = EvalSet(
eval_set_id="test_set",
eval_cases=[EvalCase(eval_id="case_0", conversation=[])],
)

await EvaluationGenerator.generate_responses(
eval_set=eval_set,
agent_module_path="some.agent.module",
repeat_num=1,
user_simulator_config=user_simulator_config,
)

mock_provider_cls.assert_called_once_with(
user_simulator_config=user_simulator_config
)
assert (
mock_provider_cls.call_args.kwargs["user_simulator_config"]
is user_simulator_config
)

@pytest.mark.asyncio
async def test_generates_inferences_with_user_simulator_live(
self, mocker, mock_runner, mock_session_service
Expand Down
Loading