diff --git a/src/session.spec.ts b/src/session.spec.ts index 1670c61..4807662 100644 --- a/src/session.spec.ts +++ b/src/session.spec.ts @@ -1,3 +1,4 @@ +import { createRequire } from 'node:module'; import { LoaderFunctionArgs, Session as ReactRouterSession, redirect } from 'react-router'; import { AuthenticationResponse, type User } from '@workos-inc/node'; import * as ironSession from 'iron-session'; @@ -819,6 +820,142 @@ describe('session', () => { }); }); + describe('JWKS caching', () => { + const createLoaderArgs = (request: Request): LoaderFunctionArgs => + ({ + request, + params: {}, + context: {}, + }) as LoaderFunctionArgs; + + const mockSessionData = { + accessToken: 'valid.jwt.token', + refreshToken: 'refresh.token', + user: { + id: 'user-1', + email: 'test@example.com', + }, + impersonator: null, + }; + + type IsolatedModules = { + authkitLoader: typeof authkitLoader; + createRemoteJWKSet: jest.Mock; + jwtVerify: jest.Mock; + getJwksUrl: jest.Mock; + }; + + // Each test gets its own freshly-loaded copy of ./session.js so the + // module-level JWKS cache never leaks across tests (or out of this + // describe block). This guards against subtle ordering bugs where a later + // test would depend on cache state set up here. + function loadIsolated(): IsolatedModules { + let isolated!: IsolatedModules; + // jest.isolateModules only scopes synchronous CJS require() calls; a + // dynamic `await import()` would need --experimental-vm-modules. Use + // createRequire to avoid a bare `require` keyword in source while still + // participating in Jest's module loader hooks. + const isolatedRequire = createRequire(__filename); + jest.isolateModules(() => { + const joseModule = isolatedRequire('jose') as typeof import('jose'); + const workosModule = isolatedRequire('./workos.js') as typeof import('./workos.js'); + const sessionStorageModule = isolatedRequire('./sessionStorage.js') as typeof import('./sessionStorage.js'); + const ironSessionModule = isolatedRequire('iron-session') as typeof import('iron-session'); + const sessionModule = isolatedRequire('./session.js') as typeof import('./session.js'); + + const wos = workosModule.getWorkOS(); + const getJwksUrlMock = wos.userManagement.getJwksUrl as jest.Mock; + const createRemoteJWKSetMock = joseModule.createRemoteJWKSet as jest.Mock; + const jwtVerifyMock = joseModule.jwtVerify as jest.Mock; + const decodeJwtMock = joseModule.decodeJwt as jest.Mock; + const getSessionStorageMock = sessionStorageModule.getSessionStorage as jest.Mock; + const unsealDataMock = ironSessionModule.unsealData as jest.Mock; + + const isolatedGetSession = jest.fn().mockResolvedValue( + createMockSession({ + has: jest.fn().mockReturnValue(true), + get: jest.fn().mockReturnValue('encrypted-jwt'), + set: jest.fn(), + }), + ); + getSessionStorageMock.mockResolvedValue({ + cookieName: 'wos-cookie', + getSession: isolatedGetSession, + destroySession: jest.fn().mockResolvedValue('destroyed-session-cookie'), + commitSession: jest.fn(), + }); + unsealDataMock.mockResolvedValue({ + ...mockSessionData, + headers: { 'Set-Cookie': 'session-cookie' }, + }); + getJwksUrlMock.mockImplementation((clientId: string) => `https://auth.workos.com/oauth/jwks/${clientId}`); + // Real createRemoteJWKSet returns a getKey function used by jwtVerify. + // The mock needs to return a truthy value so the module-level cache + // check in session.ts treats it as populated. + createRemoteJWKSetMock.mockReturnValue(jest.fn()); + jwtVerifyMock.mockResolvedValue({ + payload: {}, + protectedHeader: {}, + key: new TextEncoder().encode('test-key'), + }); + decodeJwtMock.mockReturnValue({ + sid: 'test-session-id', + org_id: 'org-123', + role: 'admin', + roles: ['admin'], + permissions: ['read', 'write'], + entitlements: ['premium'], + feature_flags: [], + }); + + isolated = { + authkitLoader: sessionModule.authkitLoader, + createRemoteJWKSet: createRemoteJWKSetMock, + jwtVerify: jwtVerifyMock, + getJwksUrl: getJwksUrlMock, + }; + }); + return isolated; + } + + it('reuses the cached JWKS instance across multiple verifyAccessToken calls', async () => { + const { authkitLoader, createRemoteJWKSet, jwtVerify } = loadIsolated(); + + // Prime the module-scoped cache. + await authkitLoader(createLoaderArgs(new Request('http://example.com/a', { headers: { Cookie: 'cookie' } }))); + createRemoteJWKSet.mockClear(); + + await authkitLoader(createLoaderArgs(new Request('http://example.com/b', { headers: { Cookie: 'cookie' } }))); + await authkitLoader(createLoaderArgs(new Request('http://example.com/c', { headers: { Cookie: 'cookie' } }))); + await authkitLoader(createLoaderArgs(new Request('http://example.com/d', { headers: { Cookie: 'cookie' } }))); + + expect(jwtVerify).toHaveBeenCalled(); + expect(createRemoteJWKSet).not.toHaveBeenCalled(); + }); + + it('rebuilds the JWKS instance when the JWKS URL changes', async () => { + const { authkitLoader, createRemoteJWKSet, getJwksUrl } = loadIsolated(); + + // Populate the cache with the default URL. + await authkitLoader(createLoaderArgs(new Request('http://example.com/a', { headers: { Cookie: 'cookie' } }))); + createRemoteJWKSet.mockClear(); + + // Same URL → no rebuild. + await authkitLoader(createLoaderArgs(new Request('http://example.com/b', { headers: { Cookie: 'cookie' } }))); + expect(createRemoteJWKSet).not.toHaveBeenCalled(); + + // URL changes (e.g. consumer re-configures with a different clientId) → + // the cache must be invalidated and a new JWKS instance created. + getJwksUrl.mockImplementation(() => 'https://auth.workos.com/oauth/jwks/other-client'); + await authkitLoader(createLoaderArgs(new Request('http://example.com/c', { headers: { Cookie: 'cookie' } }))); + expect(createRemoteJWKSet).toHaveBeenCalledTimes(1); + + // Still the same new URL → still cached. + await authkitLoader(createLoaderArgs(new Request('http://example.com/d', { headers: { Cookie: 'cookie' } }))); + expect(createRemoteJWKSet).toHaveBeenCalledTimes(1); + }); + }); + describe('saveSession', () => { const sessionData = { accessToken: 'new.valid.token', diff --git a/src/session.ts b/src/session.ts index ae34023..7e288c8 100644 --- a/src/session.ts +++ b/src/session.ts @@ -595,8 +595,20 @@ export async function getSessionFromCookie(cookie: string, session?: SessionData } } +let cachedJWKS: ReturnType | undefined; +let cachedJWKSUrl: string | undefined; + +function getJWKS(): ReturnType { + const jwksUrl = getWorkOS().userManagement.getJwksUrl(getConfig('clientId')); + if (!cachedJWKS || cachedJWKSUrl !== jwksUrl) { + cachedJWKS = createRemoteJWKSet(new URL(jwksUrl)); + cachedJWKSUrl = jwksUrl; + } + return cachedJWKS; +} + async function verifyAccessToken(accessToken: string) { - const JWKS = createRemoteJWKSet(new URL(getWorkOS().userManagement.getJwksUrl(getConfig('clientId')))); + const JWKS = getJWKS(); try { await jwtVerify(accessToken, JWKS); return true;