Skip to content

Commit 7c09fdc

Browse files
committed
feat: multicast a2a
Signed-off-by: Mauro Sardara <msardara@cisco.com>
1 parent 534396c commit 7c09fdc

10 files changed

Lines changed: 379 additions & 92 deletions

File tree

buf.gen.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ managed:
99
plugins:
1010
# Generate python protobuf related code
1111
# Generates _pb2_slimrpc.py files
12-
- local: protoc-gen-slimrpc-python
12+
- local: /Users/msardara/repos/slim/data-plane/target/release/protoc-gen-slimrpc-python
1313
out: slima2a/types
1414
opt:
1515
- 'types_import=from a2a.grpc import a2a_pb2 as a2a__pb2'
13.4 MB
Binary file not shown.

examples/echo_agent/client.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import asyncio
33
import logging
4+
from typing import cast
45
from uuid import uuid4
56

67
import httpx
@@ -15,6 +16,7 @@
1516
Message,
1617
Part,
1718
Role,
19+
TaskArtifactUpdateEvent,
1820
TextPart,
1921
)
2022
from a2a.utils.constants import (
@@ -30,6 +32,12 @@
3032

3133
BASE_URL = "http://localhost:9999"
3234

35+
# Two echo agent instances to multicast to
36+
AGENT_NAMES = [
37+
"agntcy/demo/echo_agent1",
38+
"agntcy/demo/echo_agent2",
39+
]
40+
3341
logger = logging.getLogger(__name__)
3442

3543

@@ -66,7 +74,7 @@ async def main() -> None:
6674

6775
client_config = ClientConfig(
6876
supported_transports=["JSONRPC", "slimrpc"],
69-
streaming=args.stream,
77+
streaming=True,
7078
httpx_client=httpx_client,
7179
slimrpc_channel_factory=slimrpc_channel_factory(slim_local_app, conn_id),
7280
)
@@ -78,7 +86,7 @@ async def main() -> None:
7886
agent_card: AgentCard
7987
match args.type:
8088
case "slimrpc":
81-
agent_card = minimal_agent_card("agntcy/demo/echo_agent", ["slimrpc"])
89+
agent_card = minimal_agent_card(",".join(AGENT_NAMES), ["slimrpc"])
8290
case "starlette":
8391
agent_card = await fetch_agent_card(
8492
resolver=A2ACardResolver(
@@ -90,6 +98,19 @@ async def main() -> None:
9098
raise ValueError(f"Invalid client type: {args.type}")
9199

92100
client = client_factory.create(card=agent_card)
101+
102+
if args.type == "slimrpc":
103+
# Fetch agent cards from all servers in the group.
104+
transport = cast(SRPCTransport, client._transport) # type: ignore[attr-defined]
105+
async for card in transport.get_all_cards():
106+
logger.info(
107+
f"agent card: {card.model_dump_json(indent=2, exclude_none=True)}"
108+
)
109+
110+
# Fetch the real card from the first server so that
111+
# client._card.capabilities.streaming=True before send_message.
112+
await client.get_card()
113+
93114
logger.info("A2AClient initialized.")
94115

95116
response_text = await send_message(client, args.text)
@@ -106,12 +127,6 @@ def parse_arguments() -> argparse.Namespace:
106127
required=False,
107128
default="ERROR",
108129
)
109-
parser.add_argument(
110-
"--stream",
111-
action="store_true",
112-
required=False,
113-
default=False,
114-
)
115130
parser.add_argument(
116131
"--text",
117132
type=str,
@@ -144,24 +159,25 @@ async def send_message(
144159
)
145160
logger.info(f"associated request ({request_id}) with text: {text}")
146161

147-
output = ""
162+
# Group output by task_id so responses from each server are kept separate.
163+
outputs: dict[str, str] = {}
148164
try:
149165
async for event in client.send_message(request=request):
150166
if isinstance(event, Message):
167+
outputs.setdefault("msg", "")
151168
for part in event.parts:
152169
if isinstance(part.root, TextPart):
153-
output += part.root.text
170+
outputs["msg"] += part.root.text
154171
else:
155172
task, update = event
156173
logger.info(f"task ({task.id}) status: {task.status.state}")
157174

158-
if task.status.state == "completed" and task.artifacts:
159-
for artifact in task.artifacts:
160-
for part in artifact.parts:
161-
if isinstance(part.root, TextPart):
162-
output += part.root.text
163-
164-
if update:
175+
if isinstance(update, TaskArtifactUpdateEvent):
176+
outputs.setdefault(update.task_id, "")
177+
for part in update.artifact.parts:
178+
if isinstance(part.root, TextPart):
179+
outputs[update.task_id] += part.root.text
180+
elif update:
165181
logger.info(f"update: {update.model_dump(mode='json')}")
166182
except Exception as e:
167183
logger.error(
@@ -170,7 +186,7 @@ async def send_message(
170186
)
171187
raise RuntimeError("failed sending message or processing response") from e
172188

173-
return output
189+
return "\n---\n".join(outputs.values())
174190

175191

176192
if __name__ == "__main__":

examples/echo_agent/server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def main() -> None:
5959
service, local_app, local_name, conn_id = await setup_slim_client(
6060
namespace="agntcy",
6161
group="demo",
62-
name="echo_agent",
62+
name=args.instance,
6363
)
6464

6565
# Create server
@@ -90,6 +90,13 @@ def parse_arguments() -> argparse.Namespace:
9090

9191
parser.add_argument("--type", type=str, required=False, default="slimrpc")
9292

93+
parser.add_argument(
94+
"--instance",
95+
type=str,
96+
required=False,
97+
default="echo_agent",
98+
)
99+
93100
parser.add_argument(
94101
"--log-level",
95102
type=str,

examples/travel_planner_agent/client.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77
logging.getLogger("asyncio").setLevel(logging.ERROR)
88

99
# ruff: noqa: E402
10+
from typing import cast
11+
1012
import httpx
11-
from a2a.client import (
12-
Client,
13-
ClientFactory,
14-
minimal_agent_card,
15-
)
13+
from a2a.client import Client, ClientFactory, minimal_agent_card
1614
from a2a.types import (
1715
Message,
1816
Part,
1917
Role,
18+
TaskArtifactUpdateEvent,
2019
TextPart,
2120
)
2221

@@ -27,6 +26,14 @@
2726
slimrpc_channel_factory,
2827
)
2928

29+
# Two travel planner instances to multicast to
30+
AGENT_NAMES = [
31+
"agntcy/demo/travel_planner_agent1",
32+
"agntcy/demo/travel_planner_agent2",
33+
]
34+
35+
logger = logging.getLogger(__name__)
36+
3037

3138
def print_welcome_message() -> None:
3239
print("Welcome to the generic A2A client!")
@@ -51,26 +58,31 @@ async def interact_with_server(client: Client) -> None:
5158
parts=[Part(root=TextPart(text=user_input))],
5259
)
5360

54-
output = ""
61+
# Group output by task_id so responses from each server are kept separate.
62+
outputs: dict[str, str] = {}
5563
try:
5664
async for response in client.send_message(request=request):
5765
if isinstance(response, Message):
66+
outputs.setdefault("msg", "")
5867
for part in response.parts:
5968
if isinstance(part.root, TextPart):
60-
output += part.root.text
69+
outputs["msg"] += part.root.text
6170
else:
62-
task, _ = response
71+
_, update = response
6372

64-
if task.status.state == "completed" and task.artifacts:
65-
for artifact in task.artifacts:
66-
for part in artifact.parts:
67-
if isinstance(part.root, TextPart):
68-
output += part.root.text
73+
if isinstance(update, TaskArtifactUpdateEvent):
74+
outputs.setdefault(update.task_id, "")
75+
for part in update.artifact.parts:
76+
if isinstance(part.root, TextPart):
77+
outputs[update.task_id] += part.root.text
6978

7079
except Exception as e:
7180
raise RuntimeError("failed sending message or processing response") from e
7281

73-
print(output, end="", flush=True)
82+
for i, text in enumerate(outputs.values()):
83+
if i > 0:
84+
print("\n---")
85+
print(text, end="", flush=True)
7486
await asyncio.sleep(0.1)
7587

7688

@@ -96,8 +108,19 @@ async def main() -> None:
96108

97109
# mypy: the register API expects a different callable type; safe to ignore here.
98110
client_factory.register("slimrpc", SRPCTransport.create) # type: ignore
99-
agent_card = minimal_agent_card("agntcy/demo/travel_planner_agent", ["slimrpc"])
100-
client = client_factory.create(card=agent_card)
111+
112+
client = client_factory.create(
113+
card=minimal_agent_card(",".join(AGENT_NAMES), ["slimrpc"])
114+
)
115+
116+
# Fetch agent cards from all servers in the group.
117+
transport = cast(SRPCTransport, client._transport) # type: ignore[attr-defined]
118+
async for card in transport.get_all_cards():
119+
logger.info(f"agent card: {card.model_dump_json(indent=2, exclude_none=True)}")
120+
121+
# Fetch the real card from the first server so that
122+
# client._card.capabilities.streaming=True before send_message.
123+
await client.get_card()
101124

102125
await interact_with_server(client)
103126

examples/travel_planner_agent/server.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import asyncio
23
import logging
34

@@ -22,6 +23,8 @@
2223

2324

2425
async def main() -> None:
26+
args = parse_arguments()
27+
2528
skill = AgentSkill(
2629
id="travel_planner",
2730
name="travel planner agent",
@@ -52,7 +55,7 @@ async def main() -> None:
5255
service, local_app, local_name, conn_id = await setup_slim_client(
5356
namespace="agntcy",
5457
group="demo",
55-
name="travel_planner_agent",
58+
name=args.instance,
5659
)
5760

5861
# Create server
@@ -67,5 +70,18 @@ async def main() -> None:
6770
await server.serve_async()
6871

6972

73+
def parse_arguments() -> argparse.Namespace:
74+
parser = argparse.ArgumentParser()
75+
76+
parser.add_argument(
77+
"--instance",
78+
type=str,
79+
required=False,
80+
default="travel_planner_agent",
81+
)
82+
83+
return parser.parse_args()
84+
85+
7086
if __name__ == "__main__":
7187
asyncio.run(main())

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ readme = "README.md"
99
requires-python = ">=3.10, <4.0"
1010
dependencies = [
1111
"a2a-sdk[telemetry]==0.3.20",
12-
"slim-bindings~=1.1",
12+
"slim-bindings",
1313
]
1414
classifiers = [
1515
"Development Status :: 3 - Alpha",
@@ -44,6 +44,9 @@ examples = [
4444
[tool.uv]
4545
default-groups = ["linting", "testing", "examples"]
4646

47+
[tool.uv.sources]
48+
slim-bindings = { path = "../slim/data-plane/target/wheels/slim_bindings-1.2.0-py3-none-macosx_11_0_arm64.whl" }
49+
4750
[tool.ruff.format]
4851
exclude = ["*_pb2.py", "*_pb2.pyi", "*_pb2_slimrpc.py"]
4952

0 commit comments

Comments
 (0)