-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathshort_term_memory.py
More file actions
295 lines (244 loc) · 11.2 KB
/
short_term_memory.py
File metadata and controls
295 lines (244 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Literal
from google.adk.sessions import (
BaseSessionService,
DatabaseSessionService,
InMemorySessionService,
Session,
)
from pydantic import BaseModel, Field, PrivateAttr
from veadk.memory.short_term_memory_backends.mysql_backend import (
MysqlSTMBackend,
)
from veadk.memory.short_term_memory_backends.postgresql_backend import (
PostgreSqlSTMBackend,
)
from veadk.memory.short_term_memory_backends.sqlite_backend import (
SQLiteSTMBackend,
)
from veadk.utils.logger import get_logger
if TYPE_CHECKING:
from google.adk.events import Event
from veadk import Agent
logger = get_logger(__name__)
def wrap_get_session_with_callbacks(obj, callback_fn: Callable):
get_session_fn = getattr(obj, "get_session")
@wraps(get_session_fn)
async def wrapper(*args, **kwargs):
result = await get_session_fn(*args, **kwargs)
callback_fn(result, *args, **kwargs)
return result
setattr(obj, "get_session", wrapper)
class ShortTermMemory(BaseModel):
"""Short term memory for agent execution.
The short term memory represents the context of the agent model. All content in the short term memory will be sent to agent model directly, including the system prompt, historical user prompt, and historical model responses.
Attributes:
backend (Literal["local", "mysql", "sqlite", "postgresql", "database"]):
The backend of short term memory:
- `local` for in-memory storage
- `mysql` for mysql / PostgreSQL storage
- `sqlite` for locally sqlite storage
backend_configs (dict): Configuration dict for init short term memory backend.
db_url (str):
Database connection url for init short term memory backend.
For example, `sqlite:///./test.db`. Once set, it will override the `backend` parameter.
local_database_path (str):
Local database path, only used when `backend` is `sqlite`.
Default to `/tmp/veadk_local_database.db`.
after_load_memory_callback (Callable | None):
A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input.
"""
backend: Literal["local", "mysql", "sqlite", "postgresql", "database"] = "local"
backend_configs: dict = Field(default_factory=dict)
db_kwargs: dict = Field(default_factory=dict)
db_url: str = ""
local_database_path: str = "/tmp/veadk_local_database.db"
after_load_memory_callback: Callable | None = None
_session_service: BaseSessionService = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
if self.db_url:
logger.info("The `db_url` is set, ignore `backend` option.")
if self.db_url.count("@") > 1 or self.db_url.count(":") > 2:
logger.warning(
"Multiple `@` or `:` symbols detected in the database URL. "
"Please encode `username` or `password` with `urllib.parse.quote_plus`. "
"Examples: p@ssword→p%40ssword."
)
self._session_service = DatabaseSessionService(
db_url=self.db_url, **self.db_kwargs
)
else:
if self.backend == "database":
logger.warning(
"Backend `database` is deprecated, use `sqlite` to create short term memory."
)
self.backend = "sqlite"
match self.backend:
case "local":
self._session_service = InMemorySessionService()
case "mysql":
self._session_service = MysqlSTMBackend(
db_kwargs=self.db_kwargs, **self.backend_configs
).session_service
case "sqlite":
self._session_service = SQLiteSTMBackend(
local_path=self.local_database_path
).session_service
case "postgresql":
self._session_service = PostgreSqlSTMBackend(
db_kwargs=self.db_kwargs, **self.backend_configs
).session_service
if self.after_load_memory_callback:
wrap_get_session_with_callbacks(
self._session_service, self.after_load_memory_callback
)
@property
def session_service(self) -> BaseSessionService:
return self._session_service
async def create_session(
self,
app_name: str,
user_id: str,
session_id: str,
state: dict | None = None,
) -> Session | None:
"""Create or retrieve a user session.
Short term memory can attempt to create a new session for a given application and user. If a session with the same `session_id` already exists, it will be returned instead of creating a new one.
If the underlying session service is backed by a database (`DatabaseSessionService`), the method first lists all existing sessions for the given `app_name` and `user_id` and logs the number of sessions found. It then checks whether a session with the specified `session_id` already exists:
- If it exists → returns the existing session.
- If it does not exist → creates and returns a new session.
Args:
app_name (str): The name of the application associated with the session.
user_id (str): The unique identifier of the user.
session_id (str): The unique identifier of the session to be created or retrieved.
state (dict | None):
The initial state of the session.
If a session with the given `session_id` already exists,
this argument is ignored and the existing session state is preserved.
Returns:
Session | None: The retrieved or newly created `Session` object, or `None` if the session creation failed.
"""
if isinstance(self._session_service, DatabaseSessionService):
list_sessions_response = await self._session_service.list_sessions(
app_name=app_name, user_id=user_id
)
logger.debug(
f"Loaded {len(list_sessions_response.sessions)} sessions from db {self.db_url}."
)
session = await self._session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if session:
logger.info(
f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
)
return session
else:
return await self._session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id, state=state
)
async def generate_profile(
self,
app_name: str,
user_id: str,
session_id: str,
events: list["Event"],
) -> list[str]:
import json
from veadk import Agent, Runner
from veadk.memory.types import MemoryProfile
from veadk.utils.misc import write_string_to_file
event_text = ""
for event in events:
event_text += f"- Event id: {event.id}\nEvent content: {event.content}\n"
agent = Agent(
name="memory_summarizer",
description="A summarizer that summarizes the memory events.",
instruction="""Summarize the memory events into different groups according to the event content. An event can belong to multiple groups. You must output the summary in JSON format (Each group should have a simple name (only a-z and _ is allowed), and a list of event ids):
[
{
"name": "",
"event_ids": ["Event id here"]
},
{
"name": "",
"event_ids": ["Event id here"]
}
]""",
model_name="deepseek-v3-2-251201",
output_schema=MemoryProfile,
)
runner = Runner(agent=agent)
response = await runner.run(messages="Events are: \n" + event_text)
# profile path: ./profiles/memory/<app_name>/user_id/session_id/profile_name.json
groups = json.loads(response)
group_names = [group["name"] for group in groups]
for group in groups:
group["event_list"] = []
for event_id in group["event_ids"]:
for event in events:
if event.id == event_id:
group["event_list"].append(event.content.model_dump_json())
write_string_to_file(
content=json.dumps(group_names, ensure_ascii=False),
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/profile_list.json",
)
for group in groups:
write_string_to_file(
content=json.dumps(group, ensure_ascii=False),
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/{group['name']}.json",
)
return group_names
async def compact_history_events(
self,
app_name: str,
user_id: str,
session_id: str,
compact_limit: int,
agent: "Agent",
):
# 1. generate profile
# 2. compact history events
# 3. append instruction and corresponding tool
session = await self.session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
compact_event_num = 0
compact_counter = 0
for event in session.events:
if event.content.role == "user":
compact_counter += 1
if compact_counter > compact_limit:
break
compact_event_num += 1
events_need_compact = session.events[:compact_event_num] # type: ignore
group_names = await self.generate_profile(
app_name=app_name,
user_id=user_id,
session_id=session_id,
events=events_need_compact,
)
# TODO(yaozheng): directly edit the events are not work as expected,
# need to check the reason later
session.events = session.events[compact_event_num:] # type: ignore
logger.debug(f"Compacted {compact_event_num} events.")
agent.instruction += f"""
The session has been compacted for the first {compact_limit} events. The compacted content are divided into following groups:
{group_names}
You can call `load_history_events` to load the compacted events if you need them according to the user's request.
"""
from veadk.tools.load_history_events import load_history_events
agent.tools.append(load_history_events)