Skip to content

Commit 72a1007

Browse files
fix: properly handle unset and zero history length (#717)
According to https://a2a-protocol.org/latest/specification/#324-history-length-semantics. It changes behavior so the fix was postponed till 1.0. After changing to proto passing `.history_length` would not work anymore due to the way how proto generated code works - optional values are still translated to language defaults to avoid `None`s, while presence should be checked via `HasField` - done in this PR. Fixes #573 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e71ac62 commit 72a1007

3 files changed

Lines changed: 103 additions & 17 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ async def on_get_task(
127127
if not task:
128128
raise ServerError(error=TaskNotFoundError())
129129

130-
# Apply historyLength parameter if specified
131-
return apply_history_length(task, params.history_length)
130+
return apply_history_length(task, params)
132131

133132
async def on_list_tasks(
134133
self,
@@ -141,7 +140,7 @@ async def on_list_tasks(
141140
if not params.include_artifacts:
142141
task.ClearField('artifacts')
143142

144-
updated_task = apply_history_length(task, params.history_length)
143+
updated_task = apply_history_length(task, params)
145144
if updated_task is not task:
146145
task.CopyFrom(updated_task)
147146

@@ -380,9 +379,7 @@ async def push_notification_callback() -> None:
380379
if isinstance(result, Task):
381380
self._validate_task_id_match(task_id, result.id)
382381
if params.configuration:
383-
result = apply_history_length(
384-
result, params.configuration.history_length
385-
)
382+
result = apply_history_length(result, params.configuration)
386383

387384
await self._send_push_notification_if_needed(task_id, result_aggregator)
388385

src/a2a/utils/task.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55

66
from base64 import b64decode, b64encode
7+
from typing import Literal, Protocol, runtime_checkable
78

89
from a2a.types.a2a_pb2 import (
910
Artifact,
@@ -81,27 +82,57 @@ def completed_task(
8182
)
8283

8384

84-
def apply_history_length(task: Task, history_length: int | None) -> Task:
85+
@runtime_checkable
86+
class HistoryLengthConfig(Protocol):
87+
"""Protocol for configuration arguments containing history_length field."""
88+
89+
history_length: int
90+
91+
def HasField(self, field_name: Literal['history_length']) -> bool: # noqa: N802 -- Protobuf generated code
92+
"""Checks if a field is set.
93+
94+
This method name matches the generated Protobuf code.
95+
"""
96+
...
97+
98+
99+
def apply_history_length(
100+
task: Task, config: HistoryLengthConfig | None
101+
) -> Task:
85102
"""Applies history_length parameter on task and returns a new task object.
86103
87104
Args:
88105
task: The original task object with complete history
89-
history_length: History length configuration value
106+
config: Configuration object containing 'history_length' field and HasField method.
90107
91108
Returns:
92109
A new task object with limited history
110+
111+
See Also:
112+
https://a2a-protocol.org/latest/specification/#324-history-length-semantics
93113
"""
94-
# Apply historyLength parameter if specified
95-
if history_length is not None and history_length > 0 and task.history:
96-
# Limit history to the most recent N messages
97-
limited_history = list(task.history[-history_length:])
98-
# Create a new task instance with limited history
114+
if config is None or not config.HasField('history_length'):
115+
return task
116+
117+
history_length = config.history_length
118+
119+
if history_length == 0:
120+
if not task.history:
121+
return task
99122
task_copy = Task()
100123
task_copy.CopyFrom(task)
101-
# Clear and re-add history items
102-
del task_copy.history[:]
103-
task_copy.history.extend(limited_history)
124+
task_copy.ClearField('history')
104125
return task_copy
126+
127+
if history_length > 0 and task.history:
128+
if len(task.history) <= history_length:
129+
return task
130+
131+
task_copy = Task()
132+
task_copy.CopyFrom(task)
133+
del task_copy.history[:-history_length]
134+
return task_copy
135+
105136
return task
106137

107138

tests/utils/test_task.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@
55

66
import pytest
77

8-
from a2a.types.a2a_pb2 import Artifact, Message, Part, Role, TaskState
8+
from a2a.types.a2a_pb2 import (
9+
Artifact,
10+
Message,
11+
Part,
12+
Role,
13+
TaskState,
14+
GetTaskRequest,
15+
SendMessageConfiguration,
16+
)
917
from a2a.utils.task import (
18+
apply_history_length,
1019
completed_task,
1120
decode_page_token,
1221
encode_page_token,
@@ -213,5 +222,54 @@ def test_decode_page_token_fails(self):
213222
)
214223

215224

225+
class TestApplyHistoryLength(unittest.TestCase):
226+
def setUp(self):
227+
self.history = [
228+
Message(
229+
message_id=str(i),
230+
role=Role.ROLE_USER,
231+
parts=[Part(text=f'msg {i}')],
232+
)
233+
for i in range(5)
234+
]
235+
artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])]
236+
self.task = completed_task(
237+
task_id='t1',
238+
context_id='c1',
239+
artifacts=artifacts,
240+
history=self.history,
241+
)
242+
243+
def test_none_config_returns_full_history(self):
244+
result = apply_history_length(self.task, None)
245+
self.assertEqual(len(result.history), 5)
246+
self.assertEqual(result.history, self.history)
247+
248+
def test_unset_history_length_returns_full_history(self):
249+
result = apply_history_length(self.task, GetTaskRequest())
250+
self.assertEqual(len(result.history), 5)
251+
self.assertEqual(result.history, self.history)
252+
253+
def test_positive_history_length_truncates(self):
254+
result = apply_history_length(
255+
self.task, GetTaskRequest(history_length=2)
256+
)
257+
self.assertEqual(len(result.history), 2)
258+
self.assertEqual(result.history, self.history[-2:])
259+
260+
def test_large_history_length_returns_full_history(self):
261+
result = apply_history_length(
262+
self.task, GetTaskRequest(history_length=10)
263+
)
264+
self.assertEqual(len(result.history), 5)
265+
self.assertEqual(result.history, self.history)
266+
267+
def test_zero_history_length_returns_empty_history(self):
268+
result = apply_history_length(
269+
self.task, SendMessageConfiguration(history_length=0)
270+
)
271+
self.assertEqual(len(result.history), 0)
272+
273+
216274
if __name__ == '__main__':
217275
unittest.main()

0 commit comments

Comments
 (0)