Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/graphon/nodes/question_classifier/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class ClassConfig(BaseModel):
id: str
name: str
label: str | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest adding some comments to the three fields, especially for name and field since they are quite similar and confusing.

The name field is actually the description of the correspond class, AFAIK.



class QuestionClassifierNodeData(BaseNodeData):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def __init__(
def version(cls):
return "1"

@staticmethod
def _default_class_label(index: int) -> str:
return f"CLASS {index}"

def _run(self):
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
Expand Down Expand Up @@ -191,17 +195,25 @@ def _run(self):

category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
category_label = rendered_classes[0].label or self._default_class_label(1)
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = rendered_classes
classes_map = {class_.id: class_.name for class_ in classes}
classes_map = {
class_.id: {
"name": class_.name,
"label": class_.label or self._default_class_label(index + 1),
}
for index, class_ in enumerate(classes)
}
category_ids = [_class.id for _class in classes]
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_name = classes_map[category_id_result]["name"]
category_label = classes_map[category_id_result]["label"]
category_id = category_id_result
process_data = {
"model_mode": node_data.model.mode,
Expand All @@ -215,6 +227,7 @@ def _run(self):
}
outputs = {
"class_name": category_name,
"class_label": category_label,
"class_id": category_id,
"usage": jsonable_encoder(usage),
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

from graphon.model_runtime.entities import ImagePromptMessageContent
from graphon.model_runtime.entities import ImagePromptMessageContent, LLMUsage
from graphon.node_events import ModelInvokeCompletedEvent
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
from graphon.nodes.protocols import HttpClientProtocol
from graphon.nodes.question_classifier import (
QuestionClassifierNode,
QuestionClassifierNodeData,
)
from graphon.nodes.question_classifier.question_classifier_node import llm_utils
from graphon.template_rendering import Jinja2TemplateRenderer
from tests.workflow_test_utils import build_test_graph_init_params

Expand All @@ -17,7 +19,7 @@ def test_init_question_classifier_node_data():
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"classes": [{"id": "1", "name": "class 1", "label": "CLASS 1"}],
"instruction": "This is a test instruction",
"memory": {
"role_prefix": {"user": "Human:", "assistant": "AI:"},
Expand All @@ -32,6 +34,7 @@ def test_init_question_classifier_node_data():
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.classes[0].label == "CLASS 1"
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
Expand Down Expand Up @@ -64,6 +67,7 @@ def test_init_question_classifier_node_data_without_vision_config():
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
assert node_data.classes[0].id == "1"
assert node_data.classes[0].label is None
assert node_data.instruction == "This is a test instruction"
assert node_data.memory is not None
assert node_data.memory.role_prefix is not None
Expand Down Expand Up @@ -124,3 +128,138 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon
)

assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer


def _create_question_classifier_node_for_run(
node_data: QuestionClassifierNodeData,
*,
variable_pool: MagicMock,
template_renderer: MagicMock,
) -> QuestionClassifierNode:
return QuestionClassifierNode(
id="node-id",
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
graph_init_params=build_test_graph_init_params(
workflow_id="workflow-id",
graph_config={},
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
),
graph_runtime_state=SimpleNamespace(variable_pool=variable_pool),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(
provider="openai",
model_name="gpt-4o",
stop=(),
parameters={},
),
http_client=MagicMock(spec=HttpClientProtocol),
llm_file_saver=MagicMock(),
template_renderer=template_renderer,
)


def test_question_classifier_run_returns_class_label_separately(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["start", "sys.query"],
"model": {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
"classes": [
{"id": "1", "name": "billing questions", "label": "Billing"},
{"id": "2", "name": "refund requests", "label": "Refund desk"},
],
"instruction": "Classify the question",
}
)
variable_pool = MagicMock()
variable_pool.get.return_value = SimpleNamespace(value="Where is my refund?")
variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
template_renderer = MagicMock(spec=Jinja2TemplateRenderer)
node = _create_question_classifier_node_for_run(
node_data,
variable_pool=variable_pool,
template_renderer=template_renderer,
)

monkeypatch.setattr(llm_utils, "resolve_completion_params_variables", lambda parameters, _: parameters)
monkeypatch.setattr(
llm_utils,
"fetch_prompt_messages",
MagicMock(return_value=([], None)),
)
monkeypatch.setattr(node, "_calculate_rest_token", MagicMock(return_value=1024))
monkeypatch.setattr(node, "_get_prompt_template", MagicMock(return_value=[]))
monkeypatch.setattr(
"graphon.nodes.question_classifier.question_classifier_node.LLMNode.invoke_llm",
lambda **_: iter(
[
ModelInvokeCompletedEvent(
text='{"category_id": "2", "category_name": "refund requests"}',
usage=LLMUsage.empty_usage(),
finish_reason="stop",
)
]
),
)

result = node._run()

assert result.outputs["class_name"] == "refund requests"
assert result.outputs["class_label"] == "Refund desk"
assert result.outputs["class_id"] == "2"
assert result.edge_source_handle == "2"


def test_question_classifier_run_falls_back_to_canonical_class_label(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["start", "sys.query"],
"model": {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
"classes": [
{"id": "1", "name": "billing questions", "label": "Billing"},
{"id": "2", "name": "refund requests"},
],
"instruction": "Classify the question",
}
)
variable_pool = MagicMock()
variable_pool.get.return_value = SimpleNamespace(value="Where is my refund?")
variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
template_renderer = MagicMock(spec=Jinja2TemplateRenderer)
node = _create_question_classifier_node_for_run(
node_data,
variable_pool=variable_pool,
template_renderer=template_renderer,
)

monkeypatch.setattr(llm_utils, "resolve_completion_params_variables", lambda parameters, _: parameters)
monkeypatch.setattr(
llm_utils,
"fetch_prompt_messages",
MagicMock(return_value=([], None)),
)
monkeypatch.setattr(node, "_calculate_rest_token", MagicMock(return_value=1024))
monkeypatch.setattr(node, "_get_prompt_template", MagicMock(return_value=[]))
monkeypatch.setattr(
"graphon.nodes.question_classifier.question_classifier_node.LLMNode.invoke_llm",
lambda **_: iter(
[
ModelInvokeCompletedEvent(
text='{"category_id": "2", "category_name": "refund requests"}',
usage=LLMUsage.empty_usage(),
finish_reason="stop",
)
]
),
)

result = node._run()

assert result.outputs["class_name"] == "refund requests"
assert result.outputs["class_label"] == "CLASS 2"
assert result.outputs["class_id"] == "2"
assert result.edge_source_handle == "2"
4 changes: 4 additions & 0 deletions web/app/components/workflow/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ export const QUESTION_CLASSIFIER_OUTPUT_STRUCT = [
variable: 'class_name',
type: VarType.string,
},
{
variable: 'class_label',
type: VarType.string,
},
{
variable: 'usage',
type: VarType.object,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
'use client'
import type { FC, ReactNode } from 'react'
import * as React from 'react'
import { cn } from '@/utils/classnames'

type Props = {
title: string
content: ReactNode
titleClassName?: string
}

const InfoPanel: FC<Props> = ({
title,
content,
titleClassName,
}) => {
return (
<div>
<div className="flex flex-col gap-y-0.5 rounded-md bg-workflow-block-parma-bg px-[5px] py-[3px]">
<div className="system-2xs-semibold-uppercase uppercase text-text-secondary">
<div className={cn('system-2xs-semibold-uppercase uppercase text-text-secondary', titleClassName)}>
{title}
</div>
<div className="system-xs-regular break-words text-text-tertiary">
Expand Down
Loading
Loading