|
| 1 | +/** |
| 2 | + * @vitest-environment node |
| 3 | + */ |
| 4 | +import { beforeEach, describe, expect, it, vi } from 'vitest' |
| 5 | + |
| 6 | +const { |
| 7 | + MockMcpClient, |
| 8 | + mockListTools, |
| 9 | + mockConnect, |
| 10 | + mockDisconnect, |
| 11 | + mockGetWorkspaceServersRows, |
| 12 | + mockResolveEnvVars, |
| 13 | + mockValidateDomain, |
| 14 | + mockValidateSsrf, |
| 15 | + mockIsDomainAllowed, |
| 16 | +} = vi.hoisted(() => { |
| 17 | + const mockListTools = vi.fn() |
| 18 | + const mockConnect = vi.fn() |
| 19 | + const mockDisconnect = vi.fn() |
| 20 | + return { |
| 21 | + MockMcpClient: vi.fn().mockImplementation(() => ({ |
| 22 | + connect: mockConnect, |
| 23 | + disconnect: mockDisconnect, |
| 24 | + listTools: mockListTools, |
| 25 | + hasListChangedCapability: vi.fn(() => false), |
| 26 | + onClose: vi.fn(), |
| 27 | + getNegotiatedVersion: vi.fn(() => '2025-06-18'), |
| 28 | + })), |
| 29 | + mockListTools, |
| 30 | + mockConnect, |
| 31 | + mockDisconnect, |
| 32 | + mockGetWorkspaceServersRows: vi.fn(), |
| 33 | + mockResolveEnvVars: vi.fn(), |
| 34 | + mockValidateDomain: vi.fn(), |
| 35 | + mockValidateSsrf: vi.fn(), |
| 36 | + mockIsDomainAllowed: vi.fn(() => true), |
| 37 | + } |
| 38 | +}) |
| 39 | + |
| 40 | +vi.mock('@sim/db', () => { |
| 41 | + const setter = vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) }) |
| 42 | + return { |
| 43 | + db: { |
| 44 | + select: vi.fn().mockReturnValue({ |
| 45 | + from: vi.fn().mockReturnValue({ |
| 46 | + where: (...args: unknown[]) => mockGetWorkspaceServersRows(...args), |
| 47 | + }), |
| 48 | + }), |
| 49 | + update: vi.fn().mockReturnValue({ set: setter }), |
| 50 | + insert: vi.fn(), |
| 51 | + delete: vi.fn(), |
| 52 | + }, |
| 53 | + } |
| 54 | +}) |
| 55 | + |
| 56 | +vi.mock('@/lib/mcp/client', () => ({ |
| 57 | + McpClient: MockMcpClient, |
| 58 | +})) |
| 59 | + |
| 60 | +vi.mock('@/lib/mcp/connection-manager', () => ({ |
| 61 | + mcpConnectionManager: null, |
| 62 | +})) |
| 63 | + |
| 64 | +vi.mock('@/lib/mcp/domain-check', () => ({ |
| 65 | + isMcpDomainAllowed: (...args: unknown[]) => mockIsDomainAllowed(...args), |
| 66 | + validateMcpDomain: (...args: unknown[]) => mockValidateDomain(...args), |
| 67 | + validateMcpServerSsrf: (...args: unknown[]) => mockValidateSsrf(...args), |
| 68 | +})) |
| 69 | + |
| 70 | +vi.mock('@/lib/mcp/oauth', () => ({ |
| 71 | + getOrCreateOauthRow: vi.fn(), |
| 72 | + loadPreregisteredClient: vi.fn(), |
| 73 | + SimMcpOauthProvider: vi.fn(), |
| 74 | + withMcpOauthRefreshLock: vi.fn(), |
| 75 | +})) |
| 76 | + |
| 77 | +vi.mock('@/lib/mcp/resolve-config', () => ({ |
| 78 | + resolveMcpConfigEnvVars: (...args: unknown[]) => mockResolveEnvVars(...args), |
| 79 | +})) |
| 80 | + |
| 81 | +import { mcpService } from '@/lib/mcp/service' |
| 82 | +import { McpOauthAuthorizationRequiredError } from '@/lib/mcp/types' |
| 83 | + |
| 84 | +const WORKSPACE_ID = 'workspace-test' |
| 85 | +const USER_ID = 'user-test' |
| 86 | + |
| 87 | +function dbRow(id: string, name: string, overrides: Record<string, unknown> = {}) { |
| 88 | + return { |
| 89 | + id, |
| 90 | + name, |
| 91 | + description: null, |
| 92 | + transport: 'streamable-http', |
| 93 | + url: `https://${id}.example.com/mcp`, |
| 94 | + authType: 'headers', |
| 95 | + workspaceId: WORKSPACE_ID, |
| 96 | + headers: {}, |
| 97 | + timeout: 30000, |
| 98 | + retries: 3, |
| 99 | + enabled: true, |
| 100 | + deletedAt: null, |
| 101 | + createdAt: new Date('2026-01-01T00:00:00Z'), |
| 102 | + updatedAt: new Date('2026-01-01T00:00:00Z'), |
| 103 | + ...overrides, |
| 104 | + } |
| 105 | +} |
| 106 | + |
| 107 | +function tool(name: string, serverId: string) { |
| 108 | + return { |
| 109 | + name, |
| 110 | + description: name, |
| 111 | + inputSchema: { type: 'object' }, |
| 112 | + serverId, |
| 113 | + serverName: serverId, |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +describe('McpService.discoverTools per-server caching', () => { |
| 118 | + beforeEach(async () => { |
| 119 | + vi.clearAllMocks() |
| 120 | + mockIsDomainAllowed.mockReturnValue(true) |
| 121 | + mockValidateSsrf.mockResolvedValue('1.2.3.4') |
| 122 | + mockValidateDomain.mockImplementation(() => undefined) |
| 123 | + mockResolveEnvVars.mockImplementation((config: { url: string }) => |
| 124 | + Promise.resolve({ config: { ...config, url: config.url }, missingVars: [] }) |
| 125 | + ) |
| 126 | + mockConnect.mockResolvedValue(undefined) |
| 127 | + mockDisconnect.mockResolvedValue(undefined) |
| 128 | + // The McpService singleton holds cache state across imports. |
| 129 | + await mcpService.clearCache() |
| 130 | + }) |
| 131 | + |
| 132 | + it('caches each server independently after first discovery', async () => { |
| 133 | + mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')]) |
| 134 | + mockListTools |
| 135 | + .mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 136 | + .mockResolvedValueOnce([tool('b1', 'mcp-b')]) |
| 137 | + |
| 138 | + const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 139 | + expect(first.map((t) => t.name).sort()).toEqual(['a1', 'b1']) |
| 140 | + expect(mockListTools).toHaveBeenCalledTimes(2) |
| 141 | + |
| 142 | + mockListTools.mockClear() |
| 143 | + const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 144 | + expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1']) |
| 145 | + expect(mockListTools).not.toHaveBeenCalled() |
| 146 | + }) |
| 147 | + |
| 148 | + it("one server failing does not poison another server's cache", async () => { |
| 149 | + mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')]) |
| 150 | + mockListTools |
| 151 | + .mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 152 | + .mockRejectedValueOnce(new Error('Request timed out')) |
| 153 | + |
| 154 | + const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 155 | + expect(first.map((t) => t.name)).toEqual(['a1']) |
| 156 | + |
| 157 | + mockListTools.mockClear() |
| 158 | + mockListTools.mockResolvedValueOnce([tool('b1', 'mcp-b')]) |
| 159 | + |
| 160 | + const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 161 | + expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1']) |
| 162 | + expect(mockListTools).toHaveBeenCalledTimes(1) |
| 163 | + }) |
| 164 | + |
| 165 | + it("forceRefresh bypasses every server's cache", async () => { |
| 166 | + mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')]) |
| 167 | + mockListTools |
| 168 | + .mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 169 | + .mockResolvedValueOnce([tool('b1', 'mcp-b')]) |
| 170 | + |
| 171 | + await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 172 | + expect(mockListTools).toHaveBeenCalledTimes(2) |
| 173 | + |
| 174 | + mockListTools.mockClear() |
| 175 | + mockListTools |
| 176 | + .mockResolvedValueOnce([tool('a2', 'mcp-a')]) |
| 177 | + .mockResolvedValueOnce([tool('b2', 'mcp-b')]) |
| 178 | + |
| 179 | + const refreshed = await mcpService.discoverTools(USER_ID, WORKSPACE_ID, true) |
| 180 | + expect(refreshed.map((t) => t.name).sort()).toEqual(['a2', 'b2']) |
| 181 | + expect(mockListTools).toHaveBeenCalledTimes(2) |
| 182 | + }) |
| 183 | + |
| 184 | + it('OAuth-pending is treated as a soft skip without poisoning cache', async () => { |
| 185 | + mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')]) |
| 186 | + mockListTools |
| 187 | + .mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 188 | + .mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B')) |
| 189 | + |
| 190 | + const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 191 | + expect(first.map((t) => t.name)).toEqual(['a1']) |
| 192 | + |
| 193 | + mockListTools.mockClear() |
| 194 | + mockListTools.mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B')) |
| 195 | + |
| 196 | + await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 197 | + expect(mockListTools).toHaveBeenCalledTimes(1) |
| 198 | + }) |
| 199 | + |
| 200 | + it('returns empty array immediately when workspace has no servers', async () => { |
| 201 | + mockGetWorkspaceServersRows.mockResolvedValue([]) |
| 202 | + |
| 203 | + const result = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 204 | + expect(result).toEqual([]) |
| 205 | + expect(mockListTools).not.toHaveBeenCalled() |
| 206 | + expect(MockMcpClient).not.toHaveBeenCalled() |
| 207 | + }) |
| 208 | + |
| 209 | + it('clearCache(workspaceId) drops cached tools so next call re-fetches', async () => { |
| 210 | + mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A')]) |
| 211 | + mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 212 | + |
| 213 | + await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 214 | + expect(mockListTools).toHaveBeenCalledTimes(1) |
| 215 | + |
| 216 | + await mcpService.clearCache(WORKSPACE_ID) |
| 217 | + |
| 218 | + mockListTools.mockClear() |
| 219 | + mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 220 | + await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 221 | + expect(mockListTools).toHaveBeenCalledTimes(1) |
| 222 | + }) |
| 223 | + |
| 224 | + it('isolates caches across workspaces', async () => { |
| 225 | + const otherWorkspaceId = 'workspace-other' |
| 226 | + mockGetWorkspaceServersRows |
| 227 | + .mockResolvedValueOnce([dbRow('mcp-a', 'A')]) |
| 228 | + .mockResolvedValueOnce([dbRow('mcp-a', 'A', { workspaceId: otherWorkspaceId })]) |
| 229 | + |
| 230 | + mockListTools |
| 231 | + .mockResolvedValueOnce([tool('a1', 'mcp-a')]) |
| 232 | + .mockResolvedValueOnce([tool('a-other', 'mcp-a')]) |
| 233 | + |
| 234 | + const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID) |
| 235 | + const second = await mcpService.discoverTools(USER_ID, otherWorkspaceId) |
| 236 | + |
| 237 | + expect(first.map((t) => t.name)).toEqual(['a1']) |
| 238 | + expect(second.map((t) => t.name)).toEqual(['a-other']) |
| 239 | + expect(mockListTools).toHaveBeenCalledTimes(2) |
| 240 | + }) |
| 241 | +}) |
0 commit comments