-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathchat_proto.py
More file actions
242 lines (178 loc) · 7.01 KB
/
chat_proto.py
File metadata and controls
242 lines (178 loc) · 7.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
import time
from typing import Any, Literal, TypedDict
from datetime import datetime
from pydantic.v1 import UUID4
from uagents import Model, Protocol, Context
from uuid import uuid4
from functions import CompanyOverviewRequest, fetch_overview_json
AI_AGENT_ADDRESS = os.getenv("AI_AGENT_ADDRESS")
class Metadata(TypedDict):
# primarily used with hte `Resource` model. This field specifies the mime_type of
# resource that is being referenced. A full list can be found at `docs/mime_types.md`
mime_type: str
# the role of the resource
role: str
class TextContent(Model):
type: Literal["text"]
# The text of the content. The format of this field is UTF-8 encoded strings. Additionally,
# markdown based formatting can be used and will be supported by most clients
text: str
class Resource(Model):
# the uri of the resource
uri: str
# the set of metadata for this resource, for more detailed description of the set of
# fields see `docs/metadata.md`
metadata: dict[str, str]
class ResourceContent(Model):
type: Literal["resource"]
# The resource id
resource_id: UUID4
# The resource or list of resource for this content. typically only a single
# resource will be sent, however, if there are accompanying resources like
# thumbnails and audo tracks these can be additionally referenced
#
# In the case of the a list of resources, the first element of the list is always
# considered the primary resource
resource: Resource | list[Resource]
class MetadataContent(Model):
type: Literal["metadata"]
# the set of metadata for this content, for more detailed description of the set of
# fields see `docs/metadata.md`
metadata: dict[str, str]
class StartSessionContent(Model):
type: Literal["start-session"]
class EndSessionContent(Model):
type: Literal["end-session"]
class StartStreamContent(Model):
type: Literal["start-stream"]
stream_id: UUID4
class EndStreamContent(Model):
type: Literal["start-stream"]
stream_id: UUID4
# The combined agent content types
AgentContent = (
TextContent
| ResourceContent
| MetadataContent
| StartSessionContent
| EndSessionContent
| StartStreamContent
| EndStreamContent
)
class ChatMessage(Model):
# the timestamp for the message, should be in UTC
timestamp: datetime
# a unique message id that is generated from the message instigator
msg_id: UUID4
# the list of content elements in the chat
content: list[AgentContent]
class ChatAcknowledgement(Model):
# the timestamp for the message, should be in UTC
timestamp: datetime
# the msg id that is being acknowledged
acknowledged_msg_id: UUID4
# optional acknowledgement metadata
metadata: dict[str, str] | None = None
def create_text_chat(text: str) -> ChatMessage:
return ChatMessage(
timestamp=datetime.utcnow(),
msg_id=uuid4(),
content=[TextContent(type="text", text=text)],
)
def create_end_session_chat() -> ChatMessage:
return ChatMessage(
timestamp=datetime.utcnow(),
msg_id=uuid4(),
content=[EndSessionContent(type="end-session")],
)
chat_proto = Protocol(name="AgentChatProtcol", version="0.2.1")
struct_output_client_proto = Protocol(
name="StructuredOutputClientProtocol", version="0.1.0"
)
class StructuredOutputPrompt(Model):
prompt: str
output_schema: dict[str, Any]
class StructuredOutputResponse(Model):
output: dict[str, Any]
@chat_proto.on_message(ChatMessage)
async def handle_message(ctx: Context, sender: str, msg: ChatMessage):
await ctx.send(
sender,
ChatAcknowledgement(
timestamp=datetime.utcnow(), acknowledged_msg_id=msg.msg_id
),
)
for item in msg.content:
if isinstance(item, StartSessionContent):
ctx.logger.info(f"Got a start session message from {sender}")
continue
elif isinstance(item, TextContent):
ctx.logger.info(f"Got a message from {sender}: {item.text}")
ctx.storage.set(str(ctx.session), sender)
await ctx.send(
AI_AGENT_ADDRESS,
StructuredOutputPrompt(
prompt=item.text, output_schema=CompanyOverviewRequest.schema()
),
)
else:
ctx.logger.info(f"Got unexpected content from {sender}")
@chat_proto.on_message(ChatAcknowledgement)
async def handle_ack(ctx: Context, sender: str, msg: ChatAcknowledgement):
ctx.logger.info(
f"Got an acknowledgement from {sender} for {msg.acknowledged_msg_id}"
)
@struct_output_client_proto.on_message(StructuredOutputResponse)
async def handle_structured_output_response(
ctx: Context, sender: str, msg: StructuredOutputResponse
):
prompt = CompanyOverviewRequest.parse_obj(msg.output)
session_sender = ctx.storage.get(str(ctx.session))
if session_sender is None:
ctx.logger.error(
"Discarding message because no session sender found in storage"
)
return
cache = ctx.storage.get(prompt.ticker) or None
if cache:
if int(time.time()) - cache["timestamp"] < 86400:
cache.pop("timestamp")
chat_message = create_text_chat(
f"Company: {cache['Name']} ({cache['Symbol']})\n"
f"Exchange: {cache['Exchange']} | Currency: {cache['Currency']}\n"
f"Industry: {cache['Industry']} | Sector: {cache['Sector']}\n"
f"Market Cap: {cache['Currency']} {cache['MarketCapitalization']}\n"
f"PE Ratio: {cache['PERatio']} | EPS: {cache['EPS']}\n"
f"Website: {cache['OfficialSite']}\n\n"
f"Description: {cache['Description']}"
)
await ctx.send(session_sender, chat_message)
return
try:
output_json = fetch_overview_json(prompt.ticker)
except Exception as err:
ctx.logger.error(err)
await ctx.send(
session_sender,
create_text_chat(
"Sorry, I couldn't process your request. Please try again later."
),
)
return
if "error" in output_json:
await ctx.send(session_sender, create_text_chat(str(output_json["error"])))
return
chat_message = create_text_chat(
f"Company: {output_json['Name']} ({output_json['Symbol']})\n"
f"Exchange: {output_json['Exchange']} | Currency: {output_json['Currency']}\n"
f"Industry: {output_json['Industry']} | Sector: {output_json['Sector']}\n"
f"Market Cap: {output_json['Currency']} {output_json['MarketCapitalization']}\n"
f"PE Ratio: {output_json['PERatio']} | EPS: {output_json['EPS']}\n"
f"Website: {output_json['OfficialSite']}\n\n"
f"Description: {output_json['Description']}"
)
output_json["timestamp"] = int(time.time())
ctx.storage.set(prompt.ticker, output_json)
await ctx.send(session_sender, chat_message)
await ctx.send(session_sender, create_end_session_chat())