@@ -12,7 +12,7 @@ jest.mock('express', () => {
1212 listen : jest . fn ( )
1313 } ;
1414 const express = jest . fn ( ( ) => mockApp ) ;
15- express . json = jest . fn ( ) ;
15+ ( express as any ) . json = jest . fn ( ) ;
1616 return express ;
1717} ) ;
1818
@@ -65,7 +65,8 @@ describe('HttpSseTransport', () => {
6565 } ;
6666
6767 mockTransport = {
68- handleMessage : jest . fn ( )
68+ handleMessage : jest . fn ( ) ,
69+ sessionId : 'mock-session-id'
6970 } ;
7071
7172 mockMcpServer = {
@@ -74,7 +75,7 @@ describe('HttpSseTransport', () => {
7475
7576 // Configure mocks
7677 mockExpress . mockReturnValue ( mockApp ) ;
77- mockExpress . json = jest . fn ( ) . mockReturnValue ( 'json-middleware' ) ;
78+ ( mockExpress as any ) . json = jest . fn ( ) . mockReturnValue ( 'json-middleware' ) ;
7879 mockCreateCorsMiddleware . mockReturnValue ( 'cors-middleware' ) ;
7980 mockSSEServerTransport . mockImplementation ( ( ) => mockTransport ) ;
8081
@@ -129,120 +130,218 @@ describe('HttpSseTransport', () => {
129130 describe ( 'SSE endpoint' , ( ) => {
130131 it ( 'should create SSE transport and connect to MCP server' , ( ) => {
131132 const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
132-
133+
133134 // Mock MCP server
134135 ( transport as any ) . mcpServer = mockMcpServer ;
135-
136- const mockReq = {
136+
137+ const mockReq = {
137138 ip : '127.0.0.1' ,
138139 on : jest . fn ( )
139140 } ;
140141 const mockRes = { } ;
141-
142+
142143 // Get the SSE endpoint handler
143144 const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
144145 sseHandler ( mockReq , mockRes ) ;
145-
146+
146147 expect ( consoleErrorSpy ) . toHaveBeenCalledWith ( 'SSE connection from 127.0.0.1' ) ;
147148 expect ( mockMcpServer . connect ) . toHaveBeenCalledWith ( mockTransport ) ;
148149 expect ( mockReq . on ) . toHaveBeenCalledWith ( 'close' , expect . any ( Function ) ) ;
149150 } ) ;
150151
151152 it ( 'should handle connection without MCP server' , ( ) => {
152153 new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
153-
154- const mockReq = {
154+
155+ const mockReq = {
155156 ip : '127.0.0.1' ,
156157 on : jest . fn ( )
157158 } ;
158159 const mockRes = { } ;
159-
160+
160161 // Get the SSE endpoint handler
161162 const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
162163 sseHandler ( mockReq , mockRes ) ;
163-
164+
164165 expect ( consoleErrorSpy ) . toHaveBeenCalledWith ( 'SSE connection from 127.0.0.1' ) ;
165166 expect ( mockMcpServer . connect ) . not . toHaveBeenCalled ( ) ;
166167 } ) ;
167168
168169 it ( 'should clean up connection on close' , ( ) => {
169170 new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
170-
171- const mockReq = {
171+
172+ const mockReq = {
172173 ip : '127.0.0.1' ,
173174 on : jest . fn ( )
174175 } ;
175176 const mockRes = { } ;
176-
177+
177178 // Get the SSE endpoint handler
178179 const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
179180 sseHandler ( mockReq , mockRes ) ;
180-
181+
181182 // Get the close handler
182183 const closeHandler = mockReq . on . mock . calls . find ( call => call [ 0 ] === 'close' ) [ 1 ] ;
183184 closeHandler ( ) ;
184-
185- expect ( consoleErrorSpy ) . toHaveBeenCalledWith ( expect . stringMatching ( / S S E c o n n e c t i o n c l o s e d f o r s e s s i o n \d + / ) ) ;
185+
186+ expect ( consoleErrorSpy ) . toHaveBeenCalledWith (
187+ `SSE connection closed for session ${ mockTransport . sessionId } `
188+ ) ;
189+ } ) ;
190+
191+ it ( 'should use the transport sessionId as the connections key' , ( ) => {
192+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
193+
194+ const mockReq = { ip : '127.0.0.1' , on : jest . fn ( ) } ;
195+ const mockRes = { } ;
196+
197+ const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
198+ sseHandler ( mockReq , mockRes ) ;
199+
200+ expect ( mockSSEServerTransport ) . toHaveBeenCalledWith ( '/sse/message' , mockRes ) ;
201+ expect ( ( transport as any ) . connections . has ( mockTransport . sessionId ) ) . toBe ( true ) ;
202+ } ) ;
203+
204+ it ( 'should assign unique session IDs to concurrent connections' , ( ) => {
205+ const mockTransport1 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-1' } ;
206+ const mockTransport2 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-2' } ;
207+ mockSSEServerTransport
208+ . mockImplementationOnce ( ( ) => mockTransport1 )
209+ . mockImplementationOnce ( ( ) => mockTransport2 ) ;
210+
211+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
212+ ( transport as any ) . mcpServer = mockMcpServer ;
213+
214+ const mockReq1 = { ip : '127.0.0.1' , on : jest . fn ( ) } ;
215+ const mockReq2 = { ip : '127.0.0.2' , on : jest . fn ( ) } ;
216+
217+ const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
218+ sseHandler ( mockReq1 , { } ) ;
219+ sseHandler ( mockReq2 , { } ) ;
220+
221+ expect ( ( transport as any ) . connections . has ( 'session-id-1' ) ) . toBe ( true ) ;
222+ expect ( ( transport as any ) . connections . has ( 'session-id-2' ) ) . toBe ( true ) ;
223+ expect ( ( transport as any ) . connections . size ) . toBe ( 2 ) ;
186224 } ) ;
187225 } ) ;
188226
189227 describe ( 'Message endpoint' , ( ) => {
190228 it ( 'should handle message with active transport' , async ( ) => {
191229 const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
192-
230+
193231 // Add a connection to the transport
194232 ( transport as any ) . connections . set ( 'test-session' , mockTransport ) ;
195-
196- const mockReq = { body : { test : 'message' } } ;
197- const mockRes = {
233+
234+ const mockReq = { body : { test : 'message' } , query : { sessionId : 'test-session' } } ;
235+ const mockRes = {
198236 status : jest . fn ( ) . mockReturnThis ( ) ,
199237 end : jest . fn ( ) ,
200238 json : jest . fn ( )
201239 } ;
202-
240+
203241 // Get the message endpoint handler
204242 const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
205243 await messageHandler ( mockReq , mockRes ) ;
206-
244+
207245 expect ( mockTransport . handleMessage ) . toHaveBeenCalledWith ( { test : 'message' } ) ;
208246 expect ( mockRes . status ) . toHaveBeenCalledWith ( 200 ) ;
209247 expect ( mockRes . end ) . toHaveBeenCalled ( ) ;
210248 } ) ;
211249
212- it ( 'should return 404 when no active transport ' , async ( ) => {
250+ it ( 'should return 400 when sessionId query param is missing ' , async ( ) => {
213251 new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
214-
215- const mockReq = { body : { test : 'message' } } ;
216- const mockRes = {
252+
253+ const mockReq = { body : { test : 'message' } , query : { } } ;
254+ const mockRes = {
217255 status : jest . fn ( ) . mockReturnThis ( ) ,
218256 json : jest . fn ( )
219257 } ;
220-
258+
221259 // Get the message endpoint handler
222260 const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
223261 await messageHandler ( mockReq , mockRes ) ;
224-
262+
263+ expect ( mockRes . status ) . toHaveBeenCalledWith ( 400 ) ;
264+ expect ( mockRes . json ) . toHaveBeenCalledWith ( { error : 'Missing sessionId' } ) ;
265+ } ) ;
266+
267+ it ( 'should return 404 when sessionId does not match any active session' , async ( ) => {
268+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
269+ ( transport as any ) . connections . set ( 'real-session-id' , mockTransport ) ;
270+
271+ const mockReq = { body : { test : 'message' } , query : { sessionId : 'bogus-session-id' } } ;
272+ const mockRes = {
273+ status : jest . fn ( ) . mockReturnThis ( ) ,
274+ json : jest . fn ( )
275+ } ;
276+
277+ // Get the message endpoint handler
278+ const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
279+ await messageHandler ( mockReq , mockRes ) ;
280+
225281 expect ( mockRes . status ) . toHaveBeenCalledWith ( 404 ) ;
226- expect ( mockRes . json ) . toHaveBeenCalledWith ( { error : 'No active SSE connection' } ) ;
282+ expect ( mockRes . json ) . toHaveBeenCalledWith ( { error : 'Session not found' } ) ;
283+ expect ( mockTransport . handleMessage ) . not . toHaveBeenCalled ( ) ;
284+ } ) ;
285+
286+ it ( 'should route message only to the transport matching the sessionId' , async ( ) => {
287+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
288+
289+ const mockTransportA = { handleMessage : jest . fn ( ) } ;
290+ const mockTransportB = { handleMessage : jest . fn ( ) } ;
291+ ( transport as any ) . connections . set ( 'session-a' , mockTransportA ) ;
292+ ( transport as any ) . connections . set ( 'session-b' , mockTransportB ) ;
293+
294+ const mockReq = { body : { jsonrpc : '2.0' , method : 'ping' , id : 1 } , query : { sessionId : 'session-b' } } ;
295+ const mockRes = {
296+ status : jest . fn ( ) . mockReturnThis ( ) ,
297+ end : jest . fn ( ) ,
298+ json : jest . fn ( )
299+ } ;
300+
301+ // Get the message endpoint handler
302+ const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
303+ await messageHandler ( mockReq , mockRes ) ;
304+
305+ expect ( mockTransportB . handleMessage ) . toHaveBeenCalledWith ( { jsonrpc : '2.0' , method : 'ping' , id : 1 } ) ;
306+ expect ( mockTransportA . handleMessage ) . not . toHaveBeenCalled ( ) ;
307+ expect ( mockRes . status ) . toHaveBeenCalledWith ( 200 ) ;
308+ } ) ;
309+
310+ it ( 'should reject a request even when another valid session exists' , async ( ) => {
311+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
312+ ( transport as any ) . connections . set ( 'legitimate-session' , mockTransport ) ;
313+
314+ // Attacker sends request with no sessionId
315+ const mockReq = { body : { jsonrpc : '2.0' , method : 'tools/call' , id : 2 } , query : { } } ;
316+ const mockRes = {
317+ status : jest . fn ( ) . mockReturnThis ( ) ,
318+ json : jest . fn ( )
319+ } ;
320+
321+ const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
322+ await messageHandler ( mockReq , mockRes ) ;
323+
324+ expect ( mockRes . status ) . toHaveBeenCalledWith ( 400 ) ;
325+ expect ( mockTransport . handleMessage ) . not . toHaveBeenCalled ( ) ;
227326 } ) ;
228327
229328 it ( 'should handle transport errors' , async ( ) => {
230329 const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
231-
330+
232331 // Add a connection that will throw an error
233332 mockTransport . handleMessage . mockRejectedValue ( new Error ( 'Transport error' ) ) ;
234333 ( transport as any ) . connections . set ( 'test-session' , mockTransport ) ;
235-
236- const mockReq = { body : { test : 'message' } } ;
237- const mockRes = {
334+
335+ const mockReq = { body : { test : 'message' } , query : { sessionId : 'test-session' } } ;
336+ const mockRes = {
238337 status : jest . fn ( ) . mockReturnThis ( ) ,
239338 json : jest . fn ( )
240339 } ;
241-
340+
242341 // Get the message endpoint handler
243342 const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
244343 await messageHandler ( mockReq , mockRes ) ;
245-
344+
246345 expect ( consoleErrorSpy ) . toHaveBeenCalledWith ( 'Error handling message:' , expect . any ( Error ) ) ;
247346 expect ( mockRes . status ) . toHaveBeenCalledWith ( 500 ) ;
248347 expect ( mockRes . json ) . toHaveBeenCalledWith ( { error : 'Internal server error' } ) ;
@@ -418,37 +517,56 @@ describe('HttpSseTransport', () => {
418517 } ) ;
419518
420519 it ( 'should handle multiple connections and cleanup' , ( ) => {
520+ const mockTransport1 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-1' } ;
521+ const mockTransport2 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-2' } ;
522+ mockSSEServerTransport
523+ . mockImplementationOnce ( ( ) => mockTransport1 )
524+ . mockImplementationOnce ( ( ) => mockTransport2 ) ;
525+
421526 const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
422527 ( transport as any ) . mcpServer = mockMcpServer ;
423-
424- // Mock Date.now to return different values for different connections
425- const originalDateNow = Date . now ;
426- let callCount = 0 ;
427- Date . now = jest . fn ( ( ) => {
428- callCount ++ ;
429- return 1000 + callCount ; // Return different timestamps
430- } ) ;
431-
528+
432529 // Simulate multiple SSE connections
433530 const mockReq1 = { ip : '127.0.0.1' , on : jest . fn ( ) } ;
434531 const mockReq2 = { ip : '127.0.0.2' , on : jest . fn ( ) } ;
435- const mockRes1 = { } ;
436- const mockRes2 = { } ;
437-
532+
438533 const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
439- sseHandler ( mockReq1 , mockRes1 ) ;
440- sseHandler ( mockReq2 , mockRes2 ) ;
441-
534+ sseHandler ( mockReq1 , { } ) ;
535+ sseHandler ( mockReq2 , { } ) ;
536+
442537 expect ( ( transport as any ) . connections . size ) . toBe ( 2 ) ;
443-
538+
444539 // Close first connection
445540 const closeHandler1 = mockReq1 . on . mock . calls . find ( call => call [ 0 ] === 'close' ) [ 1 ] ;
446541 closeHandler1 ( ) ;
447-
542+
448543 expect ( ( transport as any ) . connections . size ) . toBe ( 1 ) ;
449-
450- // Restore Date.now
451- Date . now = originalDateNow ;
544+ expect ( ( transport as any ) . connections . has ( 'session-id-2' ) ) . toBe ( true ) ;
545+ } ) ;
546+
547+ it ( 'should not route a message to the first connected client when the request targets a second client' , async ( ) => {
548+ const mockTransport1 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-1' } ;
549+ const mockTransport2 = { handleMessage : jest . fn ( ) , sessionId : 'session-id-2' } ;
550+ mockSSEServerTransport
551+ . mockImplementationOnce ( ( ) => mockTransport1 )
552+ . mockImplementationOnce ( ( ) => mockTransport2 ) ;
553+
554+ const transport = new HttpSseTransport ( 3000 , 'localhost' , '/sse' ) ;
555+ ( transport as any ) . mcpServer = mockMcpServer ;
556+
557+ const sseHandler = mockApp . get . mock . calls . find ( call => call [ 0 ] === '/sse' ) [ 1 ] ;
558+ sseHandler ( { ip : '127.0.0.1' , on : jest . fn ( ) } , { } ) ;
559+ sseHandler ( { ip : '127.0.0.2' , on : jest . fn ( ) } , { } ) ;
560+
561+ // POST targeting client 2 — client 1's transport must not be invoked
562+ const mockReq = { body : { jsonrpc : '2.0' , method : 'tools/call' , id : 99 } , query : { sessionId : 'session-id-2' } } ;
563+ const mockRes = { status : jest . fn ( ) . mockReturnThis ( ) , end : jest . fn ( ) , json : jest . fn ( ) } ;
564+
565+ const messageHandler = mockApp . post . mock . calls . find ( call => call [ 0 ] === '/sse/message' ) [ 1 ] ;
566+ await messageHandler ( mockReq , mockRes ) ;
567+
568+ expect ( mockTransport2 . handleMessage ) . toHaveBeenCalledWith ( { jsonrpc : '2.0' , method : 'tools/call' , id : 99 } ) ;
569+ expect ( mockTransport1 . handleMessage ) . not . toHaveBeenCalled ( ) ;
452570 } ) ;
453571 } ) ;
454572} ) ;
0 commit comments