Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changeset/workflow-control-ownership.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@voltagent/server-core": patch
"@voltagent/server-hono": patch
"@voltagent/serverless-hono": patch
"@voltagent/server-elysia": patch
---

Validate workflow ownership before suspend and cancel control routes act on an execution.

Fixes #1316.
14 changes: 13 additions & 1 deletion packages/server-core/src/handlers/workflow.handlers.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { ServerProviderDeps, WorkflowStateEntry } from "@voltagent/core";
import { describe, expect, it, vi } from "vitest";
import { handleListWorkflowRuns } from "./workflow.handlers";
import { createWorkflowControlRequestBody, handleListWorkflowRuns } from "./workflow.handlers";

function createWorkflowState(
id: string,
Expand Down Expand Up @@ -190,3 +190,15 @@ describe("handleListWorkflowRuns", () => {
);
});
});

describe("createWorkflowControlRequestBody", () => {
it("rejects primitive JSON bodies and annotates object bodies with the route workflow id", () => {
expect(createWorkflowControlRequestBody("invalid", "wf-1")).toBeUndefined();
expect(createWorkflowControlRequestBody(1, "wf-1")).toBeUndefined();
expect(createWorkflowControlRequestBody(["invalid"], "wf-1")).toBeUndefined();
expect(createWorkflowControlRequestBody({ reason: "pause" }, "wf-1")).toEqual({
__workflowId: "wf-1",
reason: "pause",
});
});
});
56 changes: 54 additions & 2 deletions packages/server-core/src/handlers/workflow.handlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,51 @@ export async function handleAttachWorkflowStream(
}
}

async function isWorkflowExecutionOwnedByRoute(
body: WorkflowControlRequestBody | undefined,
executionId: string,
deps: ServerProviderDeps,
) {
const workflowId = body?.__workflowId;
if (typeof workflowId !== "string" || workflowId.trim().length === 0) {
return false;
}

return (
(
await deps.workflowRegistry
.getWorkflow(workflowId)
?.workflow.memory.getWorkflowState(executionId)
)?.workflowId === workflowId
);
}

export type WorkflowControlRequestBody = Record<string, unknown> & {
__workflowId: string;
reason?: string;
};

export function createWorkflowControlRequestBody(
body: unknown,
workflowId: string,
): WorkflowControlRequestBody | undefined {
if (!body || typeof body !== "object" || Array.isArray(body)) {
return undefined;
}

return {
...(body as Record<string, unknown>),
__workflowId: workflowId,
};
}

/**
* Handler for suspending a workflow
* Returns suspension result
*/
export async function handleSuspendWorkflow(
executionId: string,
body: any,
body: WorkflowControlRequestBody | undefined,
deps: ServerProviderDeps,
logger: Logger,
): Promise<ApiResponse> {
Expand All @@ -650,6 +688,13 @@ export async function handleSuspendWorkflow(
};
}

if (!(await isWorkflowExecutionOwnedByRoute(body, executionId, deps))) {
return {
success: false,
error: "Workflow execution not found or already completed",
};
}

const suspendController = deps.workflowRegistry.activeExecutions.get(executionId);

