Skip to content

Commit 6d49122

Browse files
authored
chore: remove the use of deprecated types from VertexTaskStore (a2aproject#889)
* vertexai.Part & co. will be replaced soon by genai.Part & co. * It's better to use the more specifically named variants of `Task` and `Status`: `A2aTask` and `A2aTaskStatus`. For a2aproject#751
1 parent 4ebbb2e commit 6d49122

2 files changed

Lines changed: 69 additions & 54 deletions

File tree

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
try:
2+
from google.genai import types as genai_types
23
from vertexai import types as vertexai_types
34
except ImportError as e:
45
raise ImportError(
@@ -25,70 +26,70 @@
2526

2627

2728
_TO_SDK_TASK_STATE = {
28-
vertexai_types.State.STATE_UNSPECIFIED: TaskState.unknown,
29-
vertexai_types.State.SUBMITTED: TaskState.submitted,
30-
vertexai_types.State.WORKING: TaskState.working,
31-
vertexai_types.State.COMPLETED: TaskState.completed,
32-
vertexai_types.State.CANCELLED: TaskState.canceled,
33-
vertexai_types.State.FAILED: TaskState.failed,
34-
vertexai_types.State.REJECTED: TaskState.rejected,
35-
vertexai_types.State.INPUT_REQUIRED: TaskState.input_required,
36-
vertexai_types.State.AUTH_REQUIRED: TaskState.auth_required,
29+
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
30+
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
31+
vertexai_types.A2aTaskState.WORKING: TaskState.working,
32+
vertexai_types.A2aTaskState.COMPLETED: TaskState.completed,
33+
vertexai_types.A2aTaskState.CANCELLED: TaskState.canceled,
34+
vertexai_types.A2aTaskState.FAILED: TaskState.failed,
35+
vertexai_types.A2aTaskState.REJECTED: TaskState.rejected,
36+
vertexai_types.A2aTaskState.INPUT_REQUIRED: TaskState.input_required,
37+
vertexai_types.A2aTaskState.AUTH_REQUIRED: TaskState.auth_required,
3738
}
3839

3940
_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()}
4041

4142

42-
def to_sdk_task_state(stored_state: vertexai_types.State) -> TaskState:
43+
def to_sdk_task_state(stored_state: vertexai_types.A2aTaskState) -> TaskState:
4344
"""Converts a proto A2aTask.State to a TaskState enum."""
4445
return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown)
4546

4647

47-
def to_stored_task_state(task_state: TaskState) -> vertexai_types.State:
48+
def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
4849
"""Converts a TaskState enum to a proto A2aTask.State enum value."""
4950
return _SDK_TO_STORED_TASK_STATE.get(
50-
task_state, vertexai_types.State.STATE_UNSPECIFIED
51+
task_state, vertexai_types.A2aTaskState.STATE_UNSPECIFIED
5152
)
5253

5354

