Skip to content

Commit 5e0dcd7

Browse files
authored
feat: Add support for more Task Message and Artifact fields in the Vertex Task Store (a2aproject#908)
Add support for the following fields: * Part metadata * Artifact extensions, display_name, description * Message extensions, reference_task_ids * Parts of DataPart are now restored to their original type when read back * Add support for status detail messages in task updates For a2aproject#751
1 parent 8672785 commit 5e0dcd7

6 files changed

Lines changed: 373 additions & 34 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
datapart

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,35 @@
1111
import base64
1212
import json
1313

14+
from dataclasses import dataclass
15+
from typing import Any
16+
1417
from a2a.types import (
1518
Artifact,
1619
DataPart,
1720
FilePart,
1821
FileWithBytes,
1922
FileWithUri,
23+
Message,
2024
Part,
25+
Role,
2126
Task,
2227
TaskState,
2328
TaskStatus,
2429
TextPart,
2530
)
2631

2732

33+
_ORIGINAL_METADATA_KEY = 'originalMetadata'
34+
_EXTENSIONS_KEY = 'extensions'
35+
_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds'
36+
_PART_METADATA_KEY = 'partMetadata'
37+
_METADATA_VERSION_KEY = '__vertex_compat_v'
38+
_METADATA_VERSION_NUMBER = 1.0
39+
40+
_DATA_PART_MIME_TYPE = 'application/x-a2a-datapart'
41+
42+
2843
_TO_SDK_TASK_STATE = {
2944
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
3045
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
@@ -52,6 +67,55 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
5267
)
5368

5469

70+
def to_stored_metadata(
71+
original_metadata: dict[str, Any] | None,
72+
extensions: list[str] | None,
73+
reference_task_ids: list[str] | None,
74+
parts: list[Part],
75+
) -> dict[str, Any]:
76+
"""Packs original metadata, extensions, and part types/metadata into a storage dictionary."""
77+
metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER}
78+
if original_metadata:
79+
metadata[_ORIGINAL_METADATA_KEY] = original_metadata
80+
if extensions:
81+
metadata[_EXTENSIONS_KEY] = extensions
82+
if reference_task_ids:
83+
metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids
84+
85+
metadata[_PART_METADATA_KEY] = [part.root.metadata for part in parts]
86+
87+
return metadata
88+
89+
90+
@dataclass
91+
class _UnpackedMetadata:
92+
original_metadata: dict[str, Any] | None = None
93+
extensions: list[str] | None = None
94+
reference_task_ids: list[str] | None = None
95+
part_metadata: list[dict[str, Any] | None] | None = None
96+
97+
98+
def to_sdk_metadata(
99+
stored_metadata: dict[str, Any] | None,
100+
) -> _UnpackedMetadata:
101+
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
102+
if not stored_metadata:
103+
return _UnpackedMetadata()
104+
105+
version = stored_metadata.get(_METADATA_VERSION_KEY)
106+
if version is None:
107+
return _UnpackedMetadata(original_metadata=stored_metadata)
108+
if version > _METADATA_VERSION_NUMBER:
109+
raise ValueError(f'Unsupported metadata version: {version}')
110+
111+
return _UnpackedMetadata(
112+
original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY),
113+
extensions=stored_metadata.get(_EXTENSIONS_KEY),
114+
reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
115+
part_metadata=stored_metadata.get(_PART_METADATA_KEY),
116+
)
117+
118+
55119
def to_stored_part(part: Part) -> genai_types.Part:
56120
"""Converts a SDK Part to a proto Part."""
57121
if isinstance(part.root, TextPart):
@@ -60,7 +124,7 @@ def to_stored_part(part: Part) -> genai_types.Part:
60124
data_bytes = json.dumps(part.root.data).encode('utf-8')
61125
return genai_types.Part(
62126
inline_data=genai_types.Blob(
63-
mime_type='application/json', data=data_bytes
127+
mime_type=_DATA_PART_MIME_TYPE, data=data_bytes
64128
)
65129
)
66130
if isinstance(part.root, FilePart):
@@ -82,29 +146,41 @@ def to_stored_part(part: Part) -> genai_types.Part:
82146
raise ValueError(f'Unsupported part type: {type(part.root)}')
83147

