Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions apps/backend/src/__tests__/presence.reconciliation.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { reconcileBoot, cleanupStaleSockets, setOffline } from '../services/presence.js';

// ── DB mock ────────────────────────────────────────────────────────────────
const { mockFindMany } = vi.hoisted(() => ({
mockFindMany: vi.fn(),
}));

vi.mock('../db/index.js', () => ({
db: {
query: {
conversationMembers: { findMany: mockFindMany },
},
},
}));

vi.mock('../db/schema.js', () => ({
conversationMembers: {
userId: 'userId',
conversationId: 'conversationId',
},
}));

vi.mock('drizzle-orm', () => ({
eq: vi.fn((col: unknown, val: unknown) => ({ col, val })),
}));

// ── Redis & Socket mock ────────────────────────────────────────────────────

describe('Presence Reconciliation & Gateway Boot (#...)', () => {
let mockRedis: any;
let mockIo: any;
let mockSocketsJoin: any;
let mockFetchSockets: any;

beforeEach(() => {
vi.clearAllMocks();

mockSocketsJoin = vi.fn();
mockFetchSockets = vi.fn().mockResolvedValue([]);

mockIo = {
in: vi.fn((sid: string) => ({
socketsJoin: mockSocketsJoin,
fetchSockets: () => mockFetchSockets(sid),
})),
};

mockRedis = {
scan: vi.fn(),
keys: vi.fn(),
smembers: vi.fn(),
srem: vi.fn(),
scard: vi.fn(),
del: vi.fn(),
};
});

describe('reconcileBoot', () => {
it('rebuilds room subscriptions from active Redis socket mappings on boot', async () => {
// redis.scan returns presence keys
mockRedis.scan
.mockResolvedValueOnce(['10', ['presence:user-1', 'presence:user-2']])
.mockResolvedValueOnce(['0', []]);

mockRedis.smembers.mockImplementation(async (key: string) => {
if (key === 'presence:user-1') return ['socket-1a', 'socket-1b'];
if (key === 'presence:user-2') return ['socket-2a'];
return [];
});

mockFindMany.mockImplementation(async ({ where }: any) => {
if (where.val === 'user-1') {
return [{ conversationId: 'room-alpha' }, { conversationId: 'room-beta' }];
}
if (where.val === 'user-2') {
return [{ conversationId: 'room-gamma' }];
}
return [];
});

await reconcileBoot(mockIo as any, mockRedis as any);

expect(mockRedis.scan).toHaveBeenCalledTimes(2);
expect(mockFindMany).toHaveBeenCalledTimes(2);

// user-1 sockets joined room-alpha & room-beta
expect(mockIo.in).toHaveBeenCalledWith('socket-1a');
expect(mockIo.in).toHaveBeenCalledWith('socket-1b');
expect(mockIo.in).toHaveBeenCalledWith('socket-2a');
expect(mockSocketsJoin).toHaveBeenCalledWith('room-alpha');
expect(mockSocketsJoin).toHaveBeenCalledWith('room-beta');
expect(mockSocketsJoin).toHaveBeenCalledWith('room-gamma');
});

it('falls back to redis.keys if redis.scan throws', async () => {
mockRedis.scan.mockRejectedValue(new Error('scan not supported'));
mockRedis.keys.mockResolvedValue(['presence:user-3']);
mockRedis.smembers.mockResolvedValue(['socket-3a']);
mockFindMany.mockResolvedValue([{ conversationId: 'room-delta' }]);

await reconcileBoot(mockIo as any, mockRedis as any);

expect(mockRedis.keys).toHaveBeenCalledWith('presence:*');
expect(mockSocketsJoin).toHaveBeenCalledWith('room-delta');
});
});

describe('cleanupStaleSockets', () => {
it('removes stale socket IDs from Redis presence set and deletes empty sets', async () => {
mockRedis.smembers.mockResolvedValue(['socket-dead', 'socket-alive']);

mockFetchSockets.mockImplementation(async (sid: string) => {
if (sid === 'socket-alive') return [{ id: 'socket-alive' }]; // still connected
return []; // dead socket
});

mockRedis.scard.mockResolvedValue(1);

await cleanupStaleSockets(mockIo as any, mockRedis as any, 'user-1');

expect(mockRedis.srem).toHaveBeenCalledWith('presence:user-1', 'socket-dead');
expect(mockRedis.srem).not.toHaveBeenCalledWith('presence:user-1', 'socket-alive');
expect(mockRedis.del).not.toHaveBeenCalled();
});

it('deletes presence key if all sockets were stale and removed', async () => {
mockRedis.smembers.mockResolvedValue(['socket-dead-1']);
mockFetchSockets.mockResolvedValue([]); // dead socket
mockRedis.scard.mockResolvedValue(0);

await cleanupStaleSockets(mockIo as any, mockRedis as any, 'user-2');

expect(mockRedis.srem).toHaveBeenCalledWith('presence:user-2', 'socket-dead-1');
expect(mockRedis.del).toHaveBeenCalledWith('presence:user-2');
});

it('ignores activeSocketId if passed', async () => {
mockRedis.smembers.mockResolvedValue(['socket-new']);

await cleanupStaleSockets(mockIo as any, mockRedis as any, 'user-3', 'socket-new');

expect(mockFetchSockets).not.toHaveBeenCalled();
expect(mockRedis.srem).not.toHaveBeenCalled();
});
});

describe('setOffline', () => {
it('removes socket ID and returns true when no sockets remain', async () => {
mockRedis.scard.mockResolvedValue(0);
const offline = await setOffline(mockRedis as any, 'user-1', 'socket-1');
expect(mockRedis.srem).toHaveBeenCalledWith('presence:user-1', 'socket-1');
expect(mockRedis.del).toHaveBeenCalledWith('presence:user-1');
expect(offline).toBe(true);
});

it('returns false when surviving connections remain', async () => {
mockRedis.scard.mockResolvedValue(1);
const offline = await setOffline(mockRedis as any, 'user-1', 'socket-1');
expect(mockRedis.srem).toHaveBeenCalledWith('presence:user-1', 'socket-1');
expect(mockRedis.del).not.toHaveBeenCalled();
expect(offline).toBe(false);
});
});
});
65 changes: 65 additions & 0 deletions apps/backend/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ import { registerMessagingHandlers } from './socket/messaging.js';
import { app } from './app.js';
import { redis as appRedis } from './lib/redis.js';
import { setSocketServer } from './lib/socket.js';

