From ab4e9a686fcad9ccdcb82fed092f5c6c40883138 Mon Sep 17 00:00:00 2001 From: nottherealalanturing Date: Sat, 27 Jun 2026 12:12:15 +0100 Subject: [PATCH] feat: comprehensive WebSocket gateway for real-time notifications --- .../gateway/notifications.gateway.ts | 114 ++++++++++++++++-- .../src/notifications/notifications.module.ts | 2 +- 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/backend/src/notifications/gateway/notifications.gateway.ts b/backend/src/notifications/gateway/notifications.gateway.ts index de6509c4..e1f4d229 100644 --- a/backend/src/notifications/gateway/notifications.gateway.ts +++ b/backend/src/notifications/gateway/notifications.gateway.ts @@ -3,25 +3,55 @@ import { WebSocketServer, OnGatewayConnection, OnGatewayDisconnect, + OnGatewayInit, } from '@nestjs/websockets'; import { Logger } from '@nestjs/common'; import { Server, Socket } from 'socket.io'; import { JwtService } from '@nestjs/jwt'; +import { JwtPayload } from '../../auth/interface/user.interface'; +import { UserRole } from '../../users/enums/userRoles.enum'; + +const ADMIN_ROLES = [UserRole.ADMIN, UserRole.SUPER_ADMIN]; + +interface RateLimitEntry { + count: number; + resetAt: number; +} @WebSocketGateway({ namespace: 'notifications', cors: { origin: '*' }, }) export class NotificationsGateway - implements OnGatewayConnection, OnGatewayDisconnect + implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit { @WebSocketServer() server: Server; private readonly logger = new Logger(NotificationsGateway.name); + private readonly connectedClients = new Map(); + + private readonly rateLimitMap = new Map(); + + private readonly RATE_LIMIT_MAX = 30; + + private readonly RATE_LIMIT_WINDOW_MS = 10_000; + + private heartbeatInterval: ReturnType | null = null; + constructor(private readonly jwtService: JwtService) {} + afterInit(): void { + this.logger.log('WebSocket gateway initialized'); + + this.heartbeatInterval = setInterval(() => { + this.server.sockets?.sockets?.forEach((socket) => { + socket.emit('ping', { timestamp: Date.now() }); + }); + }, 25_000); + } + async handleConnection(client: Socket): Promise { try { const token = @@ -32,30 +62,94 @@ export class NotificationsGateway ); if (!token) { + this.logger.warn(`Connection rejected (no token): ${client.id}`); client.disconnect(); return; } - const payload = this.jwtService.verify<{ sub: string }>(token, { + const payload = this.jwtService.verify(token, { secret: process.env.JWT_SECRET, }); - // Join a room named after the user ID so we can target them - await client.join(`user:${payload.sub}`); - this.logger.log(`Client connected: ${client.id} (user ${payload.sub})`); - } catch { + const userId = payload.sub; + + this.connectedClients.set(client.id, userId); + + await client.join(`user:${userId}`); + + if (payload.role && ADMIN_ROLES.includes(payload.role as UserRole)) { + await client.join('admin'); + this.logger.log( + `Admin connected: ${client.id} (user ${userId}, role ${payload.role})`, + ); + } else { + this.logger.log( + `Client connected: ${client.id} (user ${userId}${payload.role ? `, role ${payload.role}` : ''})`, + ); + } + + client.emit('connected', { + clientId: client.id, + userId, + timestamp: Date.now(), + }); + } catch (error) { + this.logger.error( + `Connection rejected for ${client.id}: ${(error as Error).message}`, + ); client.disconnect(); } } handleDisconnect(client: Socket): void { - this.logger.log(`Client disconnected: ${client.id}`); + const userId = this.connectedClients.get(client.id); + if (userId) { + this.logger.log(`Client disconnected: ${client.id} (user ${userId})`); + this.connectedClients.delete(client.id); + } else { + this.logger.log(`Client disconnected: ${client.id}`); + } + this.rateLimitMap.delete(client.id); } - /** - * Push a notification event to a specific user. - */ sendToUser(userId: string, event: string, data: unknown): void { this.server.to(`user:${userId}`).emit(event, data); } + + sendToAll(event: string, data: unknown): void { + this.server.emit(event, data); + } + + sendToAdmins(event: string, data: unknown): void { + this.server.to('admin').emit(event, data); + } + + getConnectedUsersCount(): number { + return this.connectedClients.size; + } + + getConnectedUsers(): Map { + return new Map(this.connectedClients); + } + + isRateLimited(clientId: string): boolean { + const now = Date.now(); + const entry = this.rateLimitMap.get(clientId); + + if (!entry || now >= entry.resetAt) { + this.rateLimitMap.set(clientId, { + count: 1, + resetAt: now + this.RATE_LIMIT_WINDOW_MS, + }); + return false; + } + + entry.count += 1; + if (entry.count > this.RATE_LIMIT_MAX) { + this.logger.warn(`Rate limit exceeded for socket ${clientId}`); + return true; + } + + return false; + } } diff --git a/backend/src/notifications/notifications.module.ts b/backend/src/notifications/notifications.module.ts index e2dcd12f..98098cf5 100644 --- a/backend/src/notifications/notifications.module.ts +++ b/backend/src/notifications/notifications.module.ts @@ -27,6 +27,6 @@ import { FindNotificationsProvider } from './providers/find-notifications.provid CreateNotificationProvider, FindNotificationsProvider, ], - exports: [NotificationsService], + exports: [NotificationsService, NotificationsGateway], }) export class NotificationsModule {}