Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 104 additions & 10 deletions backend/src/notifications/gateway/notifications.gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,55 @@ import {
WebSocketServer,
OnGatewayConnection,
OnGatewayDisconnect,
OnGatewayInit,
} from '@nestjs/websockets';
import { Logger } from '@nestjs/common';
import { Server, Socket } from 'socket.io';
import { JwtService } from '@nestjs/jwt';
import { JwtPayload } from '../../auth/interface/user.interface';
import { UserRole } from '../../users/enums/userRoles.enum';

const ADMIN_ROLES = [UserRole.ADMIN, UserRole.SUPER_ADMIN];

interface RateLimitEntry {
count: number;
resetAt: number;
}

@WebSocketGateway({
namespace: 'notifications',
cors: { origin: '*' },
})
export class NotificationsGateway
implements OnGatewayConnection, OnGatewayDisconnect
implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit
{
@WebSocketServer()
server: Server;

private readonly logger = new Logger(NotificationsGateway.name);

private readonly connectedClients = new Map<string, string>();

private readonly rateLimitMap = new Map<string, RateLimitEntry>();

private readonly RATE_LIMIT_MAX = 30;

private readonly RATE_LIMIT_WINDOW_MS = 10_000;

private heartbeatInterval: ReturnType<typeof setInterval> | null = null;

constructor(private readonly jwtService: JwtService) {}

afterInit(): void {
this.logger.log('WebSocket gateway initialized');

this.heartbeatInterval = setInterval(() => {
this.server.sockets?.sockets?.forEach((socket) => {
socket.emit('ping', { timestamp: Date.now() });
});
}, 25_000);
}

async handleConnection(client: Socket): Promise<void> {
try {
const token =
Expand All @@ -32,30 +62,94 @@ export class NotificationsGateway
);

if (!token) {
this.logger.warn(`Connection rejected (no token): ${client.id}`);
client.disconnect();
return;
}

const payload = this.jwtService.verify<{ sub: string }>(token, {
const payload = this.jwtService.verify<JwtPayload>(token, {
secret: process.env.JWT_SECRET,
});

// Join a room named after the user ID so we can target them
await client.join(`user:${payload.sub}`);
this.logger.log(`Client connected: ${client.id} (user ${payload.sub})`);
} catch {
const userId = payload.sub;

this.connectedClients.set(client.id, userId);

await client.join(`user:${userId}`);

if (payload.role && ADMIN_ROLES.includes(payload.role as UserRole)) {
await client.join('admin');
this.logger.log(
`Admin connected: ${client.id} (user ${userId}, role ${payload.role})`,
);
} else {
this.logger.log(
`Client connected: ${client.id} (user ${userId}${payload.role ? `, role ${payload.role}` : ''})`,
);
}

client.emit('connected', {
clientId: client.id,
userId,
timestamp: Date.now(),
});
} catch (error) {
this.logger.error(
`Connection rejected for ${client.id}: ${(error as Error).message}`,
);
client.disconnect();
}
}

handleDisconnect(client: Socket): void {
this.logger.log(`Client disconnected: ${client.id}`);
const userId = this.connectedClients.get(client.id);
if (userId) {
this.logger.log(`Client disconnected: ${client.id} (user ${userId})`);
this.connectedClients.delete(client.id);
} else {
this.logger.log(`Client disconnected: ${client.id}`);
}
this.rateLimitMap.delete(client.id);
}

/**
* Push a notification event to a specific user.
*/
sendToUser(userId: string, event: string, data: unknown): void {
this.server.to(`user:${userId}`).emit(event, data);
}

sendToAll(event: string, data: unknown): void {
this.server.emit(event, data);
}

sendToAdmins(event: string, data: unknown): void {
this.server.to('admin').emit(event, data);
}

getConnectedUsersCount(): number {
return this.connectedClients.size;
}

getConnectedUsers(): Map<string, string> {
return new Map(this.connectedClients);
}

isRateLimited(clientId: string): boolean {
const now = Date.now();
const entry = this.rateLimitMap.get(clientId);

if (!entry || now >= entry.resetAt) {
this.rateLimitMap.set(clientId, {
count: 1,
resetAt: now + this.RATE_LIMIT_WINDOW_MS,
});
return false;
}

entry.count += 1;
if (entry.count > this.RATE_LIMIT_MAX) {
this.logger.warn(`Rate limit exceeded for socket ${clientId}`);
return true;
}

return false;
}
}
2 changes: 1 addition & 1 deletion backend/src/notifications/notifications.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ import { FindNotificationsProvider } from './providers/find-notifications.provid
CreateNotificationProvider,
FindNotificationsProvider,
],
exports: [NotificationsService],
exports: [NotificationsService, NotificationsGateway],
})
export class NotificationsModule {}
Loading