import {
setOnline,
setOffline,
refreshPresence,
reconcileBoot,
cleanupStaleSockets,
} from './services/presence.js';

import { setOnline, setOffline } from './services/presence.js';
import { startHeartbeatTimer, clearHeartbeatTimer } from './services/heartbeat.js';
import {
Expand All @@ -26,6 +35,7 @@ import {
clearViolations,
} from './services/rateLimit.js';
import { registerForBackpressure, unregisterForBackpressure } from './services/backpressure.js';

import {
buildRpcFetcher,
buildTreasuryRpcFetcher,
Expand All @@ -44,6 +54,27 @@ const io = new Server(httpServer, {
cors: { origin: '*' },
});

let isShuttingDown = false;

const handleShutdown = () => {
isShuttingDown = true;
};

process.on('SIGTERM', handleShutdown);
process.on('SIGINT', handleShutdown);

const origIoClose = io.close.bind(io);
io.close = ((fn?: () => void) => {
isShuttingDown = true;
return origIoClose(fn);
}) as typeof io.close;

const origHttpClose = httpServer.close.bind(httpServer);
httpServer.close = ((fn?: (err?: Error) => void) => {
isShuttingDown = true;
return origHttpClose(fn);
}) as typeof httpServer.close;

setSocketServer(io);

io.use(socketAuthMiddleware);
Expand Down Expand Up @@ -118,6 +149,7 @@ io.on('connection', async (socket: AuthSocket) => {
const presenceVisible = user?.presenceVisible ?? true;

if (appRedis) {
await cleanupStaleSockets(io, appRedis, userId, socket.id);
await setOnline(appRedis, userId, socket.id);
if (presenceVisible) {
for (const m of memberships) {
Expand All @@ -127,6 +159,26 @@ io.on('connection', async (socket: AuthSocket) => {
}
}


socket.on('heartbeat', async () => {
if (appRedis) {
await cleanupStaleSockets(io, appRedis, userId, socket.id);
await refreshPresence(appRedis, userId);
}
});

registerMessagingHandlers(io, socket);

socket.on('disconnect', async (reason: string) => {
console.log('User disconnected:', userId, reason);
if (
isShuttingDown ||
reason === 'server shutting down' ||
reason === 'server namespace disconnect'
) {
return;
}

registerMessagingHandlers(io, socket);

// Monitor send-buffer to detect slow/stalled consumers.
Expand All @@ -139,7 +191,9 @@ io.on('connection', async (socket: AuthSocket) => {
unregisterForBackpressure(socket);
clearViolations(socket.id);


if (appRedis) {
await cleanupStaleSockets(io, appRedis, userId, socket.id);
const fullyOffline = await setOffline(appRedis, userId, socket.id);
if (fullyOffline) {
const user = await db.query.users.findFirst({
Expand Down Expand Up @@ -192,6 +246,15 @@ async function attachRedisAdapter(): Promise<void> {
const message = err instanceof Error ? err.message : String(err);
console.warn(`[socket.io] Redis unavailable (${message}) — running in single-instance mode`);
await Promise.allSettled([pubClient.quit(), subClient.quit()]);
} finally {
if (appRedis) {
try {
await reconcileBoot(io, appRedis);
console.log('[presence] Boot reconciliation complete');
} catch (err) {
console.warn('[presence] Boot reconciliation failed:', err);
}
}
}
}

Expand Down Expand Up @@ -237,3 +300,5 @@ if (stellarRpcUrl && tokenTransferContractId) {
'[stellar-listener] STELLAR_RPC_URL or TOKEN_TRANSFER_CONTRACT_ID unset; listener disabled.',
);
}

export { httpServer, io };
5 changes: 5 additions & 0 deletions apps/backend/src/routes/treasury.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@

import { Router } from 'express';
import type { IRouter } from 'express';

import { Router, type IRouter } from 'express';

import { z } from 'zod';
import { requireAuth, type AuthRequest } from '../middleware/auth.js';
import { validate } from '../middleware/validate.js';
Expand Down
79 changes: 79 additions & 0 deletions apps/backend/src/services/presence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
* - On disconnect: remove socketId from set, if set empty → user_offline
* - GET /users/:id/presence → { online: boolean }
*/
import type { Server } from 'socket.io';
import type { Redis } from 'ioredis';
import { eq } from 'drizzle-orm';
import { db } from '../db/index.js';
import { conversationMembers } from '../db/schema.js';

const PRESENCE_TTL = 90; // seconds

Expand Down Expand Up @@ -71,3 +75,78 @@ export async function isOnline(redis: Redis, userId: string): Promise<boolean> {
const count = await redis.scard(key);
return count > 0;
}

/**
* Remove any socket IDs in the user's presence set that are no longer
* connected anywhere in the Socket.IO cluster.
*/
export async function cleanupStaleSockets(
io: Server,
redis: Redis,
userId: string,
ignoredSocketId?: string,
): Promise<void> {
const key = presenceKey(userId);
const socketIds = await redis.smembers(key);
if (socketIds.length === 0) return;

await Promise.all(
socketIds.map(async (sid) => {
if (ignoredSocketId && sid === ignoredSocketId) return;
try {
const sockets = await io.in(sid).fetchSockets();
if (sockets.length === 0) {
await redis.srem(key, sid);
}
} catch (err) {
console.warn(`[presence] Failed to check socket status for ${sid}:`, err);
}
}),
);

const remaining = await redis.scard(key);
if (remaining === 0) {
await redis.del(key);
}
}

/**
* Rebuild room subscriptions from active Redis socket mappings on gateway boot.
*/
export async function reconcileBoot(io: Server, redis: Redis): Promise<void> {
let presenceKeys: string[];
try {
let cursor = '0';
presenceKeys = [];
do {
const [nextCursor, keys] = await redis.scan(cursor, 'MATCH', 'presence:*', 'COUNT', 100);
cursor = nextCursor;
presenceKeys.push(...keys);
} while (cursor !== '0');
} catch {
presenceKeys = await redis.keys('presence:*');
}

for (const key of presenceKeys) {
const userId = key.slice('presence:'.length);
if (!userId) continue;

const socketIds = await redis.smembers(key);
if (socketIds.length === 0) continue;

try {
const memberships = await db.query.conversationMembers.findMany({
where: eq(conversationMembers.userId, userId),
columns: { conversationId: true },
});

for (const socketId of socketIds) {
for (const m of memberships) {
io.in(socketId).socketsJoin(m.conversationId);
}
}
} catch (err) {
console.warn(`[presence] Failed to rebuild subscriptions for ${userId}:`, err);
}
}
}