Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions apps/backend/src/__tests__/deliveryPipeline.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';

// ── Mock DB ────────────────────────────────────────────────────────────────

const mockSelect = vi.fn();
const mockFrom = vi.fn();
const mockWhere = vi.fn();

vi.mock('../db/index.js', () => ({
db: {
select: mockSelect,
},
}));

vi.mock('../db/schema.js', () => ({
conversationMembers: { conversationId: 'conv_id', userId: 'user_id' },
userDevices: { userId: 'ud_user_id', revokedAt: 'ud_revoked_at', id: 'ud_id' },
messageEnvelopes: {
messageId: 'me_msg_id',
recipientDeviceId: 'me_rcpt_device_id',
id: 'me_id',
ciphertext: 'me_ciphertext',
},
}));

vi.mock('drizzle-orm', () => ({
and: vi.fn((...args: unknown[]) => ({ and: args })),
eq: vi.fn((col: unknown, val: unknown) => ({ eq: [col, val] })),
inArray: vi.fn((col: unknown, vals: unknown) => ({ inArray: [col, vals] })),
isNull: vi.fn((col: unknown) => ({ isNull: col })),
}));

// ── Helpers ────────────────────────────────────────────────────────────────

function makeIo() {
const emissions: Record<string, { event: string; data: unknown }[]> = {};

const emitFn = (room: string) =>
vi.fn((event: string, data: unknown) => {
emissions[room] ??= [];
emissions[room]!.push({ event, data });
});

const roomEmit: Record<string, ReturnType<typeof vi.fn>> = {};

const io = {
to: vi.fn((room: string) => {
roomEmit[room] ??= emitFn(room);
return { emit: roomEmit[room] };
}),
emissions,
roomEmit,
};
return io;
}

function baseMessage() {
return {
id: 'msg-1',
conversationId: 'conv-1',
senderId: 'user-a',
senderDeviceId: 'dev-a',
contentType: 'text/plain',
sequenceNumber: 1,
createdAt: new Date('2024-01-01'),
deletedAt: null,
ciphertext: 'base-ciphertext',
};
}

// ── Tests ──────────────────────────────────────────────────────────────────

describe('deliverMessage', () => {
beforeEach(() => {
vi.clearAllMocks();

// Default chained select builder
mockWhere.mockResolvedValue([]);
mockFrom.mockReturnValue({ where: mockWhere });
mockSelect.mockReturnValue({ from: mockFrom });
});

it('emits message_envelope to each active device that has an envelope', async () => {
const members = [{ userId: 'user-b' }];
const activeDevices = [{ id: 'dev-b', userId: 'user-b' }];
const envelopes = [
{ id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'encrypted-for-dev-b' },
];

mockWhere
.mockResolvedValueOnce(members)
.mockResolvedValueOnce(activeDevices)
.mockResolvedValueOnce(envelopes);

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

// Device-scoped emission
expect(io.to).toHaveBeenCalledWith('device:dev-b');
const deviceEmit = io.roomEmit['device:dev-b'];
expect(deviceEmit).toHaveBeenCalledWith(
'message_envelope',
expect.objectContaining({
messageId: 'msg-1',
conversationId: 'conv-1',
envelopeId: 'env-1',
ciphertext: 'encrypted-for-dev-b',
}),
);
});

it('emits new_message to conversation room without ciphertext', async () => {
const members = [{ userId: 'user-b' }];
const activeDevices = [{ id: 'dev-b', userId: 'user-b' }];
const envelopes = [
{ id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'encrypted-for-dev-b' },
];

mockWhere
.mockResolvedValueOnce(members)
.mockResolvedValueOnce(activeDevices)
.mockResolvedValueOnce(envelopes);

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

expect(io.to).toHaveBeenCalledWith('conv-1');
const roomEmit = io.roomEmit['conv-1'];
expect(roomEmit).toHaveBeenCalledWith(
'new_message',
expect.objectContaining({ id: 'msg-1', ciphertext: null }),
);
});

it('skips devices that have no envelope', async () => {
const members = [{ userId: 'user-b' }, { userId: 'user-c' }];
const activeDevices = [
{ id: 'dev-b', userId: 'user-b' },
{ id: 'dev-c', userId: 'user-c' },
];
// Only dev-b has an envelope; dev-c does not.
const envelopes = [{ id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'ct-b' }];

mockWhere
.mockResolvedValueOnce(members)
.mockResolvedValueOnce(activeDevices)
.mockResolvedValueOnce(envelopes);

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

expect(io.to).toHaveBeenCalledWith('device:dev-b');
expect(io.to).not.toHaveBeenCalledWith('device:dev-c');
});

it('only emits new_message to room when no active devices exist', async () => {
const members = [{ userId: 'user-b' }];

mockWhere
.mockResolvedValueOnce(members)
.mockResolvedValueOnce([]); // no active devices

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

// Should emit new_message as fallback
expect(io.to).toHaveBeenCalledWith('conv-1');
expect(io.roomEmit['conv-1']).toHaveBeenCalledWith('new_message', expect.anything());
// No device-scoped emission
expect(Object.keys(io.roomEmit)).not.toContain('device:dev-b');
});

it('returns early when there are no members', async () => {
mockWhere.mockResolvedValueOnce([]); // no members

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

expect(io.to).not.toHaveBeenCalled();
});

it('delivers envelopes to multiple devices independently', async () => {
const members = [{ userId: 'user-b' }, { userId: 'user-c' }];
const activeDevices = [
{ id: 'dev-b', userId: 'user-b' },
{ id: 'dev-c', userId: 'user-c' },
];
const envelopes = [
{ id: 'env-1', recipientDeviceId: 'dev-b', ciphertext: 'ct-b' },
{ id: 'env-2', recipientDeviceId: 'dev-c', ciphertext: 'ct-c' },
];

mockWhere
.mockResolvedValueOnce(members)
.mockResolvedValueOnce(activeDevices)
.mockResolvedValueOnce(envelopes);

const io = makeIo();
const { deliverMessage } = await import('../services/deliveryPipeline.js');

await deliverMessage(io as never, baseMessage() as never, 'conv-1');

expect(io.roomEmit['device:dev-b']).toHaveBeenCalledWith(
'message_envelope',
expect.objectContaining({ ciphertext: 'ct-b', envelopeId: 'env-1' }),
);
expect(io.roomEmit['device:dev-c']).toHaveBeenCalledWith(
'message_envelope',
expect.objectContaining({ ciphertext: 'ct-c', envelopeId: 'env-2' }),
);
});
});
5 changes: 5 additions & 0 deletions apps/backend/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ io.on('connection', async (socket: AuthSocket) => {
next();
});

