Skip to content

Commit c367c83

Browse files
authored
refactor: Enforce ServerCallContext in request handling (#882)
# Description - Make `ServerCallContext` a mandatory parameter across all `TaskStore` implementations (`TaskStore` interface, `DatabaseTaskStore`, and `InMemoryTaskStore`) and update tests. - Make `ServerCallContext` a mandatory parameter in `RequestContext`. Previously, context defaulted to None, which could allow callers to bypass authorization scoping if the context was set to None. By requiring the context, we guarantee that the `owner_resolver` always has the necessary request context to determine scope boundaries. Fixes #718 🦕
1 parent b85d3bb commit c367c83

18 files changed

Lines changed: 318 additions & 223 deletions

src/a2a/contrib/tasks/vertex_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def __init__(
4444
self._client = client
4545
self._agent_engine_resource_id = agent_engine_resource_id
4646

47-
async def save(
48-
self, task: Task, context: ServerCallContext | None = None
49-
) -> None:
47+
async def save(self, task: Task, context: ServerCallContext) -> None:
5048
"""Saves or updates a task in the store."""
5149
compat_task = to_compat_task(task)
5250
previous_task = await self._get_stored_task(compat_task.id)
@@ -206,7 +204,7 @@ async def _get_stored_task(
206204
return a2a_task
207205

208206
async def get(
209-
self, task_id: str, context: ServerCallContext | None = None
207+
self, task_id: str, context: ServerCallContext
210208
) -> Task | None:
211209
"""Retrieves a task from the database by ID."""
212210
a2a_task = await self._get_stored_task(task_id)
@@ -217,13 +215,11 @@ async def get(
217215
async def list(
218216
self,
219217
params: ListTasksRequest,
220-
context: ServerCallContext | None = None,
218+
context: ServerCallContext,
221219
) -> ListTasksResponse:
222220
"""Retrieves a list of tasks from the store."""
223221
raise NotImplementedError
224222

225-
async def delete(
226-
self, task_id: str, context: ServerCallContext | None = None
227-
) -> None:
223+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
228224
"""The backend doesn't support deleting tasks, so this is not implemented."""
229225
raise NotImplementedError

src/a2a/server/agent_execution/context.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,35 @@ class RequestContext:
2626

2727
def __init__( # noqa: PLR0913
2828
self,
29+
call_context: ServerCallContext,
2930
request: SendMessageRequest | None = None,
3031
task_id: str | None = None,
3132
context_id: str | None = None,
3233
task: Task | None = None,
3334
related_tasks: list[Task] | None = None,
34-
call_context: ServerCallContext | None = None,
3535
task_id_generator: IDGenerator | None = None,
3636
context_id_generator: IDGenerator | None = None,
3737
):
3838
"""Initializes the RequestContext.
3939
4040
Args:
41+
call_context: The server call context associated with this request.
4142
request: The incoming `SendMessageRequest` request payload.
4243
task_id: The ID of the task explicitly provided in the request or path.
4344
context_id: The ID of the context explicitly provided in the request or path.
4445
task: The existing `Task` object retrieved from the store, if any.
4546
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
46-
call_context: The server call context associated with this request.
4747
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
4848
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
4949
"""
5050
if related_tasks is None:
5151
related_tasks = []
52+
self._call_context = call_context
5253
self._params = request
5354
self._task_id = task_id
5455
self._context_id = context_id
5556
self._current_task = task
5657
self._related_tasks = related_tasks
57-
self._call_context = call_context
5858
self._task_id_generator = (
5959
task_id_generator if task_id_generator else UUIDGenerator()
6060
)
@@ -140,7 +140,7 @@ def configuration(self) -> SendMessageConfiguration | None:
140140
return self._params.configuration if self._params else None
141141

142142
@property
143-
def call_context(self) -> ServerCallContext | None:
143+
def call_context(self) -> ServerCallContext:
144144
"""The server call context associated with this request."""
145145
return self._call_context
146146

@@ -157,22 +157,17 @@ def add_activated_extension(self, uri: str) -> None:
157157
This causes the extension to be indicated back to the client in the
158158
response.
159159
"""
160-
if self._call_context:
161-
self._call_context.activated_extensions.add(uri)
160+
self._call_context.activated_extensions.add(uri)
162161

163162
@property
164163
def tenant(self) -> str:
165164
"""The tenant associated with this request."""
166-
return self._call_context.tenant if self._call_context else ''
165+
return self._call_context.tenant
167166

168167
@property
169168
def requested_extensions(self) -> set[str]:
170169
"""Extensions that the client requested to activate."""
171-
return (
172-
self._call_context.requested_extensions
173-
if self._call_context
174-
else set()
175-
)
170+
return self._call_context.requested_extensions
176171

177172
def _check_or_generate_task_id(self) -> None:
178173
"""Ensures a task ID is present, generating one if necessary."""

src/a2a/server/agent_execution/request_context_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ class RequestContextBuilder(ABC):
1111
@abstractmethod
1212
async def build(
1313
self,
14+
context: ServerCallContext,
1415
params: SendMessageRequest | None = None,
1516
task_id: str | None = None,
1617
context_id: str | None = None,
1718
task: Task | None = None,
18-
context: ServerCallContext | None = None,
1919
) -> RequestContext:
2020
pass

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def __init__(
3535

3636
async def build(
3737
self,
38+
context: ServerCallContext,
3839
params: SendMessageRequest | None = None,
3940
task_id: str | None = None,
4041
context_id: str | None = None,
4142
task: Task | None = None,
42-
context: ServerCallContext | None = None,
4343
) -> RequestContext:
4444
"""Builds the request context for an agent execution.
4545
@@ -48,11 +48,11 @@ async def build(
4848
referenced in `params.message.reference_task_ids` from the `task_store`.
4949
5050
Args:
51+
context: The server call context, containing metadata about the call.
5152
params: The parameters of the incoming message send request.
5253
task_id: The ID of the task being executed.
5354
context_id: The ID of the current execution context.
5455
task: The primary task object associated with the request.
55-
context: The server call context, containing metadata about the call.
5656
5757
Returns:
5858
An instance of RequestContext populated with the provided information
@@ -68,19 +68,19 @@ async def build(
6868
):
6969
tasks = await asyncio.gather(
7070
*[
71-
self._task_store.get(task_id)
71+
self._task_store.get(task_id, context)
7272
for task_id in params.message.reference_task_ids
7373
]
7474
)
7575
related_tasks = [x for x in tasks if x is not None]
7676

7777
return RequestContext(
78+
call_context=context,
7879
request=params,
7980
task_id=task_id,
8081
context_id=context_id,
8182
task=task,
8283
related_tasks=related_tasks,
83-
call_context=context,
8484
task_id_generator=self._task_id_generator,
8585
context_id_generator=self._context_id_generator,
8686
)

src/a2a/server/owner_resolver.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44

55

66
# Definition
7-
OwnerResolver = Callable[[ServerCallContext | None], str]
7+
OwnerResolver = Callable[[ServerCallContext], str]
88

99

1010
# Example Default Implementation
11-
def resolve_user_scope(context: ServerCallContext | None) -> str:
11+
def resolve_user_scope(context: ServerCallContext) -> str:
1212
"""Resolves the owner scope based on the user in the context."""
13-
if not context:
14-
return 'unknown'
15-
# Example: Basic user name. Adapt as needed for your user model.
1613
return context.user.user_name

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ async def on_cancel_task(
196196

197197
await self.agent_executor.cancel(
198198
RequestContext(
199-
None,
199+
call_context=context,
200+
request=None,
200201
task_id=task.id,
201202
context_id=task.context_id,
202203
task=task,
@@ -290,7 +291,7 @@ async def _setup_message_execution(
290291
await self._push_config_store.set_info(
291292
task_id,
292293
params.configuration.task_push_notification_config,
293-
context or ServerCallContext(),
294+
context,
294295
)
295296

296297
queue = await self._queue_manager.create_or_tap(task_id)
@@ -504,7 +505,7 @@ async def on_create_task_push_notification_config(
504505
await self._push_config_store.set_info(
505506
task_id,
506507
params,
507-
context or ServerCallContext(),
508+
context,
508509
)
509510

510511
return params
@@ -529,10 +530,7 @@ async def on_get_task_push_notification_config(
529530
raise TaskNotFoundError
530531

531532
push_notification_configs: list[TaskPushNotificationConfig] = (
532-
await self._push_config_store.get_info(
533-
task_id, context or ServerCallContext()
534-
)
535-
or []
533+
await self._push_config_store.get_info(task_id, context) or []
536534
)
537535

538536
for config in push_notification_configs:
@@ -603,7 +601,7 @@ async def on_list_task_push_notification_configs(
603601
raise TaskNotFoundError
604602

605603
push_notification_config_list = await self._push_config_store.get_info(
606-
task_id, context or ServerCallContext()
604+
task_id, context
607605
)
608606

609607
return ListTaskPushNotificationConfigsResponse(
@@ -629,6 +627,4 @@ async def on_delete_task_push_notification_config(
629627
if not task:
630628
raise TaskNotFoundError
631629

632-
await self._push_config_store.delete_info(
633-
task_id, context or ServerCallContext(), config_id
634-
)
630+
await self._push_config_store.delete_info(task_id, context, config_id)

src/a2a/server/tasks/copying_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ class CopyingTaskStoreAdapter(TaskStore):
2424
def __init__(self, underlying_store: TaskStore):
2525
self._store = underlying_store
2626

27-
async def save(
28-
self, task: Task, context: ServerCallContext | None = None
29-
) -> None:
27+
async def save(self, task: Task, context: ServerCallContext) -> None:
3028
"""Saves a copy of the task to the underlying store."""
3129
task_copy = Task()
3230
task_copy.CopyFrom(task)
3331
await self._store.save(task_copy, context)
3432

3533
async def get(
36-
self, task_id: str, context: ServerCallContext | None = None
34+
self, task_id: str, context: ServerCallContext
3735
) -> Task | None:
3836
"""Retrieves a task from the underlying store and returns a copy."""
3937
task = await self._store.get(task_id, context)
@@ -46,16 +44,14 @@ async def get(
4644
async def list(
4745
self,
4846
params: ListTasksRequest,
49-
context: ServerCallContext | None = None,
47+
context: ServerCallContext,
5048
) -> ListTasksResponse:
5149
"""Retrieves a list of tasks from the underlying store and returns a copy."""
5250
response = await self._store.list(params, context)
5351
response_copy = ListTasksResponse()
5452
response_copy.CopyFrom(response)
5553
return response_copy
5654

57-
async def delete(
58-
self, task_id: str, context: ServerCallContext | None = None
59-
) -> None:
55+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
6056
"""Deletes a task from the underlying store."""
6157
await self._store.delete(task_id, context)

src/a2a/server/tasks/database_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
169169
# Legacy conversion
170170
return compat_task_model_to_core(task_model)
171171

172-
async def save(
173-
self, task: Task, context: ServerCallContext | None = None
174-
) -> None:
172+
async def save(self, task: Task, context: ServerCallContext) -> None:
175173
"""Saves or updates a task in the database for the resolved owner."""
176174
await self._ensure_initialized()
177175
owner = self.owner_resolver(context)
@@ -185,7 +183,7 @@ async def save(
185183
)
186184

187185
async def get(
188-
self, task_id: str, context: ServerCallContext | None = None
186+
self, task_id: str, context: ServerCallContext
189187
) -> Task | None:
190188
"""Retrieves a task from the database by ID, for the given owner."""
191189
await self._ensure_initialized()
@@ -216,7 +214,7 @@ async def get(
216214
async def list(
217215
self,
218216
params: a2a_pb2.ListTasksRequest,
219-
context: ServerCallContext | None = None,
217+
context: ServerCallContext,
220218
) -> a2a_pb2.ListTasksResponse:
221219
"""Retrieves tasks from the database based on provided parameters, for the given owner."""
222220
await self._ensure_initialized()
@@ -315,9 +313,7 @@ async def list(
315313
page_size=page_size,
316314
)
317315

318-
async def delete(
319-
self, task_id: str, context: ServerCallContext | None = None
320-
) -> None:
316+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
321317
"""Deletes a task from the database by ID, for the given owner."""
322318
await self._ensure_initialized()
323319
owner = self.owner_resolver(context)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def __init__(
3535
def _get_owner_tasks(self, owner: str) -> dict[str, Task]:
3636
return self.tasks.get(owner, {})
3737

38-
async def save(
39-
self, task: Task, context: ServerCallContext | None = None
40-
) -> None:
38+
async def save(self, task: Task, context: ServerCallContext) -> None:
4139
"""Saves or updates a task in the in-memory store for the resolved owner."""
4240
owner = self.owner_resolver(context)
4341
if owner not in self.tasks:
@@ -50,7 +48,7 @@ async def save(
5048
)
5149

5250
async def get(
53-
self, task_id: str, context: ServerCallContext | None = None
51+
self, task_id: str, context: ServerCallContext
5452
) -> Task | None:
5553
"""Retrieves a task from the in-memory store by ID, for the given owner."""
5654
owner = self.owner_resolver(context)
@@ -77,7 +75,7 @@ async def get(
7775
async def list(
7876
self,
7977
params: a2a_pb2.ListTasksRequest,
80-
context: ServerCallContext | None = None,
78+
context: ServerCallContext,
8179
) -> a2a_pb2.ListTasksResponse:
8280
"""Retrieves a list of tasks from the store, for the given owner."""
8381
owner = self.owner_resolver(context)
@@ -156,9 +154,7 @@ async def list(
156154
page_size=page_size,
157155
)
158156

159-
async def delete(
160-
self, task_id: str, context: ServerCallContext | None = None
161-
) -> None:
157+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
162158
"""Deletes a task from the in-memory store by ID, for the given owner."""
163159
owner = self.owner_resolver(context)
164160
async with self.lock:
@@ -211,28 +207,24 @@ def __init__(
211207
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
212208
)
213209

214-
async def save(
215-
self, task: Task, context: ServerCallContext | None = None
216-
) -> None:
210+
async def save(self, task: Task, context: ServerCallContext) -> None:
217211
"""Saves or updates a task in the store."""
218212
await self._store.save(task, context)
219213

220214
async def get(
221-
self, task_id: str, context: ServerCallContext | None = None
215+
self, task_id: str, context: ServerCallContext
222216
) -> Task | None:
223217
"""Retrieves a task from the store by ID."""
224218
return await self._store.get(task_id, context)
225219

226220
async def list(
227221
self,
228222
params: a2a_pb2.ListTasksRequest,
229-
context: ServerCallContext | None = None,
223+
context: ServerCallContext,
230224
) -> a2a_pb2.ListTasksResponse:
231225
"""Retrieves a list of tasks from the store."""
232226
return await self._store.list(params, context)
233227

234-
async def delete(
235-
self, task_id: str, context: ServerCallContext | None = None
236-
) -> None:
228+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
237229
"""Deletes a task from the store by ID."""
238230
await self._store.delete(task_id, context)

0 commit comments

Comments
 (0)