84148

85-
def to_sdk_part(stored_part: genai_types.Part) -> Part:
149+
def to_sdk_part(
150+
stored_part: genai_types.Part,
151+
part_metadata: dict[str, Any] | None = None,
152+
) -> Part:
86153
"""Converts a proto Part to a SDK Part."""
87154
if stored_part.text:
88-
return Part(root=TextPart(text=stored_part.text))
155+
return Part(
156+
root=TextPart(text=stored_part.text, metadata=part_metadata)
157+
)
89158
if stored_part.inline_data:
159+
mime_type = stored_part.inline_data.mime_type
160+
if mime_type == _DATA_PART_MIME_TYPE:
161+
data_dict = json.loads(stored_part.inline_data.data or b'{}')
162+
return Part(root=DataPart(data=data_dict, metadata=part_metadata))
163+
90164
encoded_bytes = base64.b64encode(
91165
stored_part.inline_data.data or b''
92166
).decode('utf-8')
93167
return Part(
94168
root=FilePart(
95169
file=FileWithBytes(
96-
mime_type=stored_part.inline_data.mime_type,
170+
mime_type=mime_type,
97171
bytes=encoded_bytes,
98-
)
172+
),
173+
metadata=part_metadata,
99174
)
100175
)
101176
if stored_part.file_data:
102177
return Part(
103178
root=FilePart(
104179
file=FileWithUri(
105180
mime_type=stored_part.file_data.mime_type,
106-
uri=stored_part.file_data.file_uri,
107-
)
181+
uri=stored_part.file_data.file_uri or '',
182+
),
183+
metadata=part_metadata,
108184
)
109185
)
110186

@@ -115,15 +191,83 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
115191
"""Converts a SDK Artifact to a proto TaskArtifact."""
116192
return vertexai_types.TaskArtifact(
117193
artifact_id=artifact.artifact_id,
194+
display_name=artifact.name,
195+
description=artifact.description,
118196
parts=[to_stored_part(part) for part in artifact.parts],
197+
metadata=to_stored_metadata(
198+
original_metadata=artifact.metadata,
199+
extensions=artifact.extensions,
200+
reference_task_ids=None,
201+
parts=artifact.parts,
202+
),
119203
)
120204

121205

122206
def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
123207
"""Converts a proto TaskArtifact to a SDK Artifact."""
208+
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
209+
part_metadata_list = unpacked_meta.part_metadata or []
210+
211+
parts = []
212+
for i, part in enumerate(stored_artifact.parts or []):
213+
meta: dict[str, Any] | None = None
214+
if i < len(part_metadata_list):
215+
meta = part_metadata_list[i]
216+
parts.append(to_sdk_part(part, part_metadata=meta))
217+
124218
return Artifact(
125219
artifact_id=stored_artifact.artifact_id,
126-
parts=[to_sdk_part(part) for part in stored_artifact.parts],
220+
name=stored_artifact.display_name,
221+
description=stored_artifact.description,
222+
extensions=unpacked_meta.extensions,
223+
metadata=unpacked_meta.original_metadata,
224+
parts=parts,
225+
)
226+
227+
228+
def to_stored_message(
229+
message: Message | None,
230+
) -> vertexai_types.TaskMessage | None:
231+
"""Converts a SDK Message to a proto Message."""
232+
if not message:
233+
return None
234+
role = message.role.value if message.role else ''
235+
return vertexai_types.TaskMessage(
236+
message_id=message.message_id,
237+
role=role,
238+
parts=[to_stored_part(part) for part in message.parts],
239+
metadata=to_stored_metadata(
240+
original_metadata=message.metadata,
241+
extensions=message.extensions,
242+
reference_task_ids=message.reference_task_ids,
243+
parts=message.parts,
244+
),
245+
)
246+
247+
248+
def to_sdk_message(
249+
stored_msg: vertexai_types.TaskMessage | None,
250+
) -> Message | None:
251+
"""Converts a proto Message to a SDK Message."""
252+
if not stored_msg:
253+
return None
254+
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
255+
part_metadata_list = unpacked_meta.part_metadata or []
256+
257+
parts = []
258+
for i, part in enumerate(stored_msg.parts or []):
259+
part_metadata: dict[str, Any] | None = None
260+
if i < len(part_metadata_list):
261+
part_metadata = part_metadata_list[i]
262+
parts.append(to_sdk_part(part, part_metadata=part_metadata))
263+
264+
return Message(
265+
message_id=stored_msg.message_id,
266+
role=Role(stored_msg.role),
267+
extensions=unpacked_meta.extensions,
268+
reference_task_ids=unpacked_meta.reference_task_ids,
269+
metadata=unpacked_meta.original_metadata,
270+
parts=parts,
127271
)
128272

