-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Expand file tree
/
Copy pathworkflow_run_service.py
More file actions
executable file
·298 lines (271 loc) · 11.3 KB
/
workflow_run_service.py
File metadata and controls
executable file
·298 lines (271 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""Service responsible for executing workflows for WebSocket sessions."""
import logging
from pathlib import Path
from typing import List, Optional, Union
from check.check import load_config
from entity.graph_config import GraphConfig
from entity.messages import Message
from entity.enums import LogLevel
from utils.exceptions import ValidationError, WorkflowCancelledError
from utils.visibility_bridge import _reset_sync
from utils.structured_logger import get_server_logger, LogType
from utils.task_input import TaskInputBuilder
from workflow.graph_context import GraphContext
from server.services.attachment_service import AttachmentService
from server.services.session_execution import SessionExecutionController
from server.services.session_store import SessionStatus, WorkflowSessionStore
from server.services.websocket_executor import WebSocketGraphExecutor
from server.services.workflow_storage import validate_workflow_filename
from server.settings import WARE_HOUSE_DIR, YAML_DIR
class WorkflowRunService:
def __init__(
self,
session_store: WorkflowSessionStore,
session_controller: SessionExecutionController,
attachment_service: AttachmentService,
) -> None:
self.session_store = session_store
self.session_controller = session_controller
self.attachment_service = attachment_service
self.logger = logging.getLogger(__name__)
def request_cancel(self, session_id: str, *, reason: Optional[str] = None) -> bool:
session = self.session_store.get_session(session_id)
if not session:
return False
cancel_message = reason or "Cancellation requested"
session.cancel_reason = cancel_message
if not session.cancel_event.is_set():
session.cancel_event.set()
self.logger.info("Cancellation requested for session %s", session_id)
if session.executor:
try:
session.executor.request_cancel(cancel_message)
except Exception as exc:
self.logger.warning("Failed to propagate cancellation to executor for %s: %s", session_id, exc)
self.session_store.update_session_status(session_id, SessionStatus.CANCELLED, error_message=cancel_message)
return True
async def start_workflow(
self,
session_id: str,
yaml_file: str,
task_prompt: str,
websocket_manager,
*,
attachments: Optional[List[str]] = None,
log_level: Optional[LogLevel] = None,
) -> None:
normalized_yaml_name = (yaml_file or "").strip()
try:
yaml_path = self._resolve_yaml_path(normalized_yaml_name)
normalized_yaml_name = yaml_path.name
attachments = attachments or []
if (not task_prompt or not task_prompt.strip()) and not attachments:
raise ValidationError(
"Task prompt cannot be empty",
details={"task_prompt_provided": bool(task_prompt)},
)
self.attachment_service.prepare_session_workspace(session_id)
self.session_store.create_session(
yaml_file=normalized_yaml_name,
task_prompt=task_prompt,
session_id=session_id,
attachments=attachments,
)
self.session_store.update_session_status(session_id, SessionStatus.RUNNING)
await websocket_manager.send_message(
session_id,
{
"type": "workflow_started",
"data": {"yaml_file": normalized_yaml_name, "task_prompt": task_prompt},
},
)
await self._execute_workflow_async(
session_id,
yaml_path,
task_prompt,
websocket_manager,
attachments,
log_level,
)
except ValidationError as exc:
self.logger.error(str(exc))
logger = get_server_logger()
logger.error(
"Workflow validation error",
log_type=LogType.WORKFLOW,
session_id=session_id,
yaml_file=normalized_yaml_name,
validation_details=getattr(exc, "details", None),
)
self.session_store.set_session_error(session_id, str(exc))
await websocket_manager.send_message(
session_id,
{"type": "error", "data": {"message": str(exc)}},
)
except Exception as exc:
self.logger.error(f"Error starting workflow for session {session_id}: {exc}")
logger = get_server_logger()
logger.log_exception(
exc,
"Error starting workflow",
session_id=session_id,
yaml_file=normalized_yaml_name,
)
self.session_store.set_session_error(session_id, str(exc))
await websocket_manager.send_message(
session_id,
{
"type": "error",
"data": {"message": f"Failed to start workflow: {exc}"},
},
)
async def _execute_workflow_async(
self,
session_id: str,
yaml_path: Path,
task_prompt: str,
websocket_manager,
attachments: List[str],
log_level: LogLevel,
) -> None:
session = self.session_store.get_session(session_id)
cancel_event = session.cancel_event if session else None
try:
design = load_config(yaml_path)
graph_config = GraphConfig.from_definition(
design.graph,
name=f"session_{session_id}",
output_root=WARE_HOUSE_DIR,
source_path=str(yaml_path),
vars=design.vars,
)
if log_level:
graph_config.log_level = log_level
graph_config.definition.log_level = log_level
graph_context = GraphContext(config=graph_config)
executor = WebSocketGraphExecutor(
graph_context,
session_id,
self.session_controller,
self.attachment_service,
websocket_manager,
self.session_store,
cancel_event=cancel_event,
task_prompt=task_prompt,
)
if session:
session.graph = graph_context
session.executor = executor
if session.cancel_event.is_set():
executor.request_cancel(session.cancel_reason or "Cancellation requested")
task_input = self._build_initial_task_input(
session_id,
graph_context,
task_prompt,
attachments,
executor.attachment_store,
)
_reset_sync() # clear dashboard before every run
await executor.execute_graph_async(task_input)
# If cancellation was requested during execution but not raised inside threads,
# treat the run as cancelled to avoid conflicting status.
if cancel_event and cancel_event.is_set():
reason = session.cancel_reason if session else "Cancellation requested"
raise WorkflowCancelledError(reason, workflow_id=graph_context.name)
results = executor.get_results()
self.session_store.complete_session(session_id, results)
await websocket_manager.send_message(
session_id,
{
"type": "workflow_completed",
"data": {
"results": results,
"summary": graph_context.final_message(),
"token_usage": executor.token_tracker.get_token_usage(),
},
},
)
logger = get_server_logger()
logger.info(
"Workflow execution completed successfully",
log_type=LogType.WORKFLOW,
session_id=session_id,
yaml_path=str(yaml_path),
result_count=len(results) if isinstance(results, dict) else 0,
)
except WorkflowCancelledError as exc:
reason = str(exc)
self.session_store.update_session_status(session_id, SessionStatus.CANCELLED, error_message=reason)
await websocket_manager.send_message(
session_id,
{
"type": "workflow_cancelled",
"data": {"message": reason},
},
)
logger = get_server_logger()
logger.info(
"Workflow execution cancelled",
log_type=LogType.WORKFLOW,
session_id=session_id,
yaml_path=str(yaml_path),
cancellation_reason=reason,
)
except ValidationError as exc:
self.session_store.set_session_error(session_id, str(exc))
await websocket_manager.send_message(
session_id,
{"type": "error", "data": {"message": str(exc)}},
)
logger = get_server_logger()
logger.error(
"Workflow validation error",
log_type=LogType.WORKFLOW,
session_id=session_id,
yaml_path=str(yaml_path),
validation_details=getattr(exc, "details", None),
)
except Exception as exc:
self.session_store.set_session_error(session_id, str(exc))
await websocket_manager.send_message(
session_id,
{"type": "error", "data": {"message": f"Workflow execution error: {exc}"}},
)
logger = get_server_logger()
logger.log_exception(
exc,
f"Error executing workflow for session {session_id}",
session_id=session_id,
yaml_path=str(yaml_path),
)
finally:
session_ref = self.session_store.get_session(session_id)
if session_ref:
session_ref.executor = None
session_ref.graph = None
self.session_controller.cleanup_session(session_id)
if session_id not in websocket_manager.active_connections:
self.session_store.pop_session(session_id)
def _build_initial_task_input(
self,
session_id: str,
graph_context: GraphContext,
prompt: str,
attachment_ids: List[str],
store,
) -> Union[List[Message], str]:
if not attachment_ids:
return prompt
blocks = self.attachment_service.build_attachment_blocks(
session_id,
attachment_ids,
target_store=store,
)
return TaskInputBuilder(store).build_from_blocks(prompt, blocks)
def _resolve_yaml_path(self, yaml_filename: str) -> Path:
"""Validate and resolve YAML paths inside the configured directory."""
safe_name = validate_workflow_filename(yaml_filename, require_yaml_extension=True)
yaml_path = YAML_DIR / safe_name
if not yaml_path.exists():
raise ValidationError("YAML file not found", details={"yaml_file": safe_name})
return yaml_path