if (!suspendController) {
Expand Down Expand Up @@ -694,7 +739,7 @@ export async function handleSuspendWorkflow(
*/
export async function handleCancelWorkflow(
executionId: string,
body: any,
body: WorkflowControlRequestBody | undefined,
deps: ServerProviderDeps,
logger: Logger,
): Promise<ApiResponse> {
Expand All @@ -708,6 +753,13 @@ export async function handleCancelWorkflow(
};
}

if (!(await isWorkflowExecutionOwnedByRoute(body, executionId, deps))) {
return {
success: false,
error: "No active execution found or workflow already completed",
};
}

const suspendController = deps.workflowRegistry.activeExecutions.get(executionId);

if (!suspendController) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ import type { Logger } from "@voltagent/internal";
import { describe, expect, it, vi } from "vitest";
import type { ErrorResponse } from "../types/responses";
import { isErrorResponse } from "../types/responses";
import { handleAttachWorkflowStream, handleStreamWorkflow } from "./workflow.handlers";
import {
handleAttachWorkflowStream,
handleCancelWorkflow,
handleStreamWorkflow,
handleSuspendWorkflow,
} from "./workflow.handlers";

type ParsedSSEEvent = {
id?: string;
Expand Down Expand Up @@ -48,7 +53,7 @@ function createDeps(options?: {
const stream = options?.streamFactory ? options.streamFactory : vi.fn();

const workflow = {
createSuspendController: vi.fn().mockReturnValue(createSuspendController()),
createSuspendController: vi.fn().mockImplementation(createSuspendController),
stream,
memory: {
getWorkflowState,
Expand Down Expand Up @@ -120,6 +125,65 @@ function assertErrorResponse(
expect(isErrorResponse(value)).toBe(true);
}

async function startBlockingWorkflowExecution(action: string) {
const logger = createLogger();
let releaseStream = () => {};
const streamBlocked = new Promise<void>((resolve) => {
releaseStream = resolve;
});
const executionId = `exec-${action}-1`;
const { deps } = createDeps({
workflowState: {
id: executionId,
workflowId: "wf-1",
workflowName: "Workflow 1",
status: "running",
createdAt: new Date(),
updatedAt: new Date(),
},
streamFactory: () => ({
executionId,
[Symbol.asyncIterator]: async function* () {
yield {
type: "workflow-start",
executionId,
from: "Workflow 1",
status: "running",
timestamp: new Date().toISOString(),
};
await streamBlocked;
},
result: Promise.resolve({ ok: true }),
status: Promise.resolve("completed"),
endAt: Promise.resolve(new Date("2026-01-01T00:00:00.000Z")),
}),
});

const streamResponse = await handleStreamWorkflow("wf-1", { input: {} }, deps, logger);
expect(isErrorResponse(streamResponse)).toBe(false);

if (isErrorResponse(streamResponse)) {
throw new Error("Expected active workflow stream");
}

const streamReader = streamResponse.getReader();
await readSSEEvent(streamReader);

const controller = deps.workflowRegistry.activeExecutions?.get(executionId) as ReturnType<
typeof createSuspendController
>;
expect(controller).toBeDefined();

return {
controller,
deps,
executionId,
logger,
releaseStream,
streamReader,
};
}

describe("workflow stream attach handler", () => {
it("returns 404 when workflow does not exist", async () => {
const logger = createLogger();
Expand Down Expand Up @@ -260,4 +324,63 @@ describe("workflow stream attach handler", () => {
const attachedFinal = await readSSEEvent(attachedReader);
expect(attachedFinal.data.type).toBe("workflow-result");
});

it.each([
{
action: "suspend",
error: "not found",
handler: handleSuspendWorkflow,
method: "suspend",
},
{
action: "cancel",
error: "No active execution found",
handler: handleCancelWorkflow,
method: "cancel",
},
])(
"rejects $action requests on the wrong workflow route",
async ({ action, error, handler, method }) => {
const { controller, deps, executionId, logger, releaseStream, streamReader } =
await startBlockingWorkflowExecution(action);

const wrongRouteResponse = await handler(
executionId,
{ __workflowId: "wf-2", reason: "wrong route" },
deps,
logger,
);

expect(wrongRouteResponse.success).toBe(false);
expect((wrongRouteResponse as ErrorResponse).error).toContain(error);
expect(controller[method]).not.toHaveBeenCalled();
expect(deps.workflowRegistry.activeExecutions?.has(executionId)).toBe(true);

const missingRouteResponse = await handler(
executionId,
{ reason: "missing route id" } as any,
deps,
logger,
);

expect(missingRouteResponse.success).toBe(false);
expect((missingRouteResponse as ErrorResponse).error).toContain(error);
expect(controller[method]).not.toHaveBeenCalled();
expect(deps.workflowRegistry.activeExecutions?.has(executionId)).toBe(true);

const correctRouteResponse = await handler(
executionId,
{ __workflowId: "wf-1", reason: "correct route" },
deps,
logger,
);

expect(correctRouteResponse.success).toBe(true);
expect(controller[method]).toHaveBeenCalledWith("correct route");
expect(deps.workflowRegistry.activeExecutions?.has(executionId)).toBe(false);

releaseStream();
await streamReader.cancel();
},
);
});
15 changes: 13 additions & 2 deletions packages/server-elysia/src/routes/workflow.routes.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { ServerProviderDeps } from "@voltagent/core";
import type { Logger } from "@voltagent/internal";
import {
createWorkflowControlRequestBody,
handleAttachWorkflowStream,
handleCancelWorkflow,
handleExecuteWorkflow,
Expand Down Expand Up @@ -287,7 +288,12 @@ export function registerWorkflowRoutes(
app.post(
"/workflows/:id/executions/:executionId/suspend",
async ({ params, body, set }) => {
const response = await handleSuspendWorkflow(params.executionId, body, deps, logger);
const response = await handleSuspendWorkflow(
params.executionId,
createWorkflowControlRequestBody(body, params.id),
deps,
logger,
);
if (!response.success) {
const errorMessage = response.error || "";
set.status = errorMessage.includes("not found")
Expand Down Expand Up @@ -320,7 +326,12 @@ export function registerWorkflowRoutes(
app.post(
"/workflows/:id/executions/:executionId/cancel",
async ({ params, body, set }) => {
const response = await handleCancelWorkflow(params.executionId, body, deps, logger);
const response = await handleCancelWorkflow(
params.executionId,
createWorkflowControlRequestBody(body, params.id),
deps,
logger,
);
if (!response.success) {
const errorMessage = response.error || "";
set.status = errorMessage.includes("not found")
Expand Down
5 changes: 3 additions & 2 deletions packages/server-hono/src/routes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { ServerProviderDeps } from "@voltagent/core";
import type { Logger } from "@voltagent/internal";
import {
UPDATE_ROUTES,
createWorkflowControlRequestBody,
handleAttachWorkflowStream,
handleCancelWorkflow,
handleChatStream,
Expand Down Expand Up @@ -412,7 +413,7 @@ export function registerWorkflowRoutes(
if (!executionId) {
throw new Error("Missing execution id parameter");
}
const body = await c.req.json();
const body = createWorkflowControlRequestBody(await c.req.json(), c.req.param("id"));
const response = await handleSuspendWorkflow(executionId, body, deps, logger);
if (!response.success) {
return c.json(response, 500);
Expand All @@ -426,7 +427,7 @@ export function registerWorkflowRoutes(
if (!executionId) {
throw new Error("Missing execution id parameter");
}
const body = await c.req.json();
const body = createWorkflowControlRequestBody(await c.req.json(), c.req.param("id"));
const response = await handleCancelWorkflow(executionId, body, deps, logger);
if (!response.success) {
const errorMessage = response.error || "";
Expand Down
5 changes: 3 additions & 2 deletions packages/serverless-hono/src/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
type TriggerHttpRequestContext,
UPDATE_ROUTES,
WORKFLOW_ROUTES,
createWorkflowControlRequestBody,
executeA2ARequest,
executeTriggerHandler,
getConversationMessagesHandler,
Expand Down Expand Up @@ -528,7 +529,7 @@ export function registerWorkflowRoutes(app: Hono, deps: ServerProviderDeps, logg

app.post(WORKFLOW_ROUTES.suspendWorkflow.path, async (c) => {
const executionId = c.req.param("executionId");
const body = await readJsonBody(c, logger);
const body = createWorkflowControlRequestBody(await readJsonBody(c, logger), c.req.param("id"));
if (!body) {
return c.json({ success: false, error: "Invalid JSON body" }, 400);
}
Expand All @@ -547,7 +548,7 @@ export function registerWorkflowRoutes(app: Hono, deps: ServerProviderDeps, logg

app.post(WORKFLOW_ROUTES.cancelWorkflow.path, async (c) => {
const executionId = c.req.param("executionId");
const body = await readJsonBody(c, logger);
const body = createWorkflowControlRequestBody(await readJsonBody(c, logger), c.req.param("id"));
if (!body) {
return c.json({ success: false, error: "Invalid JSON body" }, 400);
}
Expand Down
Loading