diff --git a/apps/backend/src/__tests__/deliveryPipeline.test.ts b/apps/backend/src/__tests__/deliveryPipeline.test.ts new file mode 100644 index 0000000..daea17c --- /dev/null +++ b/apps/backend/src/__tests__/deliveryPipeline.test.ts @@ -0,0 +1,222 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +// ── Mock DB ──────────────────────────────────────────────────────────────── + +const mockSelect = vi.fn(); +const mockFrom = vi.fn(); +const mockWhere = vi.fn(); + +vi.mock('../db/index.js', () => ({ + db: { + select: mockSelect, + }, +})); + +vi.mock('../db/schema.js', () => ({ + conversationMembers: { conversationId: 'conv_id', userId: 'user_id' }, + userDevices: { userId: 'ud_user_id', revokedAt: 'ud_revoked_at', id: 'ud_id' }, + messageEnvelopes: { + messageId: 'me_msg_id', + recipientDeviceId: 'me_rcpt_device_id', + id: 'me_id', + ciphertext: 'me_ciphertext', + }, +})); + +vi.mock('drizzle-orm', () => ({ + and: vi.fn((...args: unknown[]) => ({ and: args })), + eq: vi.fn((col: unknown, val: unknown) => ({ eq: [col, val] })), + inArray: vi.fn((col: unknown, vals: unknown) => ({ inArray: [col, vals] })), + isNull: vi.fn((col: unknown) => ({ isNull: col })), +})); + +// ── Helpers ──────────────────────────────────────────────────────────────── + +function makeIo() { + const emissions: Record = {}; + + const emitFn = (room: string) => + vi.fn((event: string, data: unknown) => { + emissions[room] ??= []; + emissions[room]!.push({ event, data }); + }); + + const roomEmit: Record> = {}; + + const io = { + to: vi.fn((room: string) => { + roomEmit[room] ??= emitFn(room); + return { emit: roomEmit[room] }; + }), + emissions, + roomEmit, + }; + return io; +} + +function baseMessage() { + return { + id: 'msg-1', + conversationId: 'conv-1', + senderId: 'user-a', + senderDeviceId: 'dev-a', + contentType: 'text/plain', + sequenceNumber: 1, + createdAt: new Date('2024-01-01'), + deletedAt: null, + ciphertext: 'base-ciphertext', + }; +} + +// ── Tests ────────────────────────────────────────────────────────────────── + +describe('deliverMessage', () => { + beforeEach(() => { + vi.clearAllMocks(); + + // Default chained select builder + mockWhere.mockResolvedValue([]); + mockFrom.mockReturnValue({ where: mockWhere }); + mockSelect.mockReturnValue({ from: mockFrom }); + }); + + it('emits message_envelope to each active device that has an envelope', async () => { + const members = [{ userId: 'user-b' }]; + const activeDevices = [{ id: 'dev-b', userId: 'user-b' }]; + const envelopes = [ + { id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'encrypted-for-dev-b' }, + ]; + + mockWhere + .mockResolvedValueOnce(members) + .mockResolvedValueOnce(activeDevices) + .mockResolvedValueOnce(envelopes); + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + // Device-scoped emission + expect(io.to).toHaveBeenCalledWith('device:dev-b'); + const deviceEmit = io.roomEmit['device:dev-b']; + expect(deviceEmit).toHaveBeenCalledWith( + 'message_envelope', + expect.objectContaining({ + messageId: 'msg-1', + conversationId: 'conv-1', + envelopeId: 'env-1', + ciphertext: 'encrypted-for-dev-b', + }), + ); + }); + + it('emits new_message to conversation room without ciphertext', async () => { + const members = [{ userId: 'user-b' }]; + const activeDevices = [{ id: 'dev-b', userId: 'user-b' }]; + const envelopes = [ + { id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'encrypted-for-dev-b' }, + ]; + + mockWhere + .mockResolvedValueOnce(members) + .mockResolvedValueOnce(activeDevices) + .mockResolvedValueOnce(envelopes); + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + expect(io.to).toHaveBeenCalledWith('conv-1'); + const roomEmit = io.roomEmit['conv-1']; + expect(roomEmit).toHaveBeenCalledWith( + 'new_message', + expect.objectContaining({ id: 'msg-1', ciphertext: null }), + ); + }); + + it('skips devices that have no envelope', async () => { + const members = [{ userId: 'user-b' }, { userId: 'user-c' }]; + const activeDevices = [ + { id: 'dev-b', userId: 'user-b' }, + { id: 'dev-c', userId: 'user-c' }, + ]; + // Only dev-b has an envelope; dev-c does not. + const envelopes = [{ id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'ct-b' }]; + + mockWhere + .mockResolvedValueOnce(members) + .mockResolvedValueOnce(activeDevices) + .mockResolvedValueOnce(envelopes); + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + expect(io.to).toHaveBeenCalledWith('device:dev-b'); + expect(io.to).not.toHaveBeenCalledWith('device:dev-c'); + }); + + it('only emits new_message to room when no active devices exist', async () => { + const members = [{ userId: 'user-b' }]; + + mockWhere + .mockResolvedValueOnce(members) + .mockResolvedValueOnce([]); // no active devices + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + // Should emit new_message as fallback + expect(io.to).toHaveBeenCalledWith('conv-1'); + expect(io.roomEmit['conv-1']).toHaveBeenCalledWith('new_message', expect.anything()); + // No device-scoped emission + expect(Object.keys(io.roomEmit)).not.toContain('device:dev-b'); + }); + + it('returns early when there are no members', async () => { + mockWhere.mockResolvedValueOnce([]); // no members + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + expect(io.to).not.toHaveBeenCalled(); + }); + + it('delivers envelopes to multiple devices independently', async () => { + const members = [{ userId: 'user-b' }, { userId: 'user-c' }]; + const activeDevices = [ + { id: 'dev-b', userId: 'user-b' }, + { id: 'dev-c', userId: 'user-c' }, + ]; + const envelopes = [ + { id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'ct-b' }, + { id: 'env-2', recipientDeviceId: 'dev-c', ciphertext: 'ct-c' }, + ]; + + mockWhere + .mockResolvedValueOnce(members) + .mockResolvedValueOnce(activeDevices) + .mockResolvedValueOnce(envelopes); + + const io = makeIo(); + const { deliverMessage } = await import('../services/deliveryPipeline.js'); + + await deliverMessage(io as never, baseMessage() as never, 'conv-1'); + + expect(io.roomEmit['device:dev-b']).toHaveBeenCalledWith( + 'message_envelope', + expect.objectContaining({ ciphertext: 'ct-b', envelopeId: 'env-1' }), + ); + expect(io.roomEmit['device:dev-c']).toHaveBeenCalledWith( + 'message_envelope', + expect.objectContaining({ ciphertext: 'ct-c', envelopeId: 'env-2' }), + ); + }); +}); diff --git a/apps/backend/src/index.ts b/apps/backend/src/index.ts index 3335e09..f4684fd 100644 --- a/apps/backend/src/index.ts +++ b/apps/backend/src/index.ts @@ -101,6 +101,11 @@ io.on('connection', async (socket: AuthSocket) => { next(); }); + // Join a device-scoped room so the delivery pipeline can push envelopes to + // exactly this device, even across horizontally-scaled instances via the + // Redis adapter. + await socket.join(`device:${deviceId}`); + // Auto-join all conversation rooms so the socket receives new_message events // for every conversation the user belongs to (needed for unread badge tracking). const memberships = await db.query.conversationMembers.findMany({ diff --git a/apps/backend/src/services/deliveryPipeline.ts b/apps/backend/src/services/deliveryPipeline.ts new file mode 100644 index 0000000..3730dfe --- /dev/null +++ b/apps/backend/src/services/deliveryPipeline.ts @@ -0,0 +1,102 @@ +import { and, eq, inArray, isNull } from 'drizzle-orm'; +import type { Server } from 'socket.io'; +import { db } from '../db/index.js'; +import { conversationMembers, messageEnvelopes, userDevices } from '../db/schema.js'; +import type { Message } from '../db/schema.js'; + +/** + * Room name for per-device targeting. Each socket joins this room on connect + * so that io.to(deviceRoom(id)) reaches exactly that device across all instances + * via the Redis adapter. + */ +export function deviceRoom(deviceId: string): string { + return `device:${deviceId}`; +} + +/** + * Deliver a persisted message to every active recipient device. + * + * Order of operations (persist-before-deliver is guaranteed by callers): + * 1. Re-validate members from conversation_members (not from room state). + * 2. Resolve active (non-revoked) devices for those members. + * 3. Load persisted envelopes — only devices that have one get delivered. + * 4. Emit message_envelope to each device's scoped room with its ciphertext. + * 5. Emit new_message to the conversation room so clients update their UI. + */ +export async function deliverMessage( + io: Server, + message: Message, + conversationId: string, +): Promise { + // Step 1: re-validate membership from the source of truth. + const members = await db + .select({ userId: conversationMembers.userId }) + .from(conversationMembers) + .where(eq(conversationMembers.conversationId, conversationId)); + + if (members.length === 0) return; + + const userIds = members.map((m) => m.userId); + + // Step 2: active devices only — revokedAt IS NULL. + const activeDevices = await db + .select({ id: userDevices.id, userId: userDevices.userId }) + .from(userDevices) + .where(and(inArray(userDevices.userId, userIds), isNull(userDevices.revokedAt))); + + if (activeDevices.length === 0) { + io.to(conversationId).emit('new_message', message); + return; + } + + const activeDeviceIds = activeDevices.map((d) => d.id); + + // Step 3: load envelopes already committed to the database. + const envelopes = await db + .select({ + id: messageEnvelopes.id, + recipientDeviceId: messageEnvelopes.recipientDeviceId, + ciphertext: messageEnvelopes.ciphertext, + }) + .from(messageEnvelopes) + .where( + and( + eq(messageEnvelopes.messageId, message.id), + inArray(messageEnvelopes.recipientDeviceId, activeDeviceIds), + ), + ); + + const envelopeByDevice = new Map(envelopes.map((e) => [e.recipientDeviceId, e])); + + // Step 4: push each device exactly its envelope. + for (const device of activeDevices) { + const envelope = envelopeByDevice.get(device.id); + if (!envelope) continue; + + io.to(deviceRoom(device.id)).emit('message_envelope', { + messageId: message.id, + conversationId, + senderId: message.senderId, + senderDeviceId: message.senderDeviceId, + contentType: message.contentType, + sequenceNumber: message.sequenceNumber, + createdAt: message.createdAt, + envelopeId: envelope.id, + ciphertext: envelope.ciphertext, + }); + } + + // Step 5: room-level notification so clients can update unread counts / UI. + // Ciphertext is intentionally omitted here; each device received it above. + io.to(conversationId).emit('new_message', { + id: message.id, + conversationId, + senderId: message.senderId, + senderDeviceId: message.senderDeviceId, + contentType: message.contentType, + sequenceNumber: message.sequenceNumber, + createdAt: message.createdAt, + deletedAt: message.deletedAt, + ciphertext: null, + }); +} diff --git a/apps/backend/src/socket/messaging.ts b/apps/backend/src/socket/messaging.ts index 47b1471..4ab3fcd 100644 --- a/apps/backend/src/socket/messaging.ts +++ b/apps/backend/src/socket/messaging.ts @@ -12,6 +12,7 @@ import type { AuthSocket } from '../middleware/socketAuth.js'; import { invalidateConversationCaches } from '../lib/conversationCache.js'; import { serializeMessage } from '../lib/messages.js'; import { redis } from '../lib/redis.js'; +import { deliverMessage } from '../services/deliveryPipeline.js'; const PAGE_SIZE = 30; @@ -130,7 +131,9 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void socket.emit('message_ack', { messageId, sequenceNumber: message.sequenceNumber }); } - io.to(conversationId).emit('new_message', message); + // Deliver: storage is guaranteed above; pipeline re-validates membership, + // resolves active devices, and pushes each device exactly its envelope. + await deliverMessage(io, message, conversationId); const members = await db.query.conversationMembers.findMany({ where: eq(conversationMembers.conversationId, conversationId),