129273

@@ -133,6 +277,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
133277
context_id=task.context_id,
134278
metadata=task.metadata,
135279
state=to_stored_task_state(task.status.state),
280+
status_details=vertexai_types.TaskStatusDetails(
281+
task_message=to_stored_message(task.status.message)
282+
)
283+
if task.status.message
284+
else None,
136285
output=vertexai_types.TaskOutput(
137286
artifacts=[
138287
to_stored_artifact(artifact)
@@ -144,10 +293,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
144293

145294
def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
146295
"""Converts a proto A2aTask to a SDK Task."""
296+
msg: Message | None = None
297+
if a2a_task.status_details and a2a_task.status_details.task_message:
298+
msg = to_sdk_message(a2a_task.status_details.task_message)
299+
147300
return Task(
148301
id=a2a_task.name.split('/')[-1],
149302
context_id=a2a_task.context_id,
150-
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
303+
status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg),
151304
metadata=a2a_task.metadata or {},
152305
artifacts=[
153306
to_sdk_artifact(artifact)

src/a2a/contrib/tasks/vertex_task_store.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,32 @@ def _get_status_change_event(
8080
)
8181
return None
8282

83+
def _get_status_details_change_event(
84+
self,
85+
previous_task: Task,
86+
task: Task,
87+
event_sequence_number: int,
88+
) -> vertexai_types.TaskEvent | None:
89+
if task.status.message != previous_task.status.message:
90+
status_details = (
91+
vertexai_types.TaskStatusDetails(
92+
task_message=vertex_task_converter.to_stored_message(
93+
task.status.message
94+
)
95+
)
96+
if task.status.message
97+
else vertexai_types.TaskStatusDetails()
98+
)
99+
return vertexai_types.TaskEvent(
100+
event_data=vertexai_types.TaskEventData(
101+
status_details_change=vertexai_types.TaskStatusDetailsChange(
102+
new_task_status=status_details,
103+
),
104+
),
105+
event_sequence_number=event_sequence_number,
106+
)
107+
return None
108+
83109
def _get_metadata_change_event(
84110
self, previous_task: Task, task: Task, event_sequence_number: int
85111
) -> vertexai_types.TaskEvent | None:
@@ -158,6 +184,13 @@ async def _update(
158184
events.append(status_event)
159185
event_sequence_number += 1
160186

187+
status_details_event = self._get_status_details_change_event(
188+
previous_task, task, event_sequence_number
189+
)
190+
if status_details_event:
191+
events.append(status_details_event)
192+
event_sequence_number += 1
193+
161194
metadata_event = self._get_metadata_change_event(
162195
previous_task, task, event_sequence_number
163196
)

tests/contrib/tasks/fake_vertex_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ async def append(
3636
data = event.event_data
3737
if getattr(data, 'state_change', None):
3838
task.state = getattr(data.state_change, 'new_state', task.state)
39+
if getattr(data, 'status_details_change', None):
40+
task.status_details = getattr(
41+
data.status_details_change,
42+
'new_task_status',
43+
getattr(task, 'status_details', None),
44+
)
3945
if getattr(data, 'metadata_change', None):
4046
task.metadata = getattr(
4147
data.metadata_change, 'new_metadata', task.metadata

0 commit comments

Comments
 (0)