From 7c3dd6d563c7aa615ab0e61d898c1696d5af7361 Mon Sep 17 00:00:00 2001 From: sublime247 Date: Sat, 27 Jun 2026 12:35:59 +0100 Subject: [PATCH] feat(collaboration): add WebSocket message size limit validation to prevent memory exhaustion --- .env.example | 4 + package-lock.json | 35 +++ src/collaboration/collaboration.gateway.ts | 10 + src/collaboration/collaboration.module.ts | 11 +- .../ws-payload-size-guard.service.spec.ts | 244 ++++++++++++++++++ .../guards/ws-payload-size-guard.service.ts | 66 +++++ src/common/constants/time.constants.ts | 1 + src/main.ts | 26 ++ 8 files changed, 396 insertions(+), 1 deletion(-) create mode 100644 src/collaboration/guards/ws-payload-size-guard.service.spec.ts create mode 100644 src/collaboration/guards/ws-payload-size-guard.service.ts diff --git a/.env.example b/.env.example index 468e9ab2..d972f49d 100644 --- a/.env.example +++ b/.env.example @@ -326,6 +326,10 @@ REQUEST_BODY_LIMIT=1mb # Max file upload size (in bytes, default: 10MB) FILE_UPLOAD_MAX_BYTES=10485760 +# Max WebSocket message payload size (in bytes, default: 65536 = 64KB) +# Applies to both Socket.IO transport level and application-level validation +WS_MAX_PAYLOAD_BYTES=65536 + # HTTP request timeout (milliseconds) REQUEST_TIMEOUT=30000 diff --git a/package-lock.json b/package-lock.json index f31d3a7a..5a39a1b3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -48,6 +48,7 @@ "@opentelemetry/propagator-jaeger": "^2.0.1", "@opentelemetry/resources": "^2.7.1", "@opentelemetry/sdk-node": "^0.203.0", + "@opentelemetry/semantic-conventions": "^1.41.1", "@segment/analytics-node": "^2.1.2", "@types/csurf": "^1.11.5", "@types/express-session": "^1.18.2", @@ -4867,6 +4868,24 @@ } } }, + "node_modules/@nestjs/schematics/node_modules/chokidar": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-4.0.3.tgz", + "integrity": "sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==", + "dev": true, + "license": "MIT", + "optional": true, + "peer": true, + "dependencies": { + "readdirp": "^4.0.1" + }, + "engines": { + "node": ">= 14.16.0" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + } + }, "node_modules/@nestjs/schematics/node_modules/jsonc-parser": { "version": "3.3.1", "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.3.1.tgz", @@ -4897,6 +4916,22 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/@nestjs/schematics/node_modules/readdirp": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-4.1.2.tgz", + "integrity": "sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==", + "dev": true, + "license": "MIT", + "optional": true, + "peer": true, + "engines": { + "node": ">= 14.18.0" + }, + "funding": { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + }, "node_modules/@nestjs/schematics/node_modules/rxjs": { "version": "7.8.1", "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.1.tgz", diff --git a/src/collaboration/collaboration.gateway.ts b/src/collaboration/collaboration.gateway.ts index 1ec4effc..6d98a20e 100644 --- a/src/collaboration/collaboration.gateway.ts +++ b/src/collaboration/collaboration.gateway.ts @@ -13,6 +13,7 @@ import { JoinSessionDto, CollaborativeOperationDto, SyncRequestDto } from './dto import { OtCrdtService, Operation } from './ot-crdt.service'; import { PresenceService } from './presence.service'; import { ChangeHistoryService } from './change-history.service'; +import { WsPayloadSizeGuardService } from './guards/ws-payload-size-guard.service'; @WebSocketGateway({ namespace: '/collaboration', cors: { origin: '*' } }) export class CollaborationGateway implements OnGatewayDisconnect { @@ -27,6 +28,7 @@ export class CollaborationGateway implements OnGatewayDisconnect { private readonly otCrdt: OtCrdtService, private readonly presence: PresenceService, private readonly history: ChangeHistoryService, + private readonly payloadSizeGuard: WsPayloadSizeGuardService, ) {} handleDisconnect(client: Socket): void { @@ -44,6 +46,8 @@ export class CollaborationGateway implements OnGatewayDisconnect { @SubscribeMessage(COLLABORATION_EVENTS.JOIN_SESSION) handleJoin(@MessageBody() dto: JoinSessionDto, @ConnectedSocket() client: Socket) { + this.payloadSizeGuard.validate(dto); + client.join(dto.sessionId); this.socketMap.set(client.id, { sessionId: dto.sessionId, userId: dto.userId }); const presenceInfo = this.presence.join(dto.sessionId, dto.userId); @@ -70,6 +74,8 @@ export class CollaborationGateway implements OnGatewayDisconnect { @MessageBody() dto: CollaborativeOperationDto, @ConnectedSocket() client: Socket, ) { + this.payloadSizeGuard.validate(dto); + const incomingOp = dto.operation as Operation; const revision = this.otCrdt.nextRevision(dto.sessionId); const op: Operation = { ...incomingOp, sessionId: dto.sessionId, userId: dto.userId, revision }; @@ -101,6 +107,8 @@ export class CollaborationGateway implements OnGatewayDisconnect { @SubscribeMessage(COLLABORATION_EVENTS.REQUEST_SYNC) handleSync(@MessageBody() dto: SyncRequestDto) { + this.payloadSizeGuard.validate(dto); + const revision = this.otCrdt.currentRevision(dto.sessionId); const history = this.history.getLatest(dto.sessionId); @@ -112,6 +120,8 @@ export class CollaborationGateway implements OnGatewayDisconnect { @SubscribeMessage(COLLABORATION_EVENTS.RESOLVE_CONFLICT) handleConflict(@MessageBody() body: { op1: Operation; op2: Operation; sessionId: string }) { + this.payloadSizeGuard.validate(body); + const resolved = this.otCrdt.resolveConflict(body.op1, body.op2); this.server.to(body.sessionId).emit(COLLABORATION_EVENTS.CONFLICT_RESOLVED, { resolved }); return { event: COLLABORATION_EVENTS.CONFLICT_RESOLVED, data: { resolved } }; diff --git a/src/collaboration/collaboration.module.ts b/src/collaboration/collaboration.module.ts index 5c377a20..cc1c9bb8 100644 --- a/src/collaboration/collaboration.module.ts +++ b/src/collaboration/collaboration.module.ts @@ -1,11 +1,20 @@ import { Module } from '@nestjs/common'; +import { ConfigModule } from '@nestjs/config'; import { OtCrdtService } from './ot-crdt.service'; import { PresenceService } from './presence.service'; import { ChangeHistoryService } from './change-history.service'; import { CollaborationGateway } from './collaboration.gateway'; +import { WsPayloadSizeGuardService } from './guards/ws-payload-size-guard.service'; @Module({ - providers: [OtCrdtService, PresenceService, ChangeHistoryService, CollaborationGateway], + imports: [ConfigModule], + providers: [ + OtCrdtService, + PresenceService, + ChangeHistoryService, + CollaborationGateway, + WsPayloadSizeGuardService, + ], exports: [OtCrdtService, PresenceService, ChangeHistoryService], }) export class CollaborationModule {} diff --git a/src/collaboration/guards/ws-payload-size-guard.service.spec.ts b/src/collaboration/guards/ws-payload-size-guard.service.spec.ts new file mode 100644 index 00000000..30543a56 --- /dev/null +++ b/src/collaboration/guards/ws-payload-size-guard.service.spec.ts @@ -0,0 +1,244 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { ConfigService } from '@nestjs/config'; +import { WsException } from '@nestjs/websockets'; +import { + WsPayloadSizeGuardService, + WS_PAYLOAD_TOO_LARGE_CODE, + DEFAULT_WS_MAX_PAYLOAD_BYTES, +} from './ws-payload-size-guard.service'; + +describe('WsPayloadSizeGuardService', () => { + // --------------------------------------------------------------------------- + // Tests with default limit (64KB) + // --------------------------------------------------------------------------- + describe('with default limit', () => { + let service: WsPayloadSizeGuardService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + WsPayloadSizeGuardService, + { + provide: ConfigService, + useValue: { + get: jest.fn().mockReturnValue(undefined), + }, + }, + ], + }).compile(); + + service = module.get(WsPayloadSizeGuardService); + }); + + it('should use default limit of 64KB', () => { + expect(service.getMaxPayloadBytes()).toBe(DEFAULT_WS_MAX_PAYLOAD_BYTES); + expect(service.getMaxPayloadBytes()).toBe(65_536); + }); + + it('should accept a small payload', () => { + const payload = { sessionId: 'session-1', userId: 'user-1', data: 'hello' }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should accept an empty object', () => { + expect(() => service.validate({})).not.toThrow(); + }); + + it('should accept a payload just under the limit', () => { + // Create a payload that is just under 64KB + const padding = 'x'.repeat(60_000); + const payload = { data: padding }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should reject a payload exceeding 64KB', () => { + // Create a payload that is well over 64KB + const padding = 'x'.repeat(70_000); + const payload = { data: padding }; + + expect(() => service.validate(payload)).toThrow(WsException); + + try { + service.validate(payload); + } catch (error) { + expect(error).toBeInstanceOf(WsException); + const wsError = error as WsException; + const errorPayload = wsError.getError() as { code: string; message: string }; + expect(errorPayload.code).toBe(WS_PAYLOAD_TOO_LARGE_CODE); + expect(errorPayload.message).toContain('exceeds the maximum allowed size'); + } + }); + + it('should reject a large nested object', () => { + // Build a deeply nested object that exceeds 64KB when serialized + const largeArray = Array.from({ length: 5000 }, (_, i) => ({ + id: i, + content: `This is element number ${i} with some additional padding text to increase size`, + nested: { a: 'value', b: i * 100 }, + })); + const payload = { operations: largeArray }; + + expect(() => service.validate(payload)).toThrow(WsException); + }); + + it('should include PAYLOAD_TOO_LARGE code in the error', () => { + const padding = 'x'.repeat(70_000); + const payload = { data: padding }; + + try { + service.validate(payload); + fail('Expected WsException to be thrown'); + } catch (error) { + const wsError = error as WsException; + const errorPayload = wsError.getError() as { code: string; message: string }; + expect(errorPayload.code).toBe('PAYLOAD_TOO_LARGE'); + } + }); + + it('should include byte sizes in the error message', () => { + const padding = 'x'.repeat(70_000); + const payload = { data: padding }; + + try { + service.validate(payload); + fail('Expected WsException to be thrown'); + } catch (error) { + const wsError = error as WsException; + const errorPayload = wsError.getError() as { code: string; message: string }; + expect(errorPayload.message).toMatch(/\d+ bytes exceeds/); + expect(errorPayload.message).toMatch(/maximum allowed size of \d+ bytes/); + } + }); + }); + + // --------------------------------------------------------------------------- + // Tests with custom limit + // --------------------------------------------------------------------------- + describe('with custom limit', () => { + let service: WsPayloadSizeGuardService; + const customLimit = 1_024; // 1KB + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + WsPayloadSizeGuardService, + { + provide: ConfigService, + useValue: { + get: jest.fn((key: string) => { + if (key === 'WS_MAX_PAYLOAD_BYTES') return customLimit; + return undefined; + }), + }, + }, + ], + }).compile(); + + service = module.get(WsPayloadSizeGuardService); + }); + + it('should use the configured limit', () => { + expect(service.getMaxPayloadBytes()).toBe(customLimit); + }); + + it('should accept a payload under the custom limit', () => { + const payload = { key: 'value' }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should reject a payload over the custom 1KB limit', () => { + const padding = 'x'.repeat(2_000); + const payload = { data: padding }; + + expect(() => service.validate(payload)).toThrow(WsException); + }); + }); + + // --------------------------------------------------------------------------- + // Tests with realistic collaboration payloads + // --------------------------------------------------------------------------- + describe('with realistic collaboration payloads', () => { + let service: WsPayloadSizeGuardService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + WsPayloadSizeGuardService, + { + provide: ConfigService, + useValue: { + get: jest.fn().mockReturnValue(undefined), // Use default 64KB + }, + }, + ], + }).compile(); + + service = module.get(WsPayloadSizeGuardService); + }); + + it('should accept a normal collaborative operation', () => { + const payload = { + sessionId: 'session-abc-123', + userId: 'user-456', + resourceType: 'document', + operation: { + type: 'insert', + position: 42, + content: 'Hello, this is a normal collaborative edit.', + revision: 10, + }, + }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should accept a join-session message', () => { + const payload = { + sessionId: 'session-abc-123', + userId: 'user-456', + resourceType: 'document', + }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should accept a sync-request message', () => { + const payload = { + sessionId: 'session-abc-123', + userId: 'user-456', + }; + expect(() => service.validate(payload)).not.toThrow(); + }); + + it('should reject a malicious megabyte-scale operation payload', () => { + const maliciousPayload = { + sessionId: 'session-abc-123', + userId: 'attacker', + resourceType: 'document', + operation: { + type: 'insert', + position: 0, + // 1MB of content — clearly malicious + content: 'A'.repeat(1_048_576), + }, + }; + expect(() => service.validate(maliciousPayload)).toThrow(WsException); + }); + + it('should reject a payload with excessive metadata', () => { + const payload = { + sessionId: 'session-abc-123', + userId: 'user-456', + resourceType: 'document', + operation: { + type: 'insert', + position: 0, + content: 'small content', + // Attacker stuffs massive metadata + metadata: Object.fromEntries( + Array.from({ length: 5000 }, (_, i) => [`key-${i}`, `value-${'x'.repeat(20)}-${i}`]), + ), + }, + }; + expect(() => service.validate(payload)).toThrow(WsException); + }); + }); +}); diff --git a/src/collaboration/guards/ws-payload-size-guard.service.ts b/src/collaboration/guards/ws-payload-size-guard.service.ts new file mode 100644 index 00000000..40620f20 --- /dev/null +++ b/src/collaboration/guards/ws-payload-size-guard.service.ts @@ -0,0 +1,66 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { WsException } from '@nestjs/websockets'; +import { ConfigService } from '@nestjs/config'; +import { BYTES } from '../../common/constants/time.constants'; + +/** + * Default maximum WebSocket payload size in bytes (64KB). + * Can be overridden via the WS_MAX_PAYLOAD_BYTES environment variable. + */ +export const DEFAULT_WS_MAX_PAYLOAD_BYTES = BYTES.SIXTY_FOUR_KB; + +/** + * Error code returned when a WebSocket message exceeds the size limit. + */ +export const WS_PAYLOAD_TOO_LARGE_CODE = 'PAYLOAD_TOO_LARGE'; + +/** + * Service that validates WebSocket message payload sizes. + * + * Provides application-level defense-in-depth on top of the transport-level + * `maxHttpBufferSize` configured in main.ts. This ensures that even if the + * transport-level limit is relaxed for other namespaces, collaboration + * handlers still enforce their own limit and return a proper WsException. + */ +@Injectable() +export class WsPayloadSizeGuardService { + private readonly logger = new Logger(WsPayloadSizeGuardService.name); + private readonly maxPayloadBytes: number; + + constructor(private readonly configService: ConfigService) { + this.maxPayloadBytes = + this.configService.get('WS_MAX_PAYLOAD_BYTES') ?? DEFAULT_WS_MAX_PAYLOAD_BYTES; + + this.logger.log( + `WebSocket payload size limit: ${this.maxPayloadBytes} bytes (${Math.round(this.maxPayloadBytes / 1024)}KB)`, + ); + } + + /** + * Validate that a payload does not exceed the configured size limit. + * + * @param payload - The message payload to validate (any shape). + * @throws WsException with code PAYLOAD_TOO_LARGE if the payload is too large. + */ + validate(payload: unknown): void { + const serialized = JSON.stringify(payload); + const byteLength = Buffer.byteLength(serialized, 'utf8'); + + if (byteLength > this.maxPayloadBytes) { + this.logger.warn( + `Rejected oversized WebSocket payload: ${byteLength} bytes (limit: ${this.maxPayloadBytes} bytes)`, + ); + throw new WsException({ + code: WS_PAYLOAD_TOO_LARGE_CODE, + message: `Payload size ${byteLength} bytes exceeds the maximum allowed size of ${this.maxPayloadBytes} bytes`, + }); + } + } + + /** + * Get the currently configured max payload size in bytes. + */ + getMaxPayloadBytes(): number { + return this.maxPayloadBytes; + } +} diff --git a/src/common/constants/time.constants.ts b/src/common/constants/time.constants.ts index 4ae1b14a..e8726cc7 100644 --- a/src/common/constants/time.constants.ts +++ b/src/common/constants/time.constants.ts @@ -31,6 +31,7 @@ export const TIME = { */ export const BYTES = { ONE_KB: 1_024, + SIXTY_FOUR_KB: 65_536, // 64 * 1024 ONE_MB_BYTES: 1_048_576, // 1024 * 1024 TEN_MB_BYTES: 10_485_760, // 10 * 1024 * 1024 } as const; diff --git a/src/main.ts b/src/main.ts index 2a8f1afc..392998e4 100644 --- a/src/main.ts +++ b/src/main.ts @@ -54,8 +54,34 @@ async function bootstrapWorker(): Promise { 10, ); + const wsMaxPayloadBytes = parseInt( + process.env.WS_MAX_PAYLOAD_BYTES || `${BYTES.SIXTY_FOUR_KB}`, + 10, + ); + const app = await NestFactory.create(AppModule, { rawBody: true }); + // ========================= + // WEBSOCKET PAYLOAD SIZE LIMIT + // ========================= + // Configure Socket.IO maxHttpBufferSize at the transport layer. + // Messages exceeding this limit are rejected before reaching any handler. + const { IoAdapter } = await import('@nestjs/platform-socket.io'); + + class SizeLimitedIoAdapter extends IoAdapter { + createIOServer(port: number, options?: any): any { + return super.createIOServer(port, { + ...options, + maxHttpBufferSize: wsMaxPayloadBytes, + }); + } + } + + app.useWebSocketAdapter(new SizeLimitedIoAdapter(app)); + logger.log( + `WebSocket maxHttpBufferSize set to ${wsMaxPayloadBytes} bytes (${Math.round(wsMaxPayloadBytes / 1024)}KB)`, + ); + // Get shutdown services const shutdownState = app.get(ShutdownStateService); const gracefulShutdown = app.get(GracefulShutdownService);