Skip to content

Commit 3ccb223

Browse files
committed
fix(mcp-server): enforce session binding and response isolation in HTTP SSE transport
1 parent 9168186 commit 3ccb223

2 files changed

Lines changed: 186 additions & 63 deletions

File tree

packages/mcp-server/src/server/transport/http-sse.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ export class HttpSseTransport implements McpTransport {
4040

4141
this.app.get(this.path, (req: Request, res: Response) => {
4242
console.error(`SSE connection from ${req.ip}`);
43-
43+
4444
// Create SSE transport - it will handle headers automatically
4545
const transport = new SSEServerTransport(`${this.path}/message`, res);
46-
const sessionId = Date.now().toString();
46+
const sessionId = transport.sessionId
4747
this.connections.set(sessionId, transport);
48-
48+
4949
// Connect transport to MCP server
5050
if (this.mcpServer) {
5151
this.mcpServer.connect(transport);
@@ -61,10 +61,15 @@ export class HttpSseTransport implements McpTransport {
6161
// Message endpoint for SSE transport
6262
this.app.post(`${this.path}/message`, async (req: Request, res: Response) => {
6363
try {
64-
// Find the first available transport (simple approach for now)
65-
const transport = Array.from(this.connections.values())[0];
64+
const sessionId = typeof req.query.sessionId === 'string' ? req.query.sessionId : undefined;
65+
if (!sessionId) {
66+
res.status(400).json({ error: 'Missing sessionId' });
67+
return;
68+
}
69+
70+
const transport = this.connections.get(sessionId);
6671
if (!transport) {
67-
res.status(404).json({ error: 'No active SSE connection' });
72+
res.status(404).json({ error: 'Session not found' });
6873
return;
6974
}
7075

packages/mcp-server/src/tests/server/transport/http-sse.test.ts

Lines changed: 175 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jest.mock('express', () => {
1212
listen: jest.fn()
1313
};
1414
const express = jest.fn(() => mockApp);
15-
express.json = jest.fn();
15+
(express as any).json = jest.fn();
1616
return express;
1717
});
1818

@@ -65,7 +65,8 @@ describe('HttpSseTransport', () => {
6565
};
6666

6767
mockTransport = {
68-
handleMessage: jest.fn()
68+
handleMessage: jest.fn(),
69+
sessionId: 'mock-session-id'
6970
};
7071

7172
mockMcpServer = {
@@ -74,7 +75,7 @@ describe('HttpSseTransport', () => {
7475

7576
// Configure mocks
7677
mockExpress.mockReturnValue(mockApp);
77-
mockExpress.json = jest.fn().mockReturnValue('json-middleware');
78+
(mockExpress as any).json = jest.fn().mockReturnValue('json-middleware');
7879
mockCreateCorsMiddleware.mockReturnValue('cors-middleware');
7980
mockSSEServerTransport.mockImplementation(() => mockTransport);
8081

@@ -129,120 +130,218 @@ describe('HttpSseTransport', () => {
129130
describe('SSE endpoint', () => {
130131
it('should create SSE transport and connect to MCP server', () => {
131132
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
132-
133+
133134
// Mock MCP server
134135
(transport as any).mcpServer = mockMcpServer;
135-
136-
const mockReq = {
136+
137+
const mockReq = {
137138
ip: '127.0.0.1',
138139
on: jest.fn()
139140
};
140141
const mockRes = {};
141-
142+
142143
// Get the SSE endpoint handler
143144
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
144145
sseHandler(mockReq, mockRes);
145-
146+
146147
expect(consoleErrorSpy).toHaveBeenCalledWith('SSE connection from 127.0.0.1');
147148
expect(mockMcpServer.connect).toHaveBeenCalledWith(mockTransport);
148149
expect(mockReq.on).toHaveBeenCalledWith('close', expect.any(Function));
149150
});
150151

151152
it('should handle connection without MCP server', () => {
152153
new HttpSseTransport(3000, 'localhost', '/sse');
153-
154-
const mockReq = {
154+
155+
const mockReq = {
155156
ip: '127.0.0.1',
156157
on: jest.fn()
157158
};
158159
const mockRes = {};
159-
160+
160161
// Get the SSE endpoint handler
161162
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
162163
sseHandler(mockReq, mockRes);
163-
164+
164165
expect(consoleErrorSpy).toHaveBeenCalledWith('SSE connection from 127.0.0.1');
165166
expect(mockMcpServer.connect).not.toHaveBeenCalled();
166167
});
167168

168169
it('should clean up connection on close', () => {
169170
new HttpSseTransport(3000, 'localhost', '/sse');
170-
171-
const mockReq = {
171+
172+
const mockReq = {
172173
ip: '127.0.0.1',
173174
on: jest.fn()
174175
};
175176
const mockRes = {};
176-
177+
177178
// Get the SSE endpoint handler
178179
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
179180
sseHandler(mockReq, mockRes);
180-
181+
181182
// Get the close handler
182183
const closeHandler = mockReq.on.mock.calls.find(call => call[0] === 'close')[1];
183184
closeHandler();
184-
185-
expect(consoleErrorSpy).toHaveBeenCalledWith(expect.stringMatching(/SSE connection closed for session \d+/));
185+
186+
expect(consoleErrorSpy).toHaveBeenCalledWith(
187+
`SSE connection closed for session ${mockTransport.sessionId}`
188+
);
189+
});
190+
191+
it('should use the transport sessionId as the connections key', () => {
192+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
193+
194+
const mockReq = { ip: '127.0.0.1', on: jest.fn() };
195+
const mockRes = {};
196+
197+
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
198+
sseHandler(mockReq, mockRes);
199+
200+
expect(mockSSEServerTransport).toHaveBeenCalledWith('/sse/message', mockRes);
201+
expect((transport as any).connections.has(mockTransport.sessionId)).toBe(true);
202+
});
203+
204+
it('should assign unique session IDs to concurrent connections', () => {
205+
const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' };
206+
const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' };
207+
mockSSEServerTransport
208+
.mockImplementationOnce(() => mockTransport1)
209+
.mockImplementationOnce(() => mockTransport2);
210+
211+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
212+
(transport as any).mcpServer = mockMcpServer;
213+
214+
const mockReq1 = { ip: '127.0.0.1', on: jest.fn() };
215+
const mockReq2 = { ip: '127.0.0.2', on: jest.fn() };
216+
217+
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
218+
sseHandler(mockReq1, {});
219+
sseHandler(mockReq2, {});
220+
221+
expect((transport as any).connections.has('session-id-1')).toBe(true);
222+
expect((transport as any).connections.has('session-id-2')).toBe(true);
223+
expect((transport as any).connections.size).toBe(2);
186224
});
187225
});
188226

189227
describe('Message endpoint', () => {
190228
it('should handle message with active transport', async () => {
191229
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
192-
230+
193231
// Add a connection to the transport
194232
(transport as any).connections.set('test-session', mockTransport);
195-
196-
const mockReq = { body: { test: 'message' } };
197-
const mockRes = {
233+
234+
const mockReq = { body: { test: 'message' }, query: { sessionId: 'test-session' } };
235+
const mockRes = {
198236
status: jest.fn().mockReturnThis(),
199237
end: jest.fn(),
200238
json: jest.fn()
201239
};
202-
240+
203241
// Get the message endpoint handler
204242
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
205243
await messageHandler(mockReq, mockRes);
206-
244+
207245
expect(mockTransport.handleMessage).toHaveBeenCalledWith({ test: 'message' });
208246
expect(mockRes.status).toHaveBeenCalledWith(200);
209247
expect(mockRes.end).toHaveBeenCalled();
210248
});
211249

212-
it('should return 404 when no active transport', async () => {
250+
it('should return 400 when sessionId query param is missing', async () => {
213251
new HttpSseTransport(3000, 'localhost', '/sse');
214-
215-
const mockReq = { body: { test: 'message' } };
216-
const mockRes = {
252+
253+
const mockReq = { body: { test: 'message' }, query: {} };
254+
const mockRes = {
217255
status: jest.fn().mockReturnThis(),
218256
json: jest.fn()
219257
};
220-
258+
221259
// Get the message endpoint handler
222260
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
223261
await messageHandler(mockReq, mockRes);
224-
262+
263+
expect(mockRes.status).toHaveBeenCalledWith(400);
264+
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Missing sessionId' });
265+
});
266+
267+
it('should return 404 when sessionId does not match any active session', async () => {
268+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
269+
(transport as any).connections.set('real-session-id', mockTransport);
270+
271+
const mockReq = { body: { test: 'message' }, query: { sessionId: 'bogus-session-id' } };
272+
const mockRes = {
273+
status: jest.fn().mockReturnThis(),
274+
json: jest.fn()
275+
};
276+
277+
// Get the message endpoint handler
278+
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
279+
await messageHandler(mockReq, mockRes);
280+
225281
expect(mockRes.status).toHaveBeenCalledWith(404);
226-
expect(mockRes.json).toHaveBeenCalledWith({ error: 'No active SSE connection' });
282+
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Session not found' });
283+
expect(mockTransport.handleMessage).not.toHaveBeenCalled();
284+
});
285+
286+
it('should route message only to the transport matching the sessionId', async () => {
287+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
288+
289+
const mockTransportA = { handleMessage: jest.fn() };
290+
const mockTransportB = { handleMessage: jest.fn() };
291+
(transport as any).connections.set('session-a', mockTransportA);
292+
(transport as any).connections.set('session-b', mockTransportB);
293+
294+
const mockReq = { body: { jsonrpc: '2.0', method: 'ping', id: 1 }, query: { sessionId: 'session-b' } };
295+
const mockRes = {
296+
status: jest.fn().mockReturnThis(),
297+
end: jest.fn(),
298+
json: jest.fn()
299+
};
300+
301+
// Get the message endpoint handler
302+
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
303+
await messageHandler(mockReq, mockRes);
304+
305+
expect(mockTransportB.handleMessage).toHaveBeenCalledWith({ jsonrpc: '2.0', method: 'ping', id: 1 });
306+
expect(mockTransportA.handleMessage).not.toHaveBeenCalled();
307+
expect(mockRes.status).toHaveBeenCalledWith(200);
308+
});
309+
310+
it('should reject a request even when another valid session exists', async () => {
311+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
312+
(transport as any).connections.set('legitimate-session', mockTransport);
313+
314+
// Attacker sends request with no sessionId
315+
const mockReq = { body: { jsonrpc: '2.0', method: 'tools/call', id: 2 }, query: {} };
316+
const mockRes = {
317+
status: jest.fn().mockReturnThis(),
318+
json: jest.fn()
319+
};
320+
321+
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
322+
await messageHandler(mockReq, mockRes);
323+
324+
expect(mockRes.status).toHaveBeenCalledWith(400);
325+
expect(mockTransport.handleMessage).not.toHaveBeenCalled();
227326
});
228327

229328
it('should handle transport errors', async () => {
230329
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
231-
330+
232331
// Add a connection that will throw an error
233332
mockTransport.handleMessage.mockRejectedValue(new Error('Transport error'));
234333
(transport as any).connections.set('test-session', mockTransport);
235-
236-
const mockReq = { body: { test: 'message' } };
237-
const mockRes = {
334+
335+
const mockReq = { body: { test: 'message' }, query: { sessionId: 'test-session' } };
336+
const mockRes = {
238337
status: jest.fn().mockReturnThis(),
239338
json: jest.fn()
240339
};
241-
340+
242341
// Get the message endpoint handler
243342
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
244343
await messageHandler(mockReq, mockRes);
245-
344+
246345
expect(consoleErrorSpy).toHaveBeenCalledWith('Error handling message:', expect.any(Error));
247346
expect(mockRes.status).toHaveBeenCalledWith(500);
248347
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Internal server error' });
@@ -418,37 +517,56 @@ describe('HttpSseTransport', () => {
418517
});
419518

420519
it('should handle multiple connections and cleanup', () => {
520+
const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' };
521+
const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' };
522+
mockSSEServerTransport
523+
.mockImplementationOnce(() => mockTransport1)
524+
.mockImplementationOnce(() => mockTransport2);
525+
421526
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
422527
(transport as any).mcpServer = mockMcpServer;
423-
424-
// Mock Date.now to return different values for different connections
425-
const originalDateNow = Date.now;
426-
let callCount = 0;
427-
Date.now = jest.fn(() => {
428-
callCount++;
429-
return 1000 + callCount; // Return different timestamps
430-
});
431-
528+
432529
// Simulate multiple SSE connections
433530
const mockReq1 = { ip: '127.0.0.1', on: jest.fn() };
434531
const mockReq2 = { ip: '127.0.0.2', on: jest.fn() };
435-
const mockRes1 = {};
436-
const mockRes2 = {};
437-
532+
438533
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
439-
sseHandler(mockReq1, mockRes1);
440-
sseHandler(mockReq2, mockRes2);
441-
534+
sseHandler(mockReq1, {});
535+
sseHandler(mockReq2, {});
536+
442537
expect((transport as any).connections.size).toBe(2);
443-
538+
444539
// Close first connection
445540
const closeHandler1 = mockReq1.on.mock.calls.find(call => call[0] === 'close')[1];
446541
closeHandler1();
447-
542+
448543
expect((transport as any).connections.size).toBe(1);
449-
450-
// Restore Date.now
451-
Date.now = originalDateNow;
544+
expect((transport as any).connections.has('session-id-2')).toBe(true);
545+
});
546+
547+
it('should not route a message to the first connected client when the request targets a second client', async () => {
548+
const mockTransport1 = { handleMessage: jest.fn(), sessionId: 'session-id-1' };
549+
const mockTransport2 = { handleMessage: jest.fn(), sessionId: 'session-id-2' };
550+
mockSSEServerTransport
551+
.mockImplementationOnce(() => mockTransport1)
552+
.mockImplementationOnce(() => mockTransport2);
553+
554+
const transport = new HttpSseTransport(3000, 'localhost', '/sse');
555+
(transport as any).mcpServer = mockMcpServer;
556+
557+
const sseHandler = mockApp.get.mock.calls.find(call => call[0] === '/sse')[1];
558+
sseHandler({ ip: '127.0.0.1', on: jest.fn() }, {});
559+
sseHandler({ ip: '127.0.0.2', on: jest.fn() }, {});
560+
561+
// POST targeting client 2 — client 1's transport must not be invoked
562+
const mockReq = { body: { jsonrpc: '2.0', method: 'tools/call', id: 99 }, query: { sessionId: 'session-id-2' } };
563+
const mockRes = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn() };
564+
565+
const messageHandler = mockApp.post.mock.calls.find(call => call[0] === '/sse/message')[1];
566+
await messageHandler(mockReq, mockRes);
567+
568+
expect(mockTransport2.handleMessage).toHaveBeenCalledWith({ jsonrpc: '2.0', method: 'tools/call', id: 99 });
569+
expect(mockTransport1.handleMessage).not.toHaveBeenCalled();
452570
});
453571
});
454572
});

0 commit comments

Comments
 (0)