forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquestion_classifier_node.py
More file actions
408 lines (372 loc) · 16.5 KB
/
question_classifier_node.py
File metadata and controls
408 lines (372 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import json
import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole
from graphon.model_runtime.memory import PromptMessageMemory
from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult
from graphon.nodes.base.entities import VariableSelector
from graphon.nodes.base.node import Node
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
from graphon.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol
from graphon.nodes.protocols import HttpClientProtocol
from graphon.template_rendering import Jinja2TemplateRenderer
from graphon.utils.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
QUESTION_CLASSIFIER_USER_PROMPT_1,
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from graphon.file.models import File
from graphon.runtime import GraphRuntimeState
class _PassthroughPromptMessageSerializer:
def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any:
_ = model_mode
return list(prompt_messages)
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_prompt_message_serializer: PromptMessageSerializerProtocol
_model_instance: PreparedLLMProtocol
_memory: PromptMessageMemory | None
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: PreparedLLMProtocol,
http_client: HttpClientProtocol,
template_renderer: Jinja2TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver,
prompt_message_serializer: PromptMessageSerializerProtocol | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
_ = credentials_provider, model_factory, http_client
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
self._llm_file_saver = llm_file_saver
self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer()
@classmethod
def version(cls):
return "1"
@staticmethod
def _default_class_label(index: int) -> str:
return f"CLASS {index}"
def _run(self):
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
# fetch prompt messages
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
memory=memory,
max_token_limit=rest_token,
)
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
model_instance=model_instance,
stop=model_instance.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
try:
# handle invoke result
generator = LLMNode.invoke_llm(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
rendered_classes = [
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
]
category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
category_label = rendered_classes[0].label or self._default_class_label(1)
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = rendered_classes
classes_map = {
class_.id: {
"name": class_.name,
"label": class_.label or self._default_class_label(index + 1),
}
for index, class_ in enumerate(classes)
}
category_ids = [_class.id for _class in classes]
if category_id_result in category_ids:
category_name = classes_map[category_id_result]["name"]
category_label = classes_map[category_id_result]["label"]
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
"prompts": self._prompt_message_serializer.serialize(
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
}
outputs = {
"class_name": category_name,
"class_label": category_label,
"class_id": category_id,
"usage": jsonable_encoder(usage),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=category_id,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@property
def model_instance(self) -> PreparedLLMProtocol:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).
:return:
"""
# filters parameter is not used in this node type
return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
query: str,
model_instance: PreparedLLMProtocol,
context: str | None,
) -> int:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
context=context or "",
memory=None,
model_instance=model_instance,
stop=model_instance.stop,
memory_config=node_data.memory,
vision_enabled=False,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
rest_tokens = 2000
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = LLMMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category)
instruction = node_data.instruction or ""
input_text = query
memory_str = ""
if memory:
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == LLMMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == LLMMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")