// Join a device-scoped room so the delivery pipeline can push envelopes to
// exactly this device, even across horizontally-scaled instances via the
// Redis adapter.
await socket.join(`device:${deviceId}`);

// Auto-join all conversation rooms so the socket receives new_message events
// for every conversation the user belongs to (needed for unread badge tracking).
const memberships = await db.query.conversationMembers.findMany({
Expand Down
102 changes: 102 additions & 0 deletions apps/backend/src/services/deliveryPipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import { and, eq, inArray, isNull } from 'drizzle-orm';
import type { Server } from 'socket.io';
import { db } from '../db/index.js';
import { conversationMembers, messageEnvelopes, userDevices } from '../db/schema.js';
import type { Message } from '../db/schema.js';

/**
* Room name for per-device targeting. Each socket joins this room on connect
* so that io.to(deviceRoom(id)) reaches exactly that device across all instances
* via the Redis adapter.
*/
export function deviceRoom(deviceId: string): string {
return `device:${deviceId}`;
}

/**
* Deliver a persisted message to every active recipient device.
*
* Order of operations (persist-before-deliver is guaranteed by callers):
* 1. Re-validate members from conversation_members (not from room state).
* 2. Resolve active (non-revoked) devices for those members.
* 3. Load persisted envelopes — only devices that have one get delivered.
* 4. Emit message_envelope to each device's scoped room with its ciphertext.
* 5. Emit new_message to the conversation room so clients update their UI.
*/
export async function deliverMessage(
io: Server,
message: Message,
conversationId: string,
): Promise<void> {
// Step 1: re-validate membership from the source of truth.
const members = await db
.select({ userId: conversationMembers.userId })
.from(conversationMembers)
.where(eq(conversationMembers.conversationId, conversationId));

if (members.length === 0) return;

const userIds = members.map((m) => m.userId);

// Step 2: active devices only — revokedAt IS NULL.
const activeDevices = await db
.select({ id: userDevices.id, userId: userDevices.userId })
.from(userDevices)
.where(and(inArray(userDevices.userId, userIds), isNull(userDevices.revokedAt)));

if (activeDevices.length === 0) {
io.to(conversationId).emit('new_message', message);
return;
}

const activeDeviceIds = activeDevices.map((d) => d.id);

// Step 3: load envelopes already committed to the database.
const envelopes = await db
.select({
id: messageEnvelopes.id,
recipientDeviceId: messageEnvelopes.recipientDeviceId,
ciphertext: messageEnvelopes.ciphertext,
})
.from(messageEnvelopes)
.where(
and(
eq(messageEnvelopes.messageId, message.id),
inArray(messageEnvelopes.recipientDeviceId, activeDeviceIds),
),
);

const envelopeByDevice = new Map(envelopes.map((e) => [e.recipientDeviceId, e]));

// Step 4: push each device exactly its envelope.
for (const device of activeDevices) {
const envelope = envelopeByDevice.get(device.id);
if (!envelope) continue;

io.to(deviceRoom(device.id)).emit('message_envelope', {
messageId: message.id,
conversationId,
senderId: message.senderId,
senderDeviceId: message.senderDeviceId,
contentType: message.contentType,
sequenceNumber: message.sequenceNumber,
createdAt: message.createdAt,
envelopeId: envelope.id,
ciphertext: envelope.ciphertext,
});
}

// Step 5: room-level notification so clients can update unread counts / UI.
// Ciphertext is intentionally omitted here; each device received it above.
io.to(conversationId).emit('new_message', {
id: message.id,
conversationId,
senderId: message.senderId,
senderDeviceId: message.senderDeviceId,
contentType: message.contentType,
sequenceNumber: message.sequenceNumber,
createdAt: message.createdAt,
deletedAt: message.deletedAt,
ciphertext: null,
});
}
5 changes: 4 additions & 1 deletion apps/backend/src/socket/messaging.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { AuthSocket } from '../middleware/socketAuth.js';
import { invalidateConversationCaches } from '../lib/conversationCache.js';
import { serializeMessage } from '../lib/messages.js';
import { redis } from '../lib/redis.js';
import { deliverMessage } from '../services/deliveryPipeline.js';

const PAGE_SIZE = 30;

Expand Down Expand Up @@ -130,7 +131,9 @@ export function registerMessagingHandlers(io: Server, socket: AuthSocket): void
socket.emit('message_ack', { messageId, sequenceNumber: message.sequenceNumber });
}

io.to(conversationId).emit('new_message', message);
// Deliver: storage is guaranteed above; pipeline re-validates membership,
// resolves active devices, and pushes each device exactly its envelope.
await deliverMessage(io, message, conversationId);

const members = await db.query.conversationMembers.findMany({
where: eq(conversationMembers.conversationId, conversationId),
Expand Down