54-
def to_stored_part(part: Part) -> vertexai_types.Part:
55+
def to_stored_part(part: Part) -> genai_types.Part:
5556
"""Converts a SDK Part to a proto Part."""
5657
if isinstance(part.root, TextPart):
57-
return vertexai_types.Part(text=part.root.text)
58+
return genai_types.Part(text=part.root.text)
5859
if isinstance(part.root, DataPart):
5960
data_bytes = json.dumps(part.root.data).encode('utf-8')
60-
return vertexai_types.Part(
61-
inline_data=vertexai_types.Blob(
61+
return genai_types.Part(
62+
inline_data=genai_types.Blob(
6263
mime_type='application/json', data=data_bytes
6364
)
6465
)
6566
if isinstance(part.root, FilePart):
6667
file_content = part.root.file
6768
if isinstance(file_content, FileWithBytes):
6869
decoded_bytes = base64.b64decode(file_content.bytes)
69-
return vertexai_types.Part(
70-
inline_data=vertexai_types.Blob(
70+
return genai_types.Part(
71+
inline_data=genai_types.Blob(
7172
mime_type=file_content.mime_type or '', data=decoded_bytes
7273
)
7374
)
7475
if isinstance(file_content, FileWithUri):
75-
return vertexai_types.Part(
76-
file_data=vertexai_types.FileData(
76+
return genai_types.Part(
77+
file_data=genai_types.FileData(
7778
mime_type=file_content.mime_type or '',
7879
file_uri=file_content.uri,
7980
)
8081
)
8182
raise ValueError(f'Unsupported part type: {type(part.root)}')
8283

8384

84-
def to_sdk_part(stored_part: vertexai_types.Part) -> Part:
85+
def to_sdk_part(stored_part: genai_types.Part) -> Part:
8586
"""Converts a proto Part to a SDK Part."""
8687
if stored_part.text:
8788
return Part(root=TextPart(text=stored_part.text))
8889
if stored_part.inline_data:
89-
encoded_bytes = base64.b64encode(stored_part.inline_data.data).decode(
90-
'utf-8'
91-
)
90+
encoded_bytes = base64.b64encode(
91+
stored_part.inline_data.data or b''
92+
).decode('utf-8')
9293
return Part(
9394
root=FilePart(
9495
file=FileWithBytes(

tests/contrib/tasks/test_vertex_task_converter.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
'vertexai', reason='Vertex Task Converter tests require vertexai'
88
)
99
from vertexai import types as vertexai_types
10-
10+
from google.genai import types as genai_types
1111
from a2a.contrib.tasks.vertex_task_converter import (
1212
to_sdk_artifact,
1313
to_sdk_part,
@@ -34,29 +34,39 @@
3434

3535
def test_to_sdk_task_state() -> None:
3636
assert (
37-
to_sdk_task_state(vertexai_types.State.STATE_UNSPECIFIED)
37+
to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED)
3838
== TaskState.unknown
3939
)
4040
assert (
41-
to_sdk_task_state(vertexai_types.State.SUBMITTED) == TaskState.submitted
41+
to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED)
42+
== TaskState.submitted
43+
)
44+
assert (
45+
to_sdk_task_state(vertexai_types.A2aTaskState.WORKING)
46+
== TaskState.working
4247
)
43-
assert to_sdk_task_state(vertexai_types.State.WORKING) == TaskState.working
4448
assert (
45-
to_sdk_task_state(vertexai_types.State.COMPLETED) == TaskState.completed
49+
to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED)
50+
== TaskState.completed
4651
)
4752
assert (
48-
to_sdk_task_state(vertexai_types.State.CANCELLED) == TaskState.canceled
53+
to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED)
54+
== TaskState.canceled
4955
)
50-
assert to_sdk_task_state(vertexai_types.State.FAILED) == TaskState.failed
5156
assert (
52-
to_sdk_task_state(vertexai_types.State.REJECTED) == TaskState.rejected
57+
to_sdk_task_state(vertexai_types.A2aTaskState.FAILED)
58+
== TaskState.failed
5359
)
5460
assert (
55-
to_sdk_task_state(vertexai_types.State.INPUT_REQUIRED)
61+
to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED)
62+
== TaskState.rejected
63+
)
64+
assert (
65+
to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED)
5666
== TaskState.input_required
5767
)
5868
assert (
59-
to_sdk_task_state(vertexai_types.State.AUTH_REQUIRED)
69+
to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED)
6070
== TaskState.auth_required
6171
)
6272
assert to_sdk_task_state(999) == TaskState.unknown # type: ignore
@@ -65,35 +75,39 @@ def test_to_sdk_task_state() -> None:
6575
def test_to_stored_task_state() -> None:
6676
assert (
6777
to_stored_task_state(TaskState.unknown)
68-
== vertexai_types.State.STATE_UNSPECIFIED
78+
== vertexai_types.A2aTaskState.STATE_UNSPECIFIED
6979
)
7080
assert (
7181
to_stored_task_state(TaskState.submitted)
72-
== vertexai_types.State.SUBMITTED
82+
== vertexai_types.A2aTaskState.SUBMITTED
7383
)
7484
assert (
75-
to_stored_task_state(TaskState.working) == vertexai_types.State.WORKING
85+
to_stored_task_state(TaskState.working)
86+
== vertexai_types.A2aTaskState.WORKING
7687
)
7788
assert (
7889
to_stored_task_state(TaskState.completed)
79-
== vertexai_types.State.COMPLETED
90+
== vertexai_types.A2aTaskState.COMPLETED
8091
)
8192
assert (
8293
to_stored_task_state(TaskState.canceled)
83-
== vertexai_types.State.CANCELLED
94+
== vertexai_types.A2aTaskState.CANCELLED
95+
)
96+
assert (
97+
to_stored_task_state(TaskState.failed)
98+
== vertexai_types.A2aTaskState.FAILED
8499
)
85-
assert to_stored_task_state(TaskState.failed) == vertexai_types.State.FAILED
86100
assert (
87101
to_stored_task_state(TaskState.rejected)
88-
== vertexai_types.State.REJECTED
102+
== vertexai_types.A2aTaskState.REJECTED
89103
)
90104
assert (
91105
to_stored_task_state(TaskState.input_required)
92-
== vertexai_types.State.INPUT_REQUIRED
106+
== vertexai_types.A2aTaskState.INPUT_REQUIRED
93107
)
94108
assert (
95109
to_stored_task_state(TaskState.auth_required)
96-
== vertexai_types.State.AUTH_REQUIRED
110+
== vertexai_types.A2aTaskState.AUTH_REQUIRED
97111
)
98112

99113

@@ -155,15 +169,15 @@ class BadPart:
155169

156170

157171
def test_to_sdk_part_text() -> None:
158-
stored_part = vertexai_types.Part(text='hello back')
172+
stored_part = genai_types.Part(text='hello back')
159173
sdk_part = to_sdk_part(stored_part)
160174
assert isinstance(sdk_part.root, TextPart)
161175
assert sdk_part.root.text == 'hello back'
162176

163177

164178
def test_to_sdk_part_inline_data() -> None:
165-
stored_part = vertexai_types.Part(
166-
inline_data=vertexai_types.Blob(
179+
stored_part = genai_types.Part(
180+
inline_data=genai_types.Blob(
167181
mime_type='application/json',
168182
data=b'{"key": "val"}',
169183
)
@@ -177,8 +191,8 @@ def test_to_sdk_part_inline_data() -> None:
177191

178192

179193
def test_to_sdk_part_file_data() -> None:
180-
stored_part = vertexai_types.Part(
181-
file_data=vertexai_types.FileData(
194+
stored_part = genai_types.Part(
195+
file_data=genai_types.FileData(
182196
mime_type='image/jpeg',
183197
file_uri='gs://bucket/image.jpg',
184198
)
@@ -191,7 +205,7 @@ def test_to_sdk_part_file_data() -> None:
191205

192206

193207
def test_to_sdk_part_unsupported() -> None:
194-
stored_part = vertexai_types.Part()
208+
stored_part = genai_types.Part()
195209
with pytest.raises(ValueError, match='Unsupported part:'):
196210
to_sdk_part(stored_part)
197211

@@ -210,7 +224,7 @@ def test_to_stored_artifact() -> None:
210224
def test_to_sdk_artifact() -> None:
211225
stored_artifact = vertexai_types.TaskArtifact(
212226
artifact_id='art-456',
213-
parts=[vertexai_types.Part(text='part_2')],
227+
parts=[genai_types.Part(text='part_2')],
214228
)
215229
sdk_artifact = to_sdk_artifact(stored_artifact)
216230
assert sdk_artifact.artifact_id == 'art-456'
@@ -236,7 +250,7 @@ def test_to_stored_task() -> None:
236250
stored_task = to_stored_task(sdk_task)
237251
assert stored_task.context_id == 'ctx-1'
238252
assert stored_task.metadata == {'foo': 'bar'}
239-
assert stored_task.state == vertexai_types.State.WORKING
253+
assert stored_task.state == vertexai_types.A2aTaskState.WORKING
240254
assert stored_task.output is not None
241255
assert stored_task.output.artifacts is not None
242256
assert len(stored_task.output.artifacts) == 1
@@ -247,13 +261,13 @@ def test_to_sdk_task() -> None:
247261
stored_task = vertexai_types.A2aTask(
248262
name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2',
249263
context_id='ctx-2',
250-
state=vertexai_types.State.COMPLETED,
264+
state=vertexai_types.A2aTaskState.COMPLETED,
251265
metadata={'a': 'b'},
252266
output=vertexai_types.TaskOutput(
253267
artifacts=[
254268
vertexai_types.TaskArtifact(
255269
artifact_id='art-2',
256-
parts=[vertexai_types.Part(text='result')],
270+
parts=[genai_types.Part(text='result')],
257271
)
258272
]
259273
),
@@ -275,7 +289,7 @@ def test_to_sdk_task_no_output() -> None:
275289
stored_task = vertexai_types.A2aTask(
276290
name='tasks/task-3',
277291
context_id='ctx-3',
278-
state=vertexai_types.State.SUBMITTED,
292+
state=vertexai_types.A2aTaskState.SUBMITTED,
279293
metadata=None,
280294
)
281295
sdk_task = to_sdk_task(stored_task)

0 commit comments

Comments
 (0)