Skip to content

Commit bd62728

Browse files
committed
rebase knowledgebase tool
1 parent a39bdfb commit bd62728

3 files changed

Lines changed: 96 additions & 8 deletions

File tree

veadk/agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,14 @@ def model_post_init(self, __context: Any) -> None:
133133
)
134134

135135
if self.knowledgebase:
136-
from veadk.tools import load_knowledgebase_tool
136+
from veadk.tools.builtin_tools.load_knowledgebase import (
137+
LoadKnowledgebaseTool,
138+
)
137139

138-
load_knowledgebase_tool.knowledgebase = self.knowledgebase
139-
self.tools.append(load_knowledgebase_tool.load_knowledgebase_tool)
140+
load_knowledgebase_tool = LoadKnowledgebaseTool(
141+
knowledgebase=self.knowledgebase
142+
)
143+
self.tools.append(load_knowledgebase_tool)
140144

141145
if self.long_term_memory is not None:
142146
from google.adk.tools import load_memory

veadk/knowledgebase/knowledgebase.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Callable, Literal
15+
from __future__ import annotations
16+
17+
from typing import Any, Callable, Literal, Union
1618

1719
from pydantic import BaseModel, Field
18-
from typing_extensions import Union
1920

2021
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
2122
from veadk.knowledgebase.entry import KnowledgebaseEntry
@@ -54,11 +55,11 @@ def _get_backend_cls(backend: str) -> type[BaseKnowledgebaseBackend]:
5455
raise ValueError(f"Unsupported knowledgebase backend: {backend}")
5556

5657

57-
def build_knowledgebase_index(app_name: str):
58-
return f"veadk_kb_{app_name}"
58+
class KnowledgeBase(BaseModel):
59+
name: str = "user_knowledgebase"
5960

61+
description: str = "This knowledgebase stores some user-related information."
6062

61-
class KnowledgeBase(BaseModel):
6263
backend: Union[
6364
Literal["local", "opensearch", "viking", "redis"], BaseKnowledgebaseBackend
6465
] = "local"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import annotations
2+
3+
from google.adk.models.llm_request import LlmRequest
4+
from google.adk.tools.function_tool import FunctionTool
5+
from google.adk.tools.tool_context import ToolContext
6+
from google.genai import types
7+
from pydantic import BaseModel, Field
8+
from typing_extensions import override
9+
10+
from veadk.knowledgebase import KnowledgeBase
11+
from veadk.knowledgebase.entry import KnowledgebaseEntry
12+
from veadk.utils.logger import get_logger
13+
14+
logger = get_logger(__name__)
15+
16+
17+
class LoadKnowledgebaseResponse(BaseModel):
18+
knowledges: list[KnowledgebaseEntry] = Field(default_factory=list)
19+
20+
21+
class LoadKnowledgebaseTool(FunctionTool):
22+
"""A tool that loads the common knowledgebase"""
23+
24+
def __init__(self, knowledgebase: KnowledgeBase):
25+
super().__init__(self.load_knowledgebase)
26+
27+
self.knowledgebase = knowledgebase
28+
29+
if not self.custom_metadata:
30+
self.custom_metadata = {}
31+
self.custom_metadata["backend"] = knowledgebase.backend
32+
33+
@override
34+
def _get_declaration(self) -> types.FunctionDeclaration | None:
35+
return types.FunctionDeclaration(
36+
name=self.name,
37+
description=self.description,
38+
parameters=types.Schema(
39+
type=types.Type.OBJECT,
40+
properties={
41+
"query": types.Schema(
42+
type=types.Type.STRING,
43+
)
44+
},
45+
required=["query"],
46+
),
47+
)
48+
49+
@override
50+
async def process_llm_request(
51+
self,
52+
*,
53+
tool_context: ToolContext,
54+
llm_request: LlmRequest,
55+
) -> None:
56+
await super().process_llm_request(
57+
tool_context=tool_context, llm_request=llm_request
58+
)
59+
# Tell the model about the knowledgebase.
60+
llm_request.append_instructions(
61+
[
62+
f"""
63+
You have a knowledgebase (knowledegebase name is `{self.knowledgebase.name}`, knowledgebase description is `{self.knowledgebase.description}`). You can use it to answer questions. If any questions need
64+
you to look up the knowledgebase, you should call load_knowledgebase function with a query.
65+
"""
66+
]
67+
)
68+
69+
async def load_knowledgebase(
70+
self, query: str, tool_context: ToolContext
71+
) -> LoadKnowledgebaseResponse:
72+
"""Loads the knowledgebase for the user.
73+
74+
Args:
75+
query: The query to load the knowledgebase for.
76+
77+
Returns:
78+
A list of knowledgebase results.
79+
"""
80+
logger.info(f"Search knowledgebase: {self.knowledgebase.name}")
81+
response = self.knowledgebase.search(query)
82+
logger.info(f"Loaded {len(response)} knowledgebase entries for query: {query}")
83+
return LoadKnowledgebaseResponse(knowledges=response)

0 commit comments

Comments
 (0)