Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import type {
MalformedRunInfo,
WorkflowPort,
} from '../ports/workflow-port';
import type { StepUser } from '../types/execution-context';
import type { CollectionSchema } from '../types/validated/collection';
import type { StepOutcome } from '../types/validated/step-outcome';
import type { ToolConfig } from '@forestadmin/ai-proxy';
Expand Down Expand Up @@ -226,7 +225,7 @@ export default class ForestServerWorkflowPort implements WorkflowPort {
);
}

async hasRunAccess(runId: string, user: StepUser): Promise<boolean> {
async hasRunAccess(runId: string, user: { id: number }): Promise<boolean> {
return this.callPort('hasRunAccess', async () => {
const { hasAccess } = await ServerUtils.query<{ hasAccess: boolean }>(
this.options,
Expand Down
10 changes: 10 additions & 0 deletions packages/workflow-executor/src/http/bearer-claims.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { z } from 'zod';

// Claims we require from the decoded bearer JWT payload (ctx.state.user). Non-strict on purpose:
// jsonwebtoken adds standard claims (iat/exp) and Forest may send extra ones — strict would reject
// every real token. Only `id` is consumed downstream (handleTrigger + hasRunAccess → ?userId=).
// `.int()` rejects NaN/Infinity/floats (a user id is an integer) — preserves the invariant the
// previous Number.isFinite guard enforced.
export const BearerClaimsSchema = z.object({ id: z.number().int() });

export type BearerClaims = z.infer<typeof BearerClaimsSchema>;
37 changes: 28 additions & 9 deletions packages/workflow-executor/src/http/executor-http-server.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { Logger } from '../ports/logger-port';
import type { WorkflowPort } from '../ports/workflow-port';
import type Runner from '../runner';
import type { StepUser } from '../types/execution-context';
import type { Server } from 'http';

import bodyParser from '@koa/bodyparser';
Expand All @@ -10,10 +9,11 @@ import http from 'http';
import Koa from 'koa';
import koaJwt from 'koa-jwt';

import { type BearerClaims, BearerClaimsSchema } from './bearer-claims';
import {
BadRequestHttpError,
ForbiddenHttpError,
ServiceUnavailableHttpError,
UnauthorizedHttpError,
toHttpError,
} from './http-errors';
import serializeStepForWire from './step-serializer';
Expand Down Expand Up @@ -100,6 +100,29 @@ export default class ExecutorHttpServer {
koaJwt({ secret: options.authSecret, cookie: 'forest_session_token', tokenKey: 'rawToken' }),
);

// koa-jwt only validates the token's signature/expiry, not its payload shape. Validate the
// claims once, here, so every handler downstream gets a user with a guaranteed numeric id.
this.app.use(async (ctx, next) => {
const claims = BearerClaimsSchema.safeParse(ctx.state.user);

if (!claims.success) {
// A token koa-jwt accepted (valid signature) but whose payload is malformed is rare and
// high-signal (token-issuance regression / version skew / forgery probe) — log it, unlike
// ordinary expired-token churn. Only the issue paths/codes, never the payload (PII).
this.logger.warn('Bearer token has invalid claims', {
method: ctx.method,
path: ctx.path,
issues: claims.error.issues.map(issue => ({ path: issue.path, code: issue.code })),
});

throw new UnauthorizedHttpError();
}

ctx.state.user = { ...ctx.state.user, ...claims.data };

await next();
});

const router = new Router();

// hasRunAccess authorization — only on GET (read-only route).
Expand Down Expand Up @@ -147,7 +170,7 @@ export default class ExecutorHttpServer {
}

private async hasRunAccessMiddleware(ctx: Koa.Context, next: Koa.Next): Promise<void> {
const user = ctx.state.user as StepUser;
const user = ctx.state.user as BearerClaims;
let allowed: boolean;

try {
Expand Down Expand Up @@ -178,12 +201,8 @@ export default class ExecutorHttpServer {

private async handleTrigger(ctx: Koa.Context): Promise<void> {
const { runId } = ctx.params;
const rawId = (ctx.state.user as { id?: unknown })?.id;
const bearerUserId = typeof rawId === 'number' ? rawId : Number(rawId);

if (!Number.isFinite(bearerUserId)) {
throw new BadRequestHttpError('Missing or invalid user id in token');
}
// Guaranteed a number by the bearer-claims middleware.
const bearerUserId = (ctx.state.user as BearerClaims).id;

const pendingData = (ctx.request.body as { pendingData?: unknown })?.pendingData;

Expand Down
6 changes: 4 additions & 2 deletions packages/workflow-executor/src/ports/workflow-port.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { AvailableStepExecution, StepUser } from '../types/execution-context';
import type { AvailableStepExecution } from '../types/execution-context';
import type { CollectionSchema } from '../types/validated/collection';
import type { StepOutcome } from '../types/validated/step-outcome';
import type { ToolConfig } from '@forestadmin/ai-proxy';
Expand Down Expand Up @@ -36,5 +36,7 @@ export interface WorkflowPort {
): Promise<AvailableRunDispatch | null>;
getCollectionSchema(collectionName: string, runId: string): Promise<CollectionSchema>;
getMcpServerConfigs(): Promise<Record<string, ToolConfig>>;
hasRunAccess(runId: string, user: StepUser): Promise<boolean>;
// Only the user id is needed (the access check is `?userId=`); kept narrow so callers don't
// have to produce a full StepUser.
hasRunAccess(runId: string, user: { id: number }): Promise<boolean>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1055,19 +1055,7 @@ describe('ForestServerWorkflowPort', () => {
it('propagates errors from hasRunAccess', async () => {
mockQuery.mockRejectedValue(new Error('Network error'));

await expect(
port.hasRunAccess('run-42', {
id: 1,
email: 'test@example.com',
firstName: 'Test',
lastName: 'User',
team: 'admin',
renderingId: 1,
role: 'admin',
permissionLevel: 'admin',
tags: {},
}),
).rejects.toThrow('Network error');
await expect(port.hasRunAccess('run-42', { id: 1 })).rejects.toThrow('Network error');
});

it('propagates errors from updateStepExecution', async () => {
Expand Down
32 changes: 32 additions & 0 deletions packages/workflow-executor/test/http/bearer-claims.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { BearerClaimsSchema } from '../../src/http/bearer-claims';

describe('BearerClaimsSchema', () => {
it('accepts a payload with a numeric id', () => {
const result = BearerClaimsSchema.safeParse({ id: 1 });

expect(result.success).toBe(true);
});

it('tolerates standard JWT claims and extra Forest claims (not strict)', () => {
// jsonwebtoken adds iat/exp to the decoded payload — a strict schema would wrongly reject it.
const result = BearerClaimsSchema.safeParse({
id: 1,
iat: 1_700_000_000,
exp: 1_700_003_600,
email: 'admin@forest.com',
role: 'admin',
});

expect(result.success).toBe(true);
});

it.each([
['no id', { email: 'no-id@forest.com' }],
['non-numeric id', { id: 'user-42' }],
['null id', { id: null }],
['a non-integer id', { id: 1.5 }],
['empty payload', {}],
])('rejects a payload with %s', (_, payload) => {
expect(BearerClaimsSchema.safeParse(payload).success).toBe(false);
});
});
48 changes: 44 additions & 4 deletions packages/workflow-executor/test/http/executor-http-server.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ describe('ExecutorHttpServer', () => {
});

describe('run access authorization', () => {
it('returns 401 on GET /runs/:runId when the token has invalid claims, before the access check', async () => {
const workflowPort = createMockWorkflowPort();
const server = createServer({ workflowPort });
const token = signToken({ email: 'no-id@example.com' });

const response = await request(server.callback)
.get('/runs/run-1')
.set('Authorization', `Bearer ${token}`);

expect(response.status).toBe(401);
expect(response.body).toEqual({ error: 'Unauthorized' });
expect(workflowPort.hasRunAccess).not.toHaveBeenCalled();
});

it('returns 403 when hasRunAccess returns false on GET /runs/:runId', async () => {
const workflowPort = createMockWorkflowPort({
hasRunAccess: jest.fn().mockResolvedValue(false),
Expand Down Expand Up @@ -430,18 +444,44 @@ describe('ExecutorHttpServer', () => {
});
});

it('returns 400 when token has no numeric id', async () => {
it('returns 401 and logs the invalid claims when the token carries no numeric id', async () => {
const logger = { info: jest.fn(), warn: jest.fn(), error: jest.fn() };
const runner = createMockRunner();
const server = createServer({ runner });
const server = createServer({ runner, logger });
const token = signToken({ email: 'no-id@example.com' });

const response = await request(server.callback)
.post('/runs/run-1/trigger')
.set('Authorization', `Bearer ${token}`);

expect(response.status).toBe(400);
expect(response.body).toEqual({ error: 'Missing or invalid user id in token' });
expect(response.status).toBe(401);
expect(response.body).toEqual({ error: 'Unauthorized' });
expect(runner.triggerPoll).not.toHaveBeenCalled();
// The malformed-but-signed token is surfaced (issue paths/codes only, never the payload).
expect(logger.warn).toHaveBeenCalledWith(
'Bearer token has invalid claims',
expect.objectContaining({
method: 'POST',
path: '/runs/run-1/trigger',
issues: [expect.objectContaining({ path: ['id'], code: 'invalid_type' })],
}),
);
});

it('accepts a token carrying extra claims beyond id (non-strict validation)', async () => {
const runner = createMockRunner();
const server = createServer({ runner });
const token = signToken({ id: 1, role: 'admin', team: 'ops' });

const response = await request(server.callback)
.post('/runs/run-1/trigger')
.set('Authorization', `Bearer ${token}`);

expect(response.status).toBe(200);
expect(runner.triggerPoll).toHaveBeenCalledWith('run-1', {
pendingData: undefined,
bearerUserId: 1,
});
});

it('passes pendingData from request body to runner.triggerPoll', async () => {
Expand Down
Loading