Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions apps/backend/src/__tests__/readReceipts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ function makeSocket(userId: string) {

function makeIo() {
const roomEmitted: { event: string; data: unknown }[] = [];
const emitFn = vi.fn((event: string, data: unknown) => {
roomEmitted.push({ event, data });
});
const io = {
to: vi.fn(() => ({
emit: vi.fn((event: string, data: unknown) => {
roomEmitted.push({ event, data });
}),
emit: emitFn,
volatile: { emit: emitFn },
})),
roomEmitted,
};
Expand Down
94 changes: 83 additions & 11 deletions apps/backend/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,21 @@ import { registerMessagingHandlers } from './socket/messaging.js';
import { app } from './app.js';
import { redis as appRedis } from './lib/redis.js';
import { setSocketServer } from './lib/socket.js';
import { setOnline, setOffline, refreshPresence } from './services/presence.js';
import { setOnline, setOffline } from './services/presence.js';
import { startHeartbeatTimer, clearHeartbeatTimer } from './services/heartbeat.js';
import {
registerDeviceSocket,
unregisterDeviceSocket,
isDeviceRevoked,
startDeviceRevocationListener,
} from './services/deviceRevocation.js';
import {
checkRateLimit,
checkPayloadSize,
recordViolation,
clearViolations,
} from './services/rateLimit.js';
import { registerForBackpressure, unregisterForBackpressure } from './services/backpressure.js';
import {
buildRpcFetcher,
buildTreasuryRpcFetcher,
Expand All @@ -36,8 +50,57 @@ io.use(socketAuthMiddleware);

io.on('connection', async (socket: AuthSocket) => {
const userId = socket.auth!.userId;
const deviceId = socket.auth!.deviceId;
console.log('User connected:', userId, socket.id);

// Register socket for device-revocation tracking (cross-instance via Redis pub/sub).
if (appRedis) {
registerDeviceSocket(deviceId, socket.id);
}

// Start the server-side heartbeat watchdog (90 s timeout).
startHeartbeatTimer(socket, userId, deviceId, appRedis, io);

// Per-socket middleware: intercept every incoming event before handlers.
const EXCLUDED_EVENTS = new Set(['heartbeat']);
socket.use(async ([event, ...args], next) => {
// Skip internal heartbeat pings.
if (EXCLUDED_EVENTS.has(event)) {
return next();
}

// Reject events from a device that was revoked mid-session.
if (isDeviceRevoked(deviceId)) {
socket.emit('error', { event: 'device_revoked', message: 'Device has been revoked' });
socket.disconnect(true);
return;
}

// Enforce maximum payload size (configurable via MAX_PAYLOAD_SIZE env).
const payloadArgs = args.filter((a) => typeof a !== 'function');
const { valid, size } = checkPayloadSize(payloadArgs);
if (!valid) {
socket.emit('error', {
event: 'payload_too_large',
message: `Payload size ${size} exceeds limit`,
});
return;
}

// Per-socket rate limiting (configurable via SOCKET_RATE_LIMIT_PER_SEC env).
const { allowed } = await checkRateLimit(appRedis, socket.id);
if (!allowed) {
const violations = recordViolation(socket.id);
socket.emit('error', { event: 'rate_limited', message: 'Rate limit exceeded' });
if (violations >= 3) {
socket.disconnect(true);
}
return;
}

next();
});

// Auto-join all conversation rooms so the socket receives new_message events
// for every conversation the user belongs to (needed for unread badge tracking).
const memberships = await db.query.conversationMembers.findMany({
Expand All @@ -51,21 +114,23 @@ io.on('connection', async (socket: AuthSocket) => {
if (appRedis) {
await setOnline(appRedis, userId, socket.id);
for (const m of memberships) {
io.to(m.conversationId).emit('user_online', { userId });
io.to(m.conversationId).emit('presence_update', { userId, online: true });
io.to(m.conversationId).volatile.emit('user_online', { userId });
io.to(m.conversationId).volatile.emit('presence_update', { userId, online: true });
}
}

socket.on('heartbeat', async () => {
if (appRedis) {
await refreshPresence(appRedis, userId);
}
});

registerMessagingHandlers(io, socket);

// Monitor send-buffer to detect slow/stalled consumers.
registerForBackpressure(socket);

socket.on('disconnect', async () => {
console.log('User disconnected:', userId);
clearHeartbeatTimer(socket.id);
unregisterDeviceSocket(socket.id);
unregisterForBackpressure(socket);
clearViolations(socket.id);

if (appRedis) {
const fullyOffline = await setOffline(appRedis, userId, socket.id);
if (fullyOffline) {
Expand All @@ -74,8 +139,8 @@ io.on('connection', async (socket: AuthSocket) => {
columns: { conversationId: true },
});
for (const m of memberships) {
io.to(m.conversationId).emit('user_offline', { userId });
io.to(m.conversationId).emit('presence_update', { userId, online: false });
io.to(m.conversationId).volatile.emit('user_offline', { userId });
io.to(m.conversationId).volatile.emit('presence_update', { userId, online: false });
}
}
}
Expand Down Expand Up @@ -123,6 +188,13 @@ httpServer.listen(PORT, () => {
// Redis is unreachable; on failure we fall back to the in-process adapter.
void attachRedisAdapter();

// Subscribe to device_revoked:* channels so any gateway instance can
// disconnect a revoked device's sockets within seconds, even when the
// revocation was issued on a different node.
if (appRedis) {
void startDeviceRevocationListener(appRedis, appRedis);
}

// #46 — Stellar transfer event listener. Only spin up when the contract
// id is configured so local-dev and unit-test runs don't try to talk to
// Soroban RPC. The listener never throws out of runForever, so a failed
Expand Down
86 changes: 86 additions & 0 deletions apps/backend/src/services/backpressure.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import type { AuthSocket } from '../middleware/socketAuth.js';

function getBufferThreshold(): number {
const val = process.env['SOCKET_BUFFER_THRESHOLD'];
if (val) {
const parsed = parseInt(val, 10);
if (!isNaN(parsed) && parsed > 0) return parsed;
}
return 65536;
}

function getShedThreshold(): number {
const val = process.env['SOCKET_SHED_THRESHOLD'];
if (val) {
const parsed = parseInt(val, 10);
if (!isNaN(parsed) && parsed > 0) return parsed;
}
return 32768;
}

const shedSockets = new Set<string>();
const socketsToMonitor = new Set<AuthSocket>();
let monitorInterval: ReturnType<typeof setInterval> | null = null;

export function registerForBackpressure(socket: AuthSocket): void {
socketsToMonitor.add(socket);
if (!monitorInterval) {
monitorInterval = setInterval(checkBuffers, 5000);
}
}

export function unregisterForBackpressure(socket: AuthSocket): void {
socketsToMonitor.delete(socket);
shedSockets.delete(socket.id);
if (socketsToMonitor.size === 0 && monitorInterval) {
clearInterval(monitorInterval);
monitorInterval = null;
}
}

export function isSocketShed(socketId: string): boolean {
return shedSockets.has(socketId);
}

function checkBuffers(): void {
const disconnectThreshold = getBufferThreshold();
const shedThreshold = getShedThreshold();

for (const socket of socketsToMonitor) {
const buffered = getBufferedAmount(socket);

if (buffered > disconnectThreshold) {
console.warn(
`Socket ${socket.id} buffer ${buffered} exceeds disconnect threshold ${disconnectThreshold}, disconnecting`,
);
shedSockets.add(socket.id);
socket.disconnect(true);
} else if (buffered > shedThreshold) {
if (!shedSockets.has(socket.id)) {
console.warn(
`Socket ${socket.id} buffer ${buffered} exceeds shed threshold ${shedThreshold}, shedding`,
);
shedSockets.add(socket.id);
}
} else {
if (shedSockets.has(socket.id)) {
shedSockets.delete(socket.id);
}
}
}
}

function getBufferedAmount(socket: AuthSocket): number {
try {
const conn = socket.conn as unknown as {
transport?: { socket?: { bufferedAmount?: number } };
};
const ws = conn.transport?.socket;
if (ws && typeof ws.bufferedAmount === 'number') {
return ws.bufferedAmount;
}
} catch {
// Ignore errors accessing internal transport
}
return 0;
}
76 changes: 76 additions & 0 deletions apps/backend/src/services/deviceRevocation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import type { Redis } from 'ioredis';
import { getSocketServer } from '../lib/socket.js';
import type { AuthSocket } from '../middleware/socketAuth.js';
import { setOffline } from './presence.js';

const deviceSockets = new Map<string, Set<string>>();
const socketDevice = new Map<string, string>();
const revokedMidSession = new Set<string>();

export function registerDeviceSocket(deviceId: string, socketId: string): void {
let sockets = deviceSockets.get(deviceId);
if (!sockets) {
sockets = new Set();
deviceSockets.set(deviceId, sockets);
}
sockets.add(socketId);
socketDevice.set(socketId, deviceId);
}

export function unregisterDeviceSocket(socketId: string): void {
const deviceId = socketDevice.get(socketId);
if (deviceId) {
const sockets = deviceSockets.get(deviceId);
if (sockets) {
sockets.delete(socketId);
if (sockets.size === 0) {
deviceSockets.delete(deviceId);
}
}
socketDevice.delete(socketId);
}
}

export function isDeviceRevoked(deviceId: string): boolean {
return revokedMidSession.has(deviceId);
}

export function markDeviceRevoked(deviceId: string): void {
revokedMidSession.add(deviceId);
}

export async function startDeviceRevocationListener(
redis: Redis,
appRedis: Redis | null,
): Promise<void> {
if (redis.status !== 'ready' && redis.status !== 'connect') {
await redis.connect();
}

await redis.psubscribe('device_revoked:*');

redis.on('pmessage', async (_pattern: string, channel: string, _message: string) => {
const deviceId = channel.replace('device_revoked:', '');
markDeviceRevoked(deviceId);

console.log(`Device revoked mid-session: ${deviceId}`);

const socketIds = deviceSockets.get(deviceId);
if (socketIds) {
const io = getSocketServer();
for (const socketId of [...socketIds]) {
if (io) {
const socket = io.sockets.sockets.get(socketId) as AuthSocket | undefined;
if (socket) {
if (appRedis && socket.auth) {
await setOffline(appRedis, socket.auth.userId, socketId);
}
socket.emit('device_revoked', { message: 'This device has been revoked' });
socket.disconnect(true);
}
}
unregisterDeviceSocket(socketId);
}
}
});
}
73 changes: 73 additions & 0 deletions apps/backend/src/services/heartbeat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import type { Server } from 'socket.io';
import type { Redis } from 'ioredis';
import type { AuthSocket } from '../middleware/socketAuth.js';
import { db } from '../db/index.js';
import { devices } from '../db/schema.js';
import { eq } from 'drizzle-orm';
import { refreshPresence, markDeviceOffline } from './presence.js';

const HEARTBEAT_TIMEOUT_MS = 90_000;
const LAST_SEEN_THROTTLE_MS = 30_000;

const timers = new Map<string, ReturnType<typeof setTimeout>>();
const lastSeenAt = new Map<string, number>();

export function startHeartbeatTimer(
socket: AuthSocket,
userId: string,
deviceId: string,
redis: Redis | null,
io: Server,
): void {
const schedule = () => {
clearTimeout(timers.get(socket.id));
const timer = setTimeout(async () => {
timers.delete(socket.id);
console.log(`Heartbeat timeout for device ${deviceId} (user ${userId})`);

if (redis) {
await markDeviceOffline(redis, userId);
}

if (socket.connected) {
for (const room of socket.rooms) {
if (room !== socket.id) {
io.to(room).volatile.emit('user_offline', { userId });
io.to(room).volatile.emit('presence_update', { userId, online: false });
}
}
socket.disconnect(true);
}
}, HEARTBEAT_TIMEOUT_MS);
timers.set(socket.id, timer);
};

schedule();

socket.on('heartbeat', async () => {
clearTimeout(timers.get(socket.id));
timers.delete(socket.id);

if (redis) {
await refreshPresence(redis, userId);
}

const now = Date.now();
const last = lastSeenAt.get(deviceId) ?? 0;
if (now - last >= LAST_SEEN_THROTTLE_MS) {
lastSeenAt.set(deviceId, now);
try {
await db.update(devices).set({ updatedAt: new Date() }).where(eq(devices.id, deviceId));
} catch {
// Non-critical update; ignore errors.
}
}

schedule();
});
}

export function clearHeartbeatTimer(socketId: string): void {
clearTimeout(timers.get(socketId));
timers.delete(socketId);
}
Loading
Loading