Skip to content

Commit 9774062

Browse files
committed
feat(mcp): implement experimental tasks API for streaming workflow status to chat
Signed-off-by: betterclever <paliwal.pranjal83@gmail.com>
1 parent 557918a commit 9774062

2 files changed

Lines changed: 197 additions & 58 deletions

File tree

backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
44
import type { AuthContext } from '../../auth/types';
55
import type { WorkflowsService } from '../../workflows/workflows.service';
66

7-
// Helper to access private _registeredTools on McpServer (plain object at runtime)
8-
type ToolHandler = (...args: unknown[]) => unknown;
9-
type RegisteredToolsMap = Record<string, { handler: ToolHandler }>;
7+
// Helper to access private _registeredTools and experimental tasks on McpServer (plain object at runtime)
8+
type RegisteredToolsMap = Record<string, any>;
9+
1010
function getRegisteredTools(server: McpServer): RegisteredToolsMap {
1111
return (server as unknown as { _registeredTools: RegisteredToolsMap })._registeredTools;
1212
}
@@ -60,12 +60,11 @@ describe('StudioMcpService Unit Tests', () => {
6060
expect(server).toBeInstanceOf(McpServer);
6161
});
6262

63-
it('registers all 9 expected tools', () => {
63+
it('registers all expected tools and tasks', () => {
6464
const server = service.createServer(mockAuthContext);
6565
const registeredTools = getRegisteredTools(server);
6666

6767
expect(registeredTools).toBeDefined();
68-
expect(Object.keys(registeredTools).length).toBe(9);
6968

7069
const toolNames = Object.keys(registeredTools).sort();
7170
expect(toolNames).toEqual([
@@ -110,16 +109,27 @@ describe('StudioMcpService Unit Tests', () => {
110109
expect(workflowsService.findById).toHaveBeenCalledWith(workflowId, mockAuthContext);
111110
});
112111

113-
it('run_workflow tool uses auth context passed at creation time', async () => {
112+
it('run_workflow task uses auth context passed at creation time', async () => {
114113
const workflowId = '11111111-1111-4111-8111-111111111111';
115114
const inputs = { key: 'value' };
116115

117116
const server = service.createServer(mockAuthContext);
118117
const registeredTools = getRegisteredTools(server);
119-
const runWorkflowTool = registeredTools['run_workflow'];
118+
const runWorkflowTask = registeredTools['run_workflow'];
119+
120+
expect(runWorkflowTask).toBeDefined();
121+
122+
// Need to mock the extra params for the experimental tasks
123+
const mockExtra = {
124+
taskStore: {
125+
createTask: jest.fn().mockResolvedValue({ taskId: 'mockTaskId', status: 'working' }),
126+
getTask: jest.fn().mockResolvedValue({ taskId: 'mockTaskId', status: 'working' }),
127+
updateTaskStatus: jest.fn().mockResolvedValue(true),
128+
storeTaskResult: jest.fn().mockResolvedValue(true),
129+
},
130+
};
120131

121-
expect(runWorkflowTool).toBeDefined();
122-
await runWorkflowTool.handler({ workflowId, inputs });
132+
await runWorkflowTask.handler.createTask({ workflowId, inputs }, mockExtra);
123133

124134
expect(workflowsService.run).toHaveBeenCalledWith(
125135
workflowId,
@@ -129,7 +139,7 @@ describe('StudioMcpService Unit Tests', () => {
129139
trigger: {
130140
type: 'api',
131141
sourceId: mockAuthContext.userId,
132-
label: 'Studio MCP',
142+
label: 'Studio MCP Task',
133143
},
134144
},
135145
);
@@ -230,12 +240,21 @@ describe('StudioMcpService Unit Tests', () => {
230240

231241
it('denies run_workflow when workflows.run is false', async () => {
232242
const server = service.createServer(restrictedAuth);
233-
const tools = getRegisteredTools(server);
234-
const result = (await tools['run_workflow'].handler({
235-
workflowId: '11111111-1111-4111-8111-111111111111',
236-
})) as { isError?: boolean; content: { text: string }[] };
237-
expect(result.isError).toBe(true);
238-
expect(result.content[0].text).toContain('workflows.run');
243+
const tasks = getRegisteredTools(server);
244+
245+
let errorThrown = false;
246+
try {
247+
await tasks['run_workflow'].handler.createTask(
248+
{
249+
workflowId: '11111111-1111-4111-8111-111111111111',
250+
},
251+
{} as any,
252+
);
253+
} catch (_e: any) {
254+
errorThrown = true;
255+
expect(_e.message).toContain('workflows.run');
256+
}
257+
expect(errorThrown).toBe(true);
239258
});
240259

241260
it('denies cancel_run when runs.cancel is false', async () => {
@@ -260,15 +279,28 @@ describe('StudioMcpService Unit Tests', () => {
260279
it('allows all tools when no apiKeyPermissions (non-API-key auth)', async () => {
261280
const server = service.createServer(mockAuthContext); // no apiKeyPermissions
262281
const tools = getRegisteredTools(server);
282+
const tasks = getRegisteredTools(server);
263283

264284
// All workflow/run tools should work without permission errors
265285
const listResult = (await tools['list_workflows'].handler({})) as { isError?: boolean };
266286
expect(listResult.isError).toBeUndefined();
267287

268-
const runResult = (await tools['run_workflow'].handler({
269-
workflowId: '11111111-1111-4111-8111-111111111111',
270-
})) as { isError?: boolean };
271-
expect(runResult.isError).toBeUndefined();
288+
const mockExtra = {
289+
taskStore: {
290+
createTask: jest.fn().mockResolvedValue({ taskId: 'mock', status: 'working' }),
291+
getTask: jest.fn().mockResolvedValue({ taskId: 'mock', status: 'working' }),
292+
updateTaskStatus: jest.fn().mockResolvedValue(true),
293+
storeTaskResult: jest.fn().mockResolvedValue(true),
294+
},
295+
};
296+
297+
const runResult = await tasks['run_workflow'].handler.createTask(
298+
{
299+
workflowId: '11111111-1111-4111-8111-111111111111',
300+
},
301+
mockExtra,
302+
);
303+
expect(runResult.task.taskId).toEqual('mock');
272304

273305
const cancelResult = (await tools['cancel_run'].handler({
274306
runId: 'test-run-id',
@@ -308,11 +340,11 @@ describe('StudioMcpService Unit Tests', () => {
308340
};
309341
const server = service.createServer(noPermsAuth);
310342
const tools = getRegisteredTools(server);
343+
const tasks = getRegisteredTools(server);
311344

312345
const gatedTools = [
313346
'list_workflows',
314347
'get_workflow',
315-
'run_workflow',
316348
'list_runs',
317349
'get_run_status',
318350
'get_run_result',
@@ -326,6 +358,20 @@ describe('StudioMcpService Unit Tests', () => {
326358
})) as { isError?: boolean };
327359
expect(result.isError).toBe(true);
328360
}
361+
362+
// Test run_workflow separately since it's a task now
363+
let errorThrown = false;
364+
try {
365+
await tasks['run_workflow'].handler.createTask(
366+
{
367+
workflowId: '11111111-1111-4111-8111-111111111111',
368+
},
369+
{} as any,
370+
);
371+
} catch (_e: any) {
372+
errorThrown = true;
373+
}
374+
expect(errorThrown).toBe(true);
329375
});
330376
});
331377

backend/src/studio-mcp/studio-mcp.service.ts

Lines changed: 130 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -138,57 +138,71 @@ export class StudioMcpService {
138138
},
139139
);
140140

141-
server.registerTool(
141+
const runWorkflowSchema = {
142+
workflowId: z.string().uuid(),
143+
inputs: z.record(z.string(), z.unknown()).optional(),
144+
versionId: z.string().uuid().optional(),
145+
};
146+
147+
server.experimental.tasks.registerToolTask(
142148
'run_workflow',
143149
{
144150
description:
145-
'Start a workflow execution. Returns the run ID and initial status. Use get_run_status to poll for completion.',
146-
inputSchema: {
147-
workflowId: z.string().uuid(),
148-
inputs: z.record(z.string(), z.unknown()).optional(),
149-
versionId: z.string().uuid().optional(),
150-
},
151+
'Start a workflow execution as a background task. The task handle can be monitored for status updates, and finally retrieved for the workflow result. Also supports legacy polling via get_run_status.',
152+
inputSchema: runWorkflowSchema,
153+
execution: { taskSupport: 'optional' },
151154
},
152-
async (args: {
153-
workflowId: string;
154-
inputs?: Record<string, unknown>;
155-
versionId?: string;
156-
}) => {
157-
const gate = this.checkPermission(auth, 'workflows.run');
158-
if (!gate.allowed) return gate.error;
159-
try {
155+
{
156+
createTask: async (args, extra) => {
157+
const gate = this.checkPermission(auth, 'workflows.run');
158+
if (!gate.allowed) throw new Error(gate.error.content[0].text);
159+
160+
const task = await extra.taskStore.createTask({ ttl: 12 * 60 * 60 * 1000 });
161+
160162
const handle = await this.workflowsService.run(
161163
args.workflowId,
162-
{ inputs: args.inputs ?? {}, versionId: args.versionId },
164+
{
165+
inputs: args.inputs ?? {},
166+
versionId: args.versionId,
167+
},
163168
auth,
164169
{
165170
trigger: {
166171
type: 'api',
167172
sourceId: auth.userId ?? 'api-key',
168-
label: 'Studio MCP',
173+
label: 'Studio MCP Task',
169174
},
170175
},
171176
);
172-
return {
173-
content: [
174-
{
175-
type: 'text' as const,
176-
text: JSON.stringify(
177-
{
178-
runId: handle.runId,
179-
workflowId: handle.workflowId,
180-
status: handle.status,
181-
workflowVersion: handle.workflowVersion,
182-
},
183-
null,
184-
2,
185-
),
186-
},
187-
],
188-
};
189-
} catch (error) {
190-
return this.errorResult(error);
191-
}
177+
178+
this.monitorWorkflowRun(
179+
handle.runId,
180+
handle.temporalRunId,
181+
task.taskId,
182+
extra.taskStore,
183+
server,
184+
auth,
185+
).catch((err) => {
186+
this.logger.error(`Error monitoring workflow run task for run ${handle.runId}: ${err}`);
187+
});
188+
189+
return { task };
190+
},
191+
getTask: async (args, extra) => {
192+
const gate = this.checkPermission(auth, 'runs.read');
193+
if (!gate.allowed) throw new Error(gate.error.content[0].text);
194+
const task = await extra.taskStore.getTask(extra.taskId);
195+
if (!task) {
196+
throw new Error(`Task ${extra.taskId} not found`);
197+
}
198+
return task;
199+
},
200+
getTaskResult: async (args, extra) => {
201+
const gate = this.checkPermission(auth, 'runs.read');
202+
if (!gate.allowed) throw new Error(gate.error.content[0].text);
203+
const result = await extra.taskStore.getTaskResult(extra.taskId);
204+
return result as any;
205+
},
192206
},
193207
);
194208
}
@@ -397,6 +411,85 @@ export class StudioMcpService {
397411
);
398412
}
399413

414+
private async monitorWorkflowRun(
415+
runId: string,
416+
temporalRunId: string | undefined,
417+
taskId: string,
418+
taskStore: any,
419+
server: McpServer,
420+
auth: AuthContext,
421+
): Promise<void> {
422+
const isTerminal = (status: string) =>
423+
['COMPLETED', 'FAILED', 'CANCELLED', 'TERMINATED', 'TIMED_OUT'].includes(status);
424+
425+
const mapStatus = (status: string): 'working' | 'completed' | 'cancelled' | 'failed' => {
426+
switch (status) {
427+
case 'RUNNING':
428+
case 'QUEUED':
429+
case 'AWAITING_INPUT':
430+
return 'working';
431+
case 'COMPLETED':
432+
return 'completed';
433+
case 'CANCELLED':
434+
case 'TERMINATED':
435+
case 'TIMED_OUT':
436+
return 'cancelled';
437+
case 'FAILED':
438+
return 'failed';
439+
default:
440+
return 'working';
441+
}
442+
};
443+
444+
while (true) {
445+
try {
446+
const runStatusPayload = await this.workflowsService.getRunStatus(
447+
runId,
448+
temporalRunId,
449+
auth,
450+
);
451+
const taskState = mapStatus(runStatusPayload.status);
452+
453+
await taskStore.updateTaskStatus(taskId, taskState, runStatusPayload.status);
454+
455+
if (isTerminal(runStatusPayload.status)) {
456+
let resultData: any;
457+
if (taskState === 'completed') {
458+
try {
459+
resultData = await this.workflowsService.getRunResult(runId, temporalRunId, auth);
460+
} catch (err) {
461+
resultData = { error: String(err) };
462+
}
463+
} else {
464+
resultData = runStatusPayload.failure || { reason: runStatusPayload.status };
465+
}
466+
467+
const resultPayload = {
468+
content: [{ type: 'text', text: JSON.stringify(resultData, null, 2) }],
469+
};
470+
471+
const storeStatus = taskState === 'completed' ? 'completed' : 'failed';
472+
await taskStore.storeTaskResult(taskId, storeStatus, resultPayload);
473+
break;
474+
}
475+
476+
await new Promise((res) => setTimeout(res, 2000));
477+
} catch (err) {
478+
this.logger.error(`Error monitoring task ${taskId} (run: ${runId}): ${err}`);
479+
try {
480+
await taskStore.updateTaskStatus(taskId, 'failed', String(err));
481+
await taskStore.storeTaskResult(taskId, 'failed', {
482+
content: [{ type: 'text', text: `Failed to monitor workflow run: ${String(err)}` }],
483+
isError: true,
484+
});
485+
} catch (_updateErr) {
486+
// Ignore
487+
}
488+
break;
489+
}
490+
}
491+
}
492+
400493
// ---------------------------------------------------------------------------
401494
// Helpers
402495
// ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)