From cb1382f72a8c6ad69a025cd55692c3cb04965b1f Mon Sep 17 00:00:00 2001 From: Anubhav Singh Date: Sun, 28 Jun 2026 21:52:07 +0530 Subject: [PATCH] feat: add heartbeat watchdog, device revocation pub/sub, rate limiting, and backpressure - Heartbeat: server-side 90s timeout marks device offline and expires Redis TTLs when heartbeats stop. Throttled lastSeenAt bump via devices.updatedAt. - Device revocation: subscribe gateways to device_revoked:* Redis channel. On receipt, disconnect the device socket immediately, clear Redis mappings, and reject post-revocation events via socket middleware. - Rate limiting: per-socket event/sec counters via Redis with configurable SOCKET_RATE_LIMIT_PER_SEC env. Max payload enforcement via MAX_PAYLOAD_SIZE. Violations warn, 3rd violation disconnects. - Backpressure: monitor WebSocket bufferedAmount every 5s. Shed slow consumers above SOCKET_SHED_THRESHOLD, disconnect above SOCKET_BUFFER_THRESHOLD. Non-critical broadcasts use volatile emit for graceful degradation. --- .../src/__tests__/readReceipts.test.ts | 8 +- apps/backend/src/index.ts | 94 ++++++++++++++++--- apps/backend/src/services/backpressure.ts | 86 +++++++++++++++++ apps/backend/src/services/deviceRevocation.ts | 76 +++++++++++++++ apps/backend/src/services/heartbeat.ts | 73 ++++++++++++++ apps/backend/src/services/presence.ts | 11 ++- apps/backend/src/services/rateLimit.ts | 53 +++++++++++ apps/backend/src/socket/messaging.ts | 10 +- 8 files changed, 391 insertions(+), 20 deletions(-) create mode 100644 apps/backend/src/services/backpressure.ts create mode 100644 apps/backend/src/services/deviceRevocation.ts create mode 100644 apps/backend/src/services/heartbeat.ts create mode 100644 apps/backend/src/services/rateLimit.ts diff --git a/apps/backend/src/__tests__/readReceipts.test.ts b/apps/backend/src/__tests__/readReceipts.test.ts index c70280c..063b1c0 100644 --- a/apps/backend/src/__tests__/readReceipts.test.ts +++ b/apps/backend/src/__tests__/readReceipts.test.ts @@ -49,11 +49,13 @@ function makeSocket(userId: string) { function makeIo() { const roomEmitted: { event: string; data: unknown }[] = []; + const emitFn = vi.fn((event: string, data: unknown) => { + roomEmitted.push({ event, data }); + }); const io = { to: vi.fn(() => ({ - emit: vi.fn((event: string, data: unknown) => { - roomEmitted.push({ event, data }); - }), + emit: emitFn, + volatile: { emit: emitFn }, })), roomEmitted, }; diff --git a/apps/backend/src/index.ts b/apps/backend/src/index.ts index f8d60b7..a101c59 100644 --- a/apps/backend/src/index.ts +++ b/apps/backend/src/index.ts @@ -11,7 +11,21 @@ import { registerMessagingHandlers } from './socket/messaging.js'; import { app } from './app.js'; import { redis as appRedis } from './lib/redis.js'; import { setSocketServer } from './lib/socket.js'; -import { setOnline, setOffline, refreshPresence } from './services/presence.js'; +import { setOnline, setOffline } from './services/presence.js'; +import { startHeartbeatTimer, clearHeartbeatTimer } from './services/heartbeat.js'; +import { + registerDeviceSocket, + unregisterDeviceSocket, + isDeviceRevoked, + startDeviceRevocationListener, +} from './services/deviceRevocation.js'; +import { + checkRateLimit, + checkPayloadSize, + recordViolation, + clearViolations, +} from './services/rateLimit.js'; +import { registerForBackpressure, unregisterForBackpressure } from './services/backpressure.js'; import { buildRpcFetcher, buildTreasuryRpcFetcher, @@ -36,8 +50,57 @@ io.use(socketAuthMiddleware); io.on('connection', async (socket: AuthSocket) => { const userId = socket.auth!.userId; + const deviceId = socket.auth!.deviceId; console.log('User connected:', userId, socket.id); + // Register socket for device-revocation tracking (cross-instance via Redis pub/sub). + if (appRedis) { + registerDeviceSocket(deviceId, socket.id); + } + + // Start the server-side heartbeat watchdog (90 s timeout). + startHeartbeatTimer(socket, userId, deviceId, appRedis, io); + + // Per-socket middleware: intercept every incoming event before handlers. + const EXCLUDED_EVENTS = new Set(['heartbeat']); + socket.use(async ([event, ...args], next) => { + // Skip internal heartbeat pings. + if (EXCLUDED_EVENTS.has(event)) { + return next(); + } + + // Reject events from a device that was revoked mid-session. + if (isDeviceRevoked(deviceId)) { + socket.emit('error', { event: 'device_revoked', message: 'Device has been revoked' }); + socket.disconnect(true); + return; + } + + // Enforce maximum payload size (configurable via MAX_PAYLOAD_SIZE env). + const payloadArgs = args.filter((a) => typeof a !== 'function'); + const { valid, size } = checkPayloadSize(payloadArgs); + if (!valid) { + socket.emit('error', { + event: 'payload_too_large', + message: `Payload size ${size} exceeds limit`, + }); + return; + } + + // Per-socket rate limiting (configurable via SOCKET_RATE_LIMIT_PER_SEC env). + const { allowed } = await checkRateLimit(appRedis, socket.id); + if (!allowed) { + const violations = recordViolation(socket.id); + socket.emit('error', { event: 'rate_limited', message: 'Rate limit exceeded' }); + if (violations >= 3) { + socket.disconnect(true); + } + return; + } + + next(); + }); + // Auto-join all conversation rooms so the socket receives new_message events // for every conversation the user belongs to (needed for unread badge tracking). const memberships = await db.query.conversationMembers.findMany({ @@ -51,21 +114,23 @@ io.on('connection', async (socket: AuthSocket) => { if (appRedis) { await setOnline(appRedis, userId, socket.id); for (const m of memberships) { - io.to(m.conversationId).emit('user_online', { userId }); - io.to(m.conversationId).emit('presence_update', { userId, online: true }); + io.to(m.conversationId).volatile.emit('user_online', { userId }); + io.to(m.conversationId).volatile.emit('presence_update', { userId, online: true }); } } - socket.on('heartbeat', async () => { - if (appRedis) { - await refreshPresence(appRedis, userId); - } - }); - registerMessagingHandlers(io, socket); + // Monitor send-buffer to detect slow/stalled consumers. + registerForBackpressure(socket); + socket.on('disconnect', async () => { console.log('User disconnected:', userId); + clearHeartbeatTimer(socket.id); + unregisterDeviceSocket(socket.id); + unregisterForBackpressure(socket); + clearViolations(socket.id); + if (appRedis) { const fullyOffline = await setOffline(appRedis, userId, socket.id); if (fullyOffline) { @@ -74,8 +139,8 @@ io.on('connection', async (socket: AuthSocket) => { columns: { conversationId: true }, }); for (const m of memberships) { - io.to(m.conversationId).emit('user_offline', { userId }); - io.to(m.conversationId).emit('presence_update', { userId, online: false }); + io.to(m.conversationId).volatile.emit('user_offline', { userId }); + io.to(m.conversationId).volatile.emit('presence_update', { userId, online: false }); } } } @@ -123,6 +188,13 @@ httpServer.listen(PORT, () => { // Redis is unreachable; on failure we fall back to the in-process adapter. void attachRedisAdapter(); +// Subscribe to device_revoked:* channels so any gateway instance can +// disconnect a revoked device's sockets within seconds, even when the +// revocation was issued on a different node. +if (appRedis) { + void startDeviceRevocationListener(appRedis, appRedis); +} + // #46 — Stellar transfer event listener. Only spin up when the contract // id is configured so local-dev and unit-test runs don't try to talk to // Soroban RPC. The listener never throws out of runForever, so a failed diff --git a/apps/backend/src/services/backpressure.ts b/apps/backend/src/services/backpressure.ts new file mode 100644 index 0000000..c0f09c4 --- /dev/null +++ b/apps/backend/src/services/backpressure.ts @@ -0,0 +1,86 @@ +import type { AuthSocket } from '../middleware/socketAuth.js'; + +function getBufferThreshold(): number { + const val = process.env['SOCKET_BUFFER_THRESHOLD']; + if (val) { + const parsed = parseInt(val, 10); + if (!isNaN(parsed) && parsed > 0) return parsed; + } + return 65536; +} + +function getShedThreshold(): number { + const val = process.env['SOCKET_SHED_THRESHOLD']; + if (val) { + const parsed = parseInt(val, 10); + if (!isNaN(parsed) && parsed > 0) return parsed; + } + return 32768; +} + +const shedSockets = new Set(); +const socketsToMonitor = new Set(); +let monitorInterval: ReturnType | null = null; + +export function registerForBackpressure(socket: AuthSocket): void { + socketsToMonitor.add(socket); + if (!monitorInterval) { + monitorInterval = setInterval(checkBuffers, 5000); + } +} + +export function unregisterForBackpressure(socket: AuthSocket): void { + socketsToMonitor.delete(socket); + shedSockets.delete(socket.id); + if (socketsToMonitor.size === 0 && monitorInterval) { + clearInterval(monitorInterval); + monitorInterval = null; + } +} + +export function isSocketShed(socketId: string): boolean { + return shedSockets.has(socketId); +} + +function checkBuffers(): void { + const disconnectThreshold = getBufferThreshold(); + const shedThreshold = getShedThreshold(); + + for (const socket of socketsToMonitor) { + const buffered = getBufferedAmount(socket); + + if (buffered > disconnectThreshold) { + console.warn( + `Socket ${socket.id} buffer ${buffered} exceeds disconnect threshold ${disconnectThreshold}, disconnecting`, + ); + shedSockets.add(socket.id); + socket.disconnect(true); + } else if (buffered > shedThreshold) { + if (!shedSockets.has(socket.id)) { + console.warn( + `Socket ${socket.id} buffer ${buffered} exceeds shed threshold ${shedThreshold}, shedding`, + ); + shedSockets.add(socket.id); + } + } else { + if (shedSockets.has(socket.id)) { + shedSockets.delete(socket.id); + } + } + } +} + +function getBufferedAmount(socket: AuthSocket): number { + try { + const conn = socket.conn as unknown as { + transport?: { socket?: { bufferedAmount?: number } }; + }; + const ws = conn.transport?.socket; + if (ws && typeof ws.bufferedAmount === 'number') { + return ws.bufferedAmount; + } + } catch { + // Ignore errors accessing internal transport + } + return 0; +} diff --git a/apps/backend/src/services/deviceRevocation.ts b/apps/backend/src/services/deviceRevocation.ts new file mode 100644 index 0000000..20f1ef0 --- /dev/null +++ b/apps/backend/src/services/deviceRevocation.ts @@ -0,0 +1,76 @@ +import type { Redis } from 'ioredis'; +import { getSocketServer } from '../lib/socket.js'; +import type { AuthSocket } from '../middleware/socketAuth.js'; +import { setOffline } from './presence.js'; + +const deviceSockets = new Map>(); +const socketDevice = new Map(); +const revokedMidSession = new Set(); + +export function registerDeviceSocket(deviceId: string, socketId: string): void { + let sockets = deviceSockets.get(deviceId); + if (!sockets) { + sockets = new Set(); + deviceSockets.set(deviceId, sockets); + } + sockets.add(socketId); + socketDevice.set(socketId, deviceId); +} + +export function unregisterDeviceSocket(socketId: string): void { + const deviceId = socketDevice.get(socketId); + if (deviceId) { + const sockets = deviceSockets.get(deviceId); + if (sockets) { + sockets.delete(socketId); + if (sockets.size === 0) { + deviceSockets.delete(deviceId); + } + } + socketDevice.delete(socketId); + } +} + +export function isDeviceRevoked(deviceId: string): boolean { + return revokedMidSession.has(deviceId); +} + +export function markDeviceRevoked(deviceId: string): void { + revokedMidSession.add(deviceId); +} + +export async function startDeviceRevocationListener( + redis: Redis, + appRedis: Redis | null, +): Promise { + if (redis.status !== 'ready' && redis.status !== 'connect') { + await redis.connect(); + } + + await redis.psubscribe('device_revoked:*'); + + redis.on('pmessage', async (_pattern: string, channel: string, _message: string) => { + const deviceId = channel.replace('device_revoked:', ''); + markDeviceRevoked(deviceId); + + console.log(`Device revoked mid-session: ${deviceId}`); + + const socketIds = deviceSockets.get(deviceId); + if (socketIds) { + const io = getSocketServer(); + for (const socketId of [...socketIds]) { + if (io) { + const socket = io.sockets.sockets.get(socketId) as AuthSocket | undefined; + if (socket) { + if (appRedis && socket.auth) { + await setOffline(appRedis, socket.auth.userId, socketId); + } + socket.emit('device_revoked', { message: 'This device has been revoked' }); + socket.disconnect(true); + } + } + unregisterDeviceSocket(socketId); + } + } + }); +} diff --git a/apps/backend/src/services/heartbeat.ts b/apps/backend/src/services/heartbeat.ts new file mode 100644 index 0000000..5adad87 --- /dev/null +++ b/apps/backend/src/services/heartbeat.ts @@ -0,0 +1,73 @@ +import type { Server } from 'socket.io'; +import type { Redis } from 'ioredis'; +import type { AuthSocket } from '../middleware/socketAuth.js'; +import { db } from '../db/index.js'; +import { devices } from '../db/schema.js'; +import { eq } from 'drizzle-orm'; +import { refreshPresence, markDeviceOffline } from './presence.js'; + +const HEARTBEAT_TIMEOUT_MS = 90_000; +const LAST_SEEN_THROTTLE_MS = 30_000; + +const timers = new Map>(); +const lastSeenAt = new Map(); + +export function startHeartbeatTimer( + socket: AuthSocket, + userId: string, + deviceId: string, + redis: Redis | null, + io: Server, +): void { + const schedule = () => { + clearTimeout(timers.get(socket.id)); + const timer = setTimeout(async () => { + timers.delete(socket.id); + console.log(`Heartbeat timeout for device ${deviceId} (user ${userId})`); + + if (redis) { + await markDeviceOffline(redis, userId); + } + + if (socket.connected) { + for (const room of socket.rooms) { + if (room !== socket.id) { + io.to(room).volatile.emit('user_offline', { userId }); + io.to(room).volatile.emit('presence_update', { userId, online: false }); + } + } + socket.disconnect(true); + } + }, HEARTBEAT_TIMEOUT_MS); + timers.set(socket.id, timer); + }; + + schedule(); + + socket.on('heartbeat', async () => { + clearTimeout(timers.get(socket.id)); + timers.delete(socket.id); + + if (redis) { + await refreshPresence(redis, userId); + } + + const now = Date.now(); + const last = lastSeenAt.get(deviceId) ?? 0; + if (now - last >= LAST_SEEN_THROTTLE_MS) { + lastSeenAt.set(deviceId, now); + try { + await db.update(devices).set({ updatedAt: new Date() }).where(eq(devices.id, deviceId)); + } catch { + // Non-critical update; ignore errors. + } + } + + schedule(); + }); +} + +export function clearHeartbeatTimer(socketId: string): void { + clearTimeout(timers.get(socketId)); + timers.delete(socketId); +} diff --git a/apps/backend/src/services/presence.ts b/apps/backend/src/services/presence.ts index ccda9cd..1013131 100644 --- a/apps/backend/src/services/presence.ts +++ b/apps/backend/src/services/presence.ts @@ -12,7 +12,7 @@ */ import type { Redis } from 'ioredis'; -const PRESENCE_TTL = 60; // seconds +const PRESENCE_TTL = 90; // seconds function presenceKey(userId: string): string { return `presence:${userId}`; @@ -54,6 +54,15 @@ export async function setOffline(redis: Redis, userId: string, socketId: string) return false; } +/** + * Forcefully mark a user offline by deleting their presence key. + * Used when a heartbeat timeout or device revocation occurs. + */ +export async function markDeviceOffline(redis: Redis, userId: string): Promise { + const key = presenceKey(userId); + await redis.del(key); +} + /** * Check if a user is currently online. */ diff --git a/apps/backend/src/services/rateLimit.ts b/apps/backend/src/services/rateLimit.ts new file mode 100644 index 0000000..efe6e21 --- /dev/null +++ b/apps/backend/src/services/rateLimit.ts @@ -0,0 +1,53 @@ +import type { Redis } from 'ioredis'; + +function getRateLimitPerSec(): number { + const val = process.env['SOCKET_RATE_LIMIT_PER_SEC']; + if (val) { + const parsed = parseInt(val, 10); + if (!isNaN(parsed) && parsed > 0) return parsed; + } + return 10; +} + +function getMaxPayloadSize(): number { + const val = process.env['MAX_PAYLOAD_SIZE']; + if (val) { + const parsed = parseInt(val, 10); + if (!isNaN(parsed) && parsed > 0) return parsed; + } + return 16384; +} + +const violationCount = new Map(); + +export async function checkRateLimit( + redis: Redis | null, + socketId: string, +): Promise<{ allowed: boolean; count: number }> { + const limit = getRateLimitPerSec(); + if (!redis) return { allowed: true, count: 0 }; + + const key = `rl:socket:${socketId}`; + const count = await redis.incr(key); + if (count === 1) { + await redis.expire(key, 1); + } + return { allowed: count <= limit, count }; +} + +export function checkPayloadSize(data: unknown): { valid: boolean; size: number } { + const maxSize = getMaxPayloadSize(); + const raw = JSON.stringify(data); + const size = Buffer.byteLength(raw, 'utf8'); + return { valid: size <= maxSize, size }; +} + +export function recordViolation(socketId: string): number { + const count = (violationCount.get(socketId) ?? 0) + 1; + violationCount.set(socketId, count); + return count; +} + +export function clearViolations(socketId: string): void { + violationCount.delete(socketId); +} diff --git a/apps/backend/src/socket/messaging.ts b/apps/backend/src/socket/messaging.ts index 17d3bab..8c92463 100644 --- a/apps/backend/src/socket/messaging.ts +++ b/apps/backend/src/socket/messaging.ts @@ -62,7 +62,7 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void .values({ conversationId, senderId: userId, content: content.trim() }) .returning(); - io.to(conversationId).emit('new_message', message); + io.to(conversationId).volatile.emit('new_message', message); const members = await db.query.conversationMembers.findMany({ where: eq(conversationMembers.conversationId, conversationId), @@ -162,7 +162,7 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void ), ); - io.to(conversationId).emit('read_receipt', { userId, lastReadMessageId }); + io.to(conversationId).volatile.emit('read_receipt', { userId, lastReadMessageId }); }, ); @@ -213,7 +213,7 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void return; } - socket.to(conversationId).emit('typing_start', { conversationId, userId }); + socket.to(conversationId).volatile.emit('typing_start', { conversationId, userId }); }); // ── typing_stop ───────────────────────────────────────────────────────────── @@ -234,7 +234,7 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void return; } - socket.to(conversationId).emit('typing_stop', { conversationId, userId }); + socket.to(conversationId).volatile.emit('typing_stop', { conversationId, userId }); }); // ── ask_assistant ────────────────────────────────────────────────────────── @@ -321,7 +321,7 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void }) .returning(); - io.to(conversationId).emit('new_message', replyMessage); + io.to(conversationId).volatile.emit('new_message', replyMessage); const members = await db.query.conversationMembers.findMany({ where: eq(conversationMembers.conversationId, conversationId),