diff --git a/realtime/src/server.ts b/realtime/src/server.ts index 841faee..3a4ecaa 100644 --- a/realtime/src/server.ts +++ b/realtime/src/server.ts @@ -94,6 +94,10 @@ export const wss = new WebSocketServer({ const ipConnections = new Map(); const ipConnectionTimestamps = new Map(); const unauthorizedAccessCooldown = new Map(); +const unauthorizedCooldownExpirations: UnauthorizedCooldownExpiryEntry[] = []; +const MAX_UNAUTHORIZED_COOLDOWN_EXPIRATIONS_PER_CLEANUP = 512; +const UNAUTHORIZED_COOLDOWN_HEAP_REBUILD_MIN_SIZE = 1024; +const UNAUTHORIZED_COOLDOWN_HEAP_REBUILD_RATIO = 4; const ROOM_ID_PATTERN = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; interface UnauthorizedCooldownState { @@ -102,6 +106,11 @@ interface UnauthorizedCooldownState { suppressedAttempts: number; } +interface UnauthorizedCooldownExpiryEntry { + key: string; + denyUntil: number; +} + interface AccessCheckData { allowed: boolean; accessLevel: RealtimeAccessLevel | null; @@ -155,6 +164,122 @@ function buildUnauthorizedCooldownKey(roomId: string, clientIp: string, token: s return `${roomId}:${clientIp}:${getTokenFingerprint(token)}`; } +function swapUnauthorizedCooldownHeapEntries(a: number, b: number): void { + const tmp = unauthorizedCooldownExpirations[a]; + unauthorizedCooldownExpirations[a] = unauthorizedCooldownExpirations[b]; + unauthorizedCooldownExpirations[b] = tmp; +} + +function siftUnauthorizedCooldownExpirationUp(index: number): void { + let current = index; + + while (current > 0) { + const parent = Math.floor((current - 1) / 2); + if ( + unauthorizedCooldownExpirations[parent].denyUntil <= + unauthorizedCooldownExpirations[current].denyUntil + ) { + break; + } + + swapUnauthorizedCooldownHeapEntries(parent, current); + current = parent; + } +} + +function siftUnauthorizedCooldownExpirationDown(index: number): void { + let current = index; + const size = unauthorizedCooldownExpirations.length; + + while (true) { + const left = current * 2 + 1; + const right = left + 1; + let smallest = current; + + if ( + left < size && + unauthorizedCooldownExpirations[left].denyUntil < + unauthorizedCooldownExpirations[smallest].denyUntil + ) { + smallest = left; + } + + if ( + right < size && + unauthorizedCooldownExpirations[right].denyUntil < + unauthorizedCooldownExpirations[smallest].denyUntil + ) { + smallest = right; + } + + if (smallest === current) { + break; + } + + swapUnauthorizedCooldownHeapEntries(current, smallest); + current = smallest; + } +} + +function pushUnauthorizedCooldownExpiration(entry: UnauthorizedCooldownExpiryEntry): void { + unauthorizedCooldownExpirations.push(entry); + siftUnauthorizedCooldownExpirationUp(unauthorizedCooldownExpirations.length - 1); +} + +function peekUnauthorizedCooldownExpiration(): UnauthorizedCooldownExpiryEntry | null { + return unauthorizedCooldownExpirations[0] ?? null; +} + +function popUnauthorizedCooldownExpiration(): UnauthorizedCooldownExpiryEntry | null { + if (unauthorizedCooldownExpirations.length === 0) { + return null; + } + + const root = unauthorizedCooldownExpirations[0]; + const last = unauthorizedCooldownExpirations.pop(); + if (unauthorizedCooldownExpirations.length > 0 && last) { + unauthorizedCooldownExpirations[0] = last; + siftUnauthorizedCooldownExpirationDown(0); + } + + return root; +} + +function rebuildUnauthorizedCooldownExpirationHeap(): void { + unauthorizedCooldownExpirations.length = 0; + + for (const [key, state] of unauthorizedAccessCooldown.entries()) { + unauthorizedCooldownExpirations.push({ + key, + denyUntil: state.denyUntil, + }); + } + + for (let i = Math.floor(unauthorizedCooldownExpirations.length / 2) - 1; i >= 0; i -= 1) { + siftUnauthorizedCooldownExpirationDown(i); + } +} + +function maybeRebuildUnauthorizedCooldownExpirationHeap(): void { + if (unauthorizedCooldownExpirations.length < UNAUTHORIZED_COOLDOWN_HEAP_REBUILD_MIN_SIZE) { + return; + } + + if (unauthorizedAccessCooldown.size === 0) { + unauthorizedCooldownExpirations.length = 0; + return; + } + + if ( + unauthorizedCooldownExpirations.length <= + unauthorizedAccessCooldown.size * UNAUTHORIZED_COOLDOWN_HEAP_REBUILD_RATIO + ) { + return; + } + + rebuildUnauthorizedCooldownExpirationHeap(); +} + function getUnauthorizedCooldownState(key: string, now: number): UnauthorizedCooldownState | null { const existingState = unauthorizedAccessCooldown.get(key); if (!existingState) { @@ -173,6 +298,7 @@ function trackUnauthorizedAccess(key: string, now: number): UnauthorizedCooldown const existingState = getUnauthorizedCooldownState(key, now); if (existingState) { existingState.denyUntil = now + config.unauthorizedAccessCooldownMs; + pushUnauthorizedCooldownExpiration({ key, denyUntil: existingState.denyUntil }); return existingState; } @@ -183,6 +309,7 @@ function trackUnauthorizedAccess(key: string, now: number): UnauthorizedCooldown }; unauthorizedAccessCooldown.set(key, nextState); + pushUnauthorizedCooldownExpiration({ key, denyUntil: nextState.denyUntil }); return nextState; } @@ -222,11 +349,40 @@ function logUnauthorizedRejection( } function cleanupExpiredUnauthorizedCooldown(now = Date.now()): void { - for (const [key, state] of unauthorizedAccessCooldown.entries()) { - if (state.denyUntil <= now) { - unauthorizedAccessCooldown.delete(key); + let processedEntries = 0; + + while (processedEntries < MAX_UNAUTHORIZED_COOLDOWN_EXPIRATIONS_PER_CLEANUP) { + const nextExpiration = peekUnauthorizedCooldownExpiration(); + if (!nextExpiration || nextExpiration.denyUntil > now) { + break; + } + + const expiredEntry = popUnauthorizedCooldownExpiration(); + if (!expiredEntry) { + break; + } + + processedEntries += 1; + + const existingState = unauthorizedAccessCooldown.get(expiredEntry.key); + if (!existingState) { + continue; + } + + // The key may have been renewed after this heap entry was pushed. + if (existingState.denyUntil !== expiredEntry.denyUntil) { + continue; } + + unauthorizedAccessCooldown.delete(expiredEntry.key); } + + if (unauthorizedAccessCooldown.size === 0) { + unauthorizedCooldownExpirations.length = 0; + return; + } + + maybeRebuildUnauthorizedCooldownExpirationHeap(); } async function fetchAccess(token: string, roomId: string): Promise { diff --git a/realtime/tests/unit/server.test.ts b/realtime/tests/unit/server.test.ts index eb3975f..3ab2b0c 100644 --- a/realtime/tests/unit/server.test.ts +++ b/realtime/tests/unit/server.test.ts @@ -70,6 +70,7 @@ const waitForConnectionProcessing = async () => { describe('Server', () => { let server: any; let wss: any; + let cleanupInactiveRooms: any; let setupWSConnectionMock: any; let memoryUsageSpy: any; let fetchMock: jest.MockedFunction; @@ -103,6 +104,7 @@ describe('Server', () => { const serverModule = await import('../../src/server'); server = serverModule.server; wss = serverModule.wss; + cleanupInactiveRooms = serverModule.cleanupInactiveRooms; const yjsUtilsModule = await import('../../src/yjs-utils'); setupWSConnectionMock = yjsUtilsModule.setupWSConnection; @@ -263,6 +265,59 @@ describe('Server', () => { } }); + it('should not clear renewed cooldown state when stale expirations are cleaned', async () => { + jest.useFakeTimers(); + + try { + fetchMock.mockResolvedValue({ + ok: true, + json: async () => ({ + success: true, + data: { + allowed: false, + accessLevel: null, + owner: false, + }, + error: null, + }), + } as Response); + + const firstConn: any = new EventEmitter(); + firstConn.close = jest.fn(); + firstConn.readyState = WebSocket.OPEN; + wss.emit('connection', firstConn, mockReq); + await waitForConnectionProcessing(); + + expect(firstConn.close).toHaveBeenCalledWith(1008, 'Access denied'); + expect(fetchMock).toHaveBeenCalledTimes(1); + + await jest.advanceTimersByTimeAsync(15001); + + const secondConn: any = new EventEmitter(); + secondConn.close = jest.fn(); + secondConn.readyState = WebSocket.OPEN; + wss.emit('connection', secondConn, mockReq); + await waitForConnectionProcessing(); + + expect(secondConn.close).toHaveBeenCalledWith(1008, 'Access denied'); + expect(fetchMock).toHaveBeenCalledTimes(2); + + cleanupInactiveRooms(); + + const thirdConn: any = new EventEmitter(); + thirdConn.close = jest.fn(); + thirdConn.readyState = WebSocket.OPEN; + wss.emit('connection', thirdConn, mockReq); + await waitForConnectionProcessing(); + + // Third attempt is still blocked by the renewed cooldown, so no new access check runs. + expect(thirdConn.close).toHaveBeenCalledWith(1008, 'Access denied'); + expect(fetchMock).toHaveBeenCalledTimes(2); + } finally { + jest.useRealTimers(); + } + }); + it('should handle synchronous error in setupWSConnection', async () => { // Mock setupWSConnection to throw synchronously setupWSConnectionMock.mockImplementationOnce(() => { diff --git a/web/services/indexed-db.service.ts b/web/services/indexed-db.service.ts index fd23d5f..36f1e1f 100644 --- a/web/services/indexed-db.service.ts +++ b/web/services/indexed-db.service.ts @@ -7,7 +7,7 @@ const DOCUMENTS_STORE = 'documents'; class IndexedDBService { private static instance: IndexedDBService; private dbPromise: Promise | null = null; - private dbClosingPromise: Promise | null = null; + private dbSwitchPromise: Promise = Promise.resolve(); private isSupported: boolean = true; private currentUserId: string | null = null; @@ -35,22 +35,42 @@ class IndexedDBService { this.currentUserId = userId; - if (this.dbPromise && !this.dbClosingPromise) { - const dbPromiseToClose = this.dbPromise; + this.dbSwitchPromise = this.dbSwitchPromise + .then(async () => { + await this.closeCurrentConnection(); + }) + .catch((error) => { + console.warn('Failed to process IndexedDB user switch:', error); + }); + } + + private async awaitStableUserSwitch(): Promise { + while (true) { + const inFlightSwitch = this.dbSwitchPromise; + await inFlightSwitch; - this.dbClosingPromise = dbPromiseToClose - .then((db) => { - db.close(); - }) - .catch((error) => { - console.warn('Failed to close IndexedDB connection during user switch:', error); - }) - .finally(() => { - if (this.dbPromise === dbPromiseToClose) { - this.dbPromise = null; - } - this.dbClosingPromise = null; - }); + // A newer switch may have been queued while awaiting this one. + if (inFlightSwitch === this.dbSwitchPromise) { + return; + } + } + } + + private async closeCurrentConnection(): Promise { + const dbPromiseToClose = this.dbPromise; + if (!dbPromiseToClose) { + return; + } + + try { + const db = await dbPromiseToClose; + db.close(); + } catch (error) { + console.warn('Failed to close IndexedDB connection during user switch:', error); + } finally { + if (this.dbPromise === dbPromiseToClose) { + this.dbPromise = null; + } } } @@ -59,9 +79,7 @@ class IndexedDBService { throw new Error('IndexedDB not supported'); } - if (this.dbClosingPromise) { - await this.dbClosingPromise; - } + await this.awaitStableUserSwitch(); if (!this.dbPromise) { this.dbPromise = this.initDB(); @@ -208,11 +226,8 @@ class IndexedDBService { public async wipeDatabase(): Promise { try { - if (this.dbPromise) { - const db = await this.dbPromise; - db.close(); - this.dbPromise = null; - } + await this.awaitStableUserSwitch(); + await this.closeCurrentConnection(); await deleteDB(this.dbName); } catch (error) { console.error('Failed to wipe database:', error); diff --git a/web/tests/unit/services/indexed-db.service.test.ts b/web/tests/unit/services/indexed-db.service.test.ts index 24a87e1..839e743 100644 --- a/web/tests/unit/services/indexed-db.service.test.ts +++ b/web/tests/unit/services/indexed-db.service.test.ts @@ -226,6 +226,34 @@ describe('indexed-db.service', () => { expect(await indexedDBService.getDocument('doc-A')).toBeDefined(); expect(await indexedDBService.getDocument('doc-B')).toBeUndefined(); }); + + it('should honor the latest user context after rapid user switches', async () => { + indexedDBService.setUserId('rapid-user-A'); + await indexedDBService.saveDocument({ + id: 'doc-rapid-A', + meta: { title: 'A', createdAt: '', updatedAt: '' }, + yjsState: new Uint8Array([1]), + version: 1, + }); + + indexedDBService.setUserId('rapid-user-B'); + indexedDBService.setUserId('rapid-user-C'); + indexedDBService.setUserId('rapid-user-B'); + + await indexedDBService.saveDocument({ + id: 'doc-rapid-B', + meta: { title: 'B', createdAt: '', updatedAt: '' }, + yjsState: new Uint8Array([2]), + version: 1, + }); + + expect(indexedDBService.dbName).toBe('nextdocs-db_rapid-user-B'); + expect(await indexedDBService.getDocument('doc-rapid-B')).toBeDefined(); + + indexedDBService.setUserId('rapid-user-A'); + expect(await indexedDBService.getDocument('doc-rapid-B')).toBeUndefined(); + expect(await indexedDBService.getDocument('doc-rapid-A')).toBeDefined(); + }); }); describe('isAvailable', () => {