-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Expand file tree
/
Copy pathwebsocket_executor.py
More file actions
executable file
·76 lines (63 loc) · 2.92 KB
/
websocket_executor.py
File metadata and controls
executable file
·76 lines (63 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""GraphExecutor variant that reports results over WebSocket."""
import asyncio
from typing import List
from utils.logger import WorkflowLogger
from workflow.graph import GraphExecutor
from workflow.graph_context import GraphContext
from server.services.attachment_service import AttachmentService
from server.services.artifact_dispatcher import ArtifactDispatcher
from server.services.prompt_channel import WebPromptChannel
from server.services.session_store import WorkflowSessionStore
from server.services.session_execution import SessionExecutionController
from workflow.hooks.workspace_artifact import WorkspaceArtifact, WorkspaceArtifactHook
class WebSocketGraphExecutor(GraphExecutor):
"""GraphExecutor subclass that emits events via WebSocket."""
def __init__(
self,
graph: GraphContext,
session_id: str,
session_controller: SessionExecutionController,
attachment_service: AttachmentService,
websocket_manager,
session_store: WorkflowSessionStore,
cancel_event=None,
task_prompt: str = None,
):
self.session_id = session_id
self.session_controller = session_controller
self.attachment_service = attachment_service
self.websocket_manager = websocket_manager
self.session_store = session_store
self.results = {}
self.task_prompt = task_prompt
self.artifact_dispatcher = ArtifactDispatcher(session_id, session_store, websocket_manager)
def hook_factory(runtime_context):
prompt_channel = WebPromptChannel(
session_id=session_id,
session_controller=session_controller,
websocket_manager=websocket_manager,
attachment_service=attachment_service,
attachment_store=runtime_context.attachment_store,
)
return WorkspaceArtifactHook(
attachment_store=runtime_context.attachment_store,
emit_callback=self._handle_workspace_artifacts,
prompt_channel=prompt_channel,
)
super().__init__(
graph,
session_id=session_id,
workspace_hook_factory=hook_factory,
cancel_event=cancel_event,
)
def _create_logger(self) -> WorkflowLogger:
from server.services.websocket_logger import WebSocketLogger
return WebSocketLogger(self.websocket_manager, self.session_id, self.graph.name,
self.graph.log_level, task_prompt=self.task_prompt,
graph_config=self.graph.config)
async def execute_graph_async(self, task_prompt):
await asyncio.get_event_loop().run_in_executor(None, self._execute, task_prompt)
def get_results(self):
return self.outputs
def _handle_workspace_artifacts(self, artifacts: List[WorkspaceArtifact]) -> None:
self.artifact_dispatcher.emit_workspace_artifacts(artifacts)