-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathcore_engine.py
More file actions
145 lines (121 loc) · 4.58 KB
/
core_engine.py
File metadata and controls
145 lines (121 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
from typing import Any, Callable
from .registry import ToolRegistry
class VannaBase:
def __init__(
self,
config: dict,
llm_stub: Callable[[list[dict]], str] | None = None,
impl: Any | None = None,
):
self._config = config or {}
self._impl = impl
self._llm_stub = llm_stub
def _ensure_impl(self):
if self._impl is None:
from vanna.chromadb import ChromaDB_VectorStore
from vanna.vertexai import VertexAI_Chat
class _Impl(ChromaDB_VectorStore, VertexAI_Chat):
def __init__(self, config: dict):
ChromaDB_VectorStore.__init__(self, config=config)
VertexAI_Chat.__init__(self, config=config)
self._impl = _Impl(self._config)
return self._impl
def _submit_prompt(self, messages: list[dict]) -> str:
if self._llm_stub is not None:
return self._llm_stub(messages)
impl = self._ensure_impl()
return impl.submit_prompt(messages)
def _get_related_documentation(self, question: str):
impl = self._ensure_impl()
return impl.get_related_documentation(question)
def _get_similar_question_sql(self, question: str):
impl = self._ensure_impl()
return impl.get_similar_question_sql(question)
class IntentVanna(VannaBase):
def generate_envelope(self, question: str) -> dict[str, Any]:
try:
docs = self._get_related_documentation(question)
examples = self._get_similar_question_sql(question)
system_prompt = (
"Role: Semantic Parser.\n"
"Task: Map query to JSON based on Knowledge.\n"
f"Knowledge: {docs}\n"
f"Examples: {examples}\n"
"Output: JSON (IntentEnvelope)\n"
)
raw_resp = self._submit_prompt(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
)
envelope = json.loads(raw_resp)
if not isinstance(envelope, dict):
raise ValueError("IntentEnvelope must be a dict")
intent = envelope.get("intent")
if not intent:
raise ValueError("Missing intent")
payload = envelope.get("payload") or {}
debug_context = {
"docs": _summarize_debug(docs),
"examples": _summarize_debug(examples),
}
tool_name = ToolRegistry.get_tool_name(intent)
if tool_name == "unknown_tool":
m = getattr(ToolRegistry, "_intent_map", {})
if isinstance(m, dict) and intent in m:
tool_name = m[intent]
return {
"intent": intent,
"payload": payload,
"next_tool": tool_name,
"error": None,
"confidence": envelope.get("confidence"),
"debug_context": debug_context,
}
except Exception as exc:
return {
"intent": "unknown",
"payload": {},
"next_tool": "unknown_tool",
"error": str(exc),
"confidence": None,
"debug_context": None,
}
class SQLVanna(VannaBase):
def generate_sql(self, question: str) -> str:
impl = self._ensure_impl()
return impl.generate_sql(question=question)
def generate_sql_from_context(self, context: str) -> str:
return self.generate_sql(question=context)
def run_sql(self, sql: str):
impl = self._ensure_impl()
return impl.run_sql(sql)
class MockVannaImpl:
def __init__(
self,
docs: str = "DOCS",
examples: str = "EXAMPLES",
response: str = '{"intent":"query_metric","payload":{"metric":"revenue"}}',
):
self._docs = docs
self._examples = examples
self._response = response
def get_related_documentation(self, question: str):
return self._docs
def get_similar_question_sql(self, question: str):
return self._examples
def submit_prompt(self, messages: list[dict]):
return self._response
def generate_sql(self, question: str) -> str:
return "SELECT 1"
def run_sql(self, sql: str):
return [{"ok": True}]
def train(self, **kwargs):
return True
def _summarize_debug(value: Any, limit: int = 400):
text = str(value)
if len(text) <= limit:
return text
return text[:limit]