From 537cec3cf0e6f750f3f9fa96350f9430eb6c2e4a Mon Sep 17 00:00:00 2001 From: ding113 Date: Sat, 25 Apr 2026 14:29:15 +0000 Subject: [PATCH 1/4] fix(rate-limit): release failed provider sessions Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- src/lib/rate-limit/service.ts | 30 +++++++ .../provider-session-release.test.ts | 82 +++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 tests/unit/lib/rate-limit/provider-session-release.test.ts diff --git a/src/lib/rate-limit/service.ts b/src/lib/rate-limit/service.ts index 3de962b46..9353551fc 100644 --- a/src/lib/rate-limit/service.ts +++ b/src/lib/rate-limit/service.ts @@ -856,6 +856,36 @@ export class RateLimitService { } } + /** + * Release a provider-level active session when a selected provider is abandoned. + * + * Provider concurrency is tracked before forwarding so fallback decisions can be atomic. + * If the provider later fails, the session must be removed immediately instead of waiting + * for TTL cleanup; otherwise outage storms inflate provider active_sessions ZSETs. + */ + static async releaseProviderSession(providerId: number, sessionId: string): Promise { + if (!Number.isInteger(providerId) || providerId <= 0 || sessionId.trim().length === 0) { + return; + } + + const redis = RateLimitService.redis; + if (!redis || redis.status !== "ready") { + return; + } + + const key = `provider:${providerId}:active_sessions`; + try { + await redis.zrem(key, sessionId); + logger.debug("[RateLimit] Released provider session", { providerId, sessionId }); + } catch (error) { + logger.error("[RateLimit] Failed to release provider session", { + providerId, + sessionId, + error, + }); + } + } + /** * 累加消费(请求结束后调用) * 5h 使用滚动窗口(ZSET),daily 根据模式选择滚动/固定窗口,周/月使用固定窗口(STRING) diff --git a/tests/unit/lib/rate-limit/provider-session-release.test.ts b/tests/unit/lib/rate-limit/provider-session-release.test.ts new file mode 100644 index 000000000..9102db6b6 --- /dev/null +++ b/tests/unit/lib/rate-limit/provider-session-release.test.ts @@ -0,0 +1,82 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +type RedisClientMock = { + status: string; + zrem: (key: string, member: string) => Promise; +}; + +let redisClientRef: RedisClientMock | null; +let zremMock: ReturnType Promise>>; + +vi.mock("server-only", () => ({})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: () => redisClientRef, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + }, +})); + +describe("RateLimitService.releaseProviderSession", () => { + beforeEach(() => { + vi.clearAllMocks(); + zremMock = vi.fn(async () => 1); + redisClientRef = { + status: "ready", + zrem: zremMock, + }; + }); + + it("应从供应商 active_sessions ZSET 中释放失败请求的 sessionId", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(zremMock).toHaveBeenCalledTimes(1); + expect(zremMock).toHaveBeenCalledWith("provider:42:active_sessions", "sess_failed"); + }); + + it("Redis 不可用或未 ready 时应静默跳过", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + redisClientRef = null; + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + redisClientRef = { status: "connecting", zrem: zremMock }; + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(zremMock).not.toHaveBeenCalled(); + }); + + it("非法 providerId 或空 sessionId 不应触发 Redis 命令", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(0, "sess_failed"); + await RateLimitService.releaseProviderSession(-1, "sess_failed"); + await RateLimitService.releaseProviderSession(42, " "); + + expect(zremMock).not.toHaveBeenCalled(); + }); + + it("释放失败时应记录日志但不向请求链路抛错", async () => { + const error = new Error("redis down"); + zremMock.mockRejectedValueOnce(error); + const { RateLimitService } = await import("@/lib/rate-limit/service"); + const { logger } = await import("@/lib/logger"); + + await expect( + RateLimitService.releaseProviderSession(42, "sess_failed") + ).resolves.toBeUndefined(); + + expect(logger.error).toHaveBeenCalledWith("[RateLimit] Failed to release provider session", { + providerId: 42, + sessionId: "sess_failed", + error, + }); + }); +}); From f454dcdba9b04e3b1cb5d51799336d6326a5165d Mon Sep 17 00:00:00 2001 From: ding113 Date: Sat, 25 Apr 2026 14:29:24 +0000 Subject: [PATCH 2/4] fix(proxy): release provider sessions on fallback Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- src/app/v1/_lib/proxy/forwarder.ts | 37 ++++++++++--- ...forwarder-provider-session-release.test.ts | 54 +++++++++++++++++++ 2 files changed, 84 insertions(+), 7 deletions(-) create mode 100644 tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 1e97622db..01340e58b 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -26,6 +26,7 @@ import { getPreferredProviderEndpoints, } from "@/lib/provider-endpoints/endpoint-selector"; import { getGlobalAgentPool, getProxyAgentForProvider } from "@/lib/proxy-agent"; +import { RateLimitService } from "@/lib/rate-limit/service"; import { SessionManager } from "@/lib/session-manager"; import { detectUpstreamErrorFromSseOrJsonText, @@ -1077,7 +1078,7 @@ export class ProxyForwarder { }); } - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } else { endpointCandidates.push({ endpointId: null, baseUrl: currentProvider.url }); @@ -1140,7 +1141,7 @@ export class ProxyForwarder { vendorId: currentProvider.providerVendorId, providerType: currentProvider.providerType, }); - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } @@ -1708,7 +1709,7 @@ export class ProxyForwarder { const env = getEnvConfig(); // 无论是否计入熔断器,都要加入 failedProviderIds(避免重复选择同一供应商) - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); if (env.ENABLE_CIRCUIT_BREAKER_ON_NETWORK_ERRORS) { logger.warn( @@ -1806,7 +1807,7 @@ export class ProxyForwarder { } // 重试耗尽:加入失败列表并切换供应商 - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // ⭐ 跳出内层循环,进入供应商切换逻辑 } @@ -1878,7 +1879,11 @@ export class ProxyForwarder { } } - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed( + session, + failedProviderIds, + currentProvider.id + ); break; // 跳出内层循环,进入供应商切换逻辑 } @@ -1927,7 +1932,11 @@ export class ProxyForwarder { currentProvider.providerVendorId, currentProvider.providerType ); - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed( + session, + failedProviderIds, + currentProvider.id + ); break; } @@ -2023,7 +2032,7 @@ export class ProxyForwarder { } // 加入失败列表并切换供应商 - failedProviderIds.push(currentProvider.id); + await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // 跳出内层循环,进入供应商切换逻辑 } } @@ -4250,6 +4259,20 @@ export class ProxyForwarder { await SessionManager.clearSessionProvider(session.sessionId); } + private static async markProviderFailed( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ): Promise { + failedProviderIds.push(providerId); + + if (!session.sessionId) { + return; + } + + await RateLimitService.releaseProviderSession(providerId, session.sessionId); + } + private static buildAllProvidersUnavailableError(finalError?: Error | null): ProxyError { const safeClientMessageCandidate = finalError instanceof ProxyError && diff --git a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts new file mode 100644 index 000000000..fd6175676 --- /dev/null +++ b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts @@ -0,0 +1,54 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { ProxySession } from "@/app/v1/_lib/proxy/session"; + +const mocks = vi.hoisted(() => ({ + releaseProviderSession: vi.fn(async (_providerId: number, _sessionId: string) => {}), +})); + +vi.mock("@/lib/rate-limit/service", () => ({ + RateLimitService: { + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + +describe("ProxyForwarder provider failure session release", () => { + beforeEach(() => { + mocks.releaseProviderSession.mockClear(); + }); + + it("标记供应商失败时应同步释放 provider active session", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => Promise; + }; + const session = { sessionId: "sess_failed" } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + await forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(42, "sess_failed"); + }); + + it("没有 sessionId 时只记录失败供应商,不触发 Redis 释放", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => Promise; + }; + const session = { sessionId: null } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + await forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(mocks.releaseProviderSession).not.toHaveBeenCalled(); + }); +}); From a9fcbd6225067515e8e7460ce5ecb3e3cbe6a847 Mon Sep 17 00:00:00 2001 From: ding113 Date: Sat, 25 Apr 2026 14:52:12 +0000 Subject: [PATCH 3/4] fix(proxy): release hedge loser sessions Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- src/app/v1/_lib/proxy/forwarder.ts | 36 +++++++++---------- .../proxy-forwarder-hedge-first-byte.test.ts | 14 ++++++++ ...forwarder-provider-session-release.test.ts | 27 +++++++++++--- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 01340e58b..4e3ff5b76 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -1078,7 +1078,7 @@ export class ProxyForwarder { }); } - await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } else { endpointCandidates.push({ endpointId: null, baseUrl: currentProvider.url }); @@ -1141,7 +1141,7 @@ export class ProxyForwarder { vendorId: currentProvider.providerVendorId, providerType: currentProvider.providerType, }); - await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } @@ -1709,7 +1709,7 @@ export class ProxyForwarder { const env = getEnvConfig(); // 无论是否计入熔断器,都要加入 failedProviderIds(避免重复选择同一供应商) - await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); if (env.ENABLE_CIRCUIT_BREAKER_ON_NETWORK_ERRORS) { logger.warn( @@ -1807,7 +1807,7 @@ export class ProxyForwarder { } // 重试耗尽:加入失败列表并切换供应商 - await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // ⭐ 跳出内层循环,进入供应商切换逻辑 } @@ -1879,11 +1879,7 @@ export class ProxyForwarder { } } - await ProxyForwarder.markProviderFailed( - session, - failedProviderIds, - currentProvider.id - ); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // 跳出内层循环,进入供应商切换逻辑 } @@ -1932,11 +1928,7 @@ export class ProxyForwarder { currentProvider.providerVendorId, currentProvider.providerType ); - await ProxyForwarder.markProviderFailed( - session, - failedProviderIds, - currentProvider.id - ); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; } @@ -2032,7 +2024,7 @@ export class ProxyForwarder { } // 加入失败列表并切换供应商 - await ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // 跳出内层循环,进入供应商切换逻辑 } } @@ -3406,6 +3398,7 @@ export class ProxyForwarder { let lastError: Error | null = null; let lastErrorCategory: ErrorCategory | null = null; const attempts = new Set(); + const failedProviderIds: number[] = []; let resolveResult: ((result: { response?: Response; error?: Error }) => void) | null = null; const resultPromise = new Promise<{ response?: Response; error?: Error }>((resolve) => { @@ -3453,6 +3446,7 @@ export class ProxyForwarder { attemptNumber: attempt.sequence, modelRedirect: getAttemptModelRedirect(attempt), }); + ProxyForwarder.markProviderFailed(session, failedProviderIds, attempt.provider.id); } try { attempt.responseController?.abort(new Error(reason)); @@ -3776,6 +3770,7 @@ export class ProxyForwarder { attempt.thresholdTimer = null; } attempts.delete(attempt); + ProxyForwarder.markProviderFailed(session, failedProviderIds, attempt.provider.id); if (errorCategory === ErrorCategory.PROVIDER_ERROR && statusCode !== 404) { await recordFailure(attempt.provider.id, error); @@ -3940,6 +3935,7 @@ export class ProxyForwarder { } catch (endpointError) { lastError = endpointError as Error; lastErrorCategory = null; + ProxyForwarder.markProviderFailed(session, failedProviderIds, provider.id); await launchAlternative(); await finishIfExhausted(); return; @@ -4259,18 +4255,22 @@ export class ProxyForwarder { await SessionManager.clearSessionProvider(session.sessionId); } - private static async markProviderFailed( + private static markProviderFailed( session: ProxySession, failedProviderIds: number[], providerId: number - ): Promise { + ): void { + if (failedProviderIds.includes(providerId)) { + return; + } + failedProviderIds.push(providerId); if (!session.sessionId) { return; } - await RateLimitService.releaseProviderSession(providerId, session.sessionId); + void RateLimitService.releaseProviderSession(providerId, session.sessionId); } private static buildAllProvidersUnavailableError(finalError?: Error | null): ProxyError { diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index 942234e60..ee77b3d6f 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -20,6 +20,7 @@ const mocks = vi.hoisted(() => ({ recordEndpointFailure: vi.fn(async () => {}), isVendorTypeCircuitOpen: vi.fn(async () => false), recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}), + releaseProviderSession: vi.fn(async (_providerId: number, _sessionId: string) => {}), categorizeErrorAsync: vi.fn(async () => 0), getErrorDetectionResultAsync: vi.fn(async () => ({ matched: false })), getCachedSystemSettings: vi.fn(async () => ({ @@ -73,6 +74,12 @@ vi.mock("@/lib/vendor-type-circuit-breaker", () => ({ recordVendorTypeAllEndpointsTimeout: mocks.recordVendorTypeAllEndpointsTimeout, })); +vi.mock("@/lib/rate-limit/service", () => ({ + RateLimitService: { + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + vi.mock("@/lib/session-manager", () => ({ SessionManager: { updateSessionBindingSmart: mocks.updateSessionBindingSmart, @@ -575,6 +582,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { redirectedModel: fireworksRedirect, billingModel: requestedModel, }); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(fireworks.id, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -874,6 +882,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { true, null ); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -1071,6 +1080,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(1); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(2, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -1148,6 +1158,9 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(3); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(2, "sess-hedge"); + expect(mocks.releaseProviderSession).not.toHaveBeenCalledWith(3, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -1790,6 +1803,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { ); expect(winnerEntry).toBeDefined(); expect(winnerEntry!.reason).toBe("request_success"); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); } finally { vi.useRealTimers(); } diff --git a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts index fd6175676..13a68d3b4 100644 --- a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts +++ b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts @@ -23,17 +23,36 @@ describe("ProxyForwarder provider failure session release", () => { session: ProxySession, failedProviderIds: number[], providerId: number - ) => Promise; + ) => void; }; const session = { sessionId: "sess_failed" } as unknown as ProxySession; const failedProviderIds: number[] = []; - await forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); expect(failedProviderIds).toEqual([42]); expect(mocks.releaseProviderSession).toHaveBeenCalledWith(42, "sess_failed"); }); + it("重复标记同一供应商时只释放一次,避免 hedge 路径重复 ZREM", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const session = { sessionId: "sess_failed" } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(mocks.releaseProviderSession).toHaveBeenCalledTimes(1); + }); + it("没有 sessionId 时只记录失败供应商,不触发 Redis 释放", async () => { const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); const forwarderInternals = ProxyForwarder as unknown as { @@ -41,12 +60,12 @@ describe("ProxyForwarder provider failure session release", () => { session: ProxySession, failedProviderIds: number[], providerId: number - ) => Promise; + ) => void; }; const session = { sessionId: null } as unknown as ProxySession; const failedProviderIds: number[] = []; - await forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); expect(failedProviderIds).toEqual([42]); expect(mocks.releaseProviderSession).not.toHaveBeenCalled(); From 5375ce4eb51c14372fbc4f810b57cad7fc15fb34 Mon Sep 17 00:00:00 2001 From: ding113 Date: Sat, 25 Apr 2026 16:05:05 +0000 Subject: [PATCH 4/4] fix(rate-limit): preserve provider session refs Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- src/app/v1/_lib/proxy/forwarder.ts | 77 ++++++++--- src/app/v1/_lib/proxy/provider-selector.ts | 4 + src/app/v1/_lib/proxy/session.ts | 23 ++++ src/lib/rate-limit/service.ts | 43 ++++-- src/lib/redis/lua-scripts.ts | 64 ++++++++- src/lib/session-manager.ts | 1 + src/lib/session-tracker.ts | 81 ++++++++++- .../provider-session-release.test.ts | 37 +++-- .../unit/lib/rate-limit/service-extra.test.ts | 34 +++-- .../session-manager-terminate-session.test.ts | 1 + .../unit/lib/session-tracker-cleanup.test.ts | 10 ++ .../proxy-forwarder-hedge-first-byte.test.ts | 129 ++++++++++++++++-- ...forwarder-provider-session-release.test.ts | 45 +++++- 13 files changed, 476 insertions(+), 73 deletions(-) diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 4e3ff5b76..204ccc722 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -3514,21 +3514,24 @@ export class ProxyForwarder { } launchingAlternative = (async () => { - const alternativeProvider = await ProxyForwarder.selectAlternative( - session, - Array.from(launchedProviderIds) - ); - if (!alternativeProvider) { - noMoreProviders = true; - // No alternative providers available — let in-flight attempt(s) continue. - // If all attempts already completed, settle with last error. - if (attempts.size === 0) { - await finishIfExhausted(); + while (!settled && !winnerCommitted && !noMoreProviders) { + const alternativeProvider = await ProxyForwarder.selectAlternative( + session, + Array.from(launchedProviderIds) + ); + if (!alternativeProvider) { + noMoreProviders = true; + // No alternative providers available — let in-flight attempt(s) continue. + // If all attempts already completed, settle with last error. + if (attempts.size === 0) { + await finishIfExhausted(); + } + return; } - return; - } - await startAttempt(alternativeProvider, false); + const launched = await startAttempt(alternativeProvider, false); + if (launched) return; + } })() .catch(async (error) => { const normalizedError = error instanceof Error ? error : new Error(String(error)); @@ -3920,11 +3923,38 @@ export class ProxyForwarder { settleSuccess(response); }; - const startAttempt = async (provider: Provider, useOriginalSession: boolean) => { - if (settled || winnerCommitted || launchedProviderIds.has(provider.id)) return; + const startAttempt = async ( + provider: Provider, + useOriginalSession: boolean + ): Promise => { + if (settled || winnerCommitted || launchedProviderIds.has(provider.id)) return false; launchedProviderIds.add(provider.id); + if (!useOriginalSession && session.sessionId) { + const limit = provider.limitConcurrentSessions || 0; + const checkResult = await RateLimitService.checkAndTrackProviderSession( + provider.id, + session.sessionId, + limit + ); + + if (!checkResult.allowed) { + ProxyForwarder.markProviderFailed(session, failedProviderIds, provider.id); + session.addProviderToChain(provider, { + reason: "concurrent_limit_failed", + circuitState: getCircuitState(provider.id), + attemptNumber: launchedProviderCount + 1, + errorMessage: checkResult.reason || "并发限制已达到", + }); + return false; + } + + if (checkResult.referenced) { + session.recordProviderSessionRef(provider.id); + } + } + let endpointSelection: { endpointId: number | null; baseUrl: string; @@ -3936,9 +3966,8 @@ export class ProxyForwarder { lastError = endpointError as Error; lastErrorCategory = null; ProxyForwarder.markProviderFailed(session, failedProviderIds, provider.id); - await launchAlternative(); await finishIfExhausted(); - return; + return false; } launchedProviderCount += 1; @@ -3989,6 +4018,7 @@ export class ProxyForwarder { armAttemptThreshold(attempt); runAttempt(attempt); + return true; }; if (session.clientAbortSignal) { @@ -4017,7 +4047,10 @@ export class ProxyForwarder { ); } - await startAttempt(initialProvider, true); + const initialLaunched = await startAttempt(initialProvider, true); + if (!initialLaunched) { + await launchAlternative(); + } await finishIfExhausted(); const result = await resultPromise; if (result.error) { @@ -4270,6 +4303,14 @@ export class ProxyForwarder { return; } + const providerSessionRefConsumer = ( + session as { consumeProviderSessionRef?: (providerId: number) => boolean } + ).consumeProviderSessionRef; + + if (!providerSessionRefConsumer?.call(session, providerId)) { + return; + } + void RateLimitService.releaseProviderSession(providerId, session.sessionId); } diff --git a/src/app/v1/_lib/proxy/provider-selector.ts b/src/app/v1/_lib/proxy/provider-selector.ts index 8bdd4ef5f..8b9fde2b9 100644 --- a/src/app/v1/_lib/proxy/provider-selector.ts +++ b/src/app/v1/_lib/proxy/provider-selector.ts @@ -295,6 +295,10 @@ export class ProxyProviderResolver { } // === 成功 === + if (checkResult.referenced) { + session.recordProviderSessionRef(session.provider.id); + } + logger.debug("ProviderSelector: Session tracked atomically", { sessionId: session.sessionId, providerName: session.provider.name, diff --git a/src/app/v1/_lib/proxy/session.ts b/src/app/v1/_lib/proxy/session.ts index 5e892c1c7..692e695d2 100644 --- a/src/app/v1/_lib/proxy/session.ts +++ b/src/app/v1/_lib/proxy/session.ts @@ -176,6 +176,10 @@ export class ProxySession { */ private providersSnapshot: Provider[] | null = null; + // 本请求已通过 Provider 并发检查获得的引用。 + // 失败切换 provider 时只能释放这里记录过的引用,避免 hedge/fallback 释放未 acquire 的 Redis 计数。 + private providerSessionRefs = new Set(); + private constructor(init: { startTime: number; method: string; @@ -313,6 +317,25 @@ export class ProxySession { } } + recordProviderSessionRef(providerId: number): void { + if (!this.providerSessionRefs) { + this.providerSessionRefs = new Set(); + } + + if (Number.isInteger(providerId) && providerId > 0) { + this.providerSessionRefs.add(providerId); + } + } + + consumeProviderSessionRef(providerId: number): boolean { + if (!this.providerSessionRefs?.has(providerId)) { + return false; + } + + this.providerSessionRefs.delete(providerId); + return true; + } + setCacheTtlResolved(ttl: CacheTtlResolved | null): void { this.cacheTtlResolved = ttl; } diff --git a/src/lib/rate-limit/service.ts b/src/lib/rate-limit/service.ts index 9353551fc..1e7c7f304 100644 --- a/src/lib/rate-limit/service.ts +++ b/src/lib/rate-limit/service.ts @@ -77,6 +77,7 @@ import { CHECK_AND_TRACK_SESSION, GET_COST_5H_ROLLING_WINDOW, GET_COST_DAILY_ROLLING_WINDOW, + RELEASE_PROVIDER_SESSION, TRACK_COST_5H_ROLLING_WINDOW, TRACK_COST_DAILY_ROLLING_WINDOW, } from "@/lib/redis/lua-scripts"; @@ -804,43 +805,52 @@ export class RateLimitService { * @param providerId - Provider ID * @param sessionId - Session ID * @param limit - 并发限制 - * @returns { allowed, count, tracked } - 是否允许、当前并发数、是否已追踪 + * @returns { allowed, count, tracked, referenced } - 是否允许、当前并发数、是否新追踪、是否获得释放引用 */ static async checkAndTrackProviderSession( providerId: number, sessionId: string, limit: number - ): Promise<{ allowed: boolean; count: number; tracked: boolean; reason?: string }> { + ): Promise<{ + allowed: boolean; + count: number; + tracked: boolean; + referenced: boolean; + reason?: string; + }> { if (limit <= 0) { - return { allowed: true, count: 0, tracked: false }; + return { allowed: true, count: 0, tracked: false, referenced: false }; } if (!RateLimitService.redis || RateLimitService.redis.status !== "ready") { logger.warn("[RateLimit] Redis not ready, Fail Open"); - return { allowed: true, count: 0, tracked: false }; + return { allowed: true, count: 0, tracked: false, referenced: false }; } try { const key = `provider:${providerId}:active_sessions`; + const refKey = `provider:${providerId}:active_session_refs`; const now = Date.now(); const result = (await RateLimitService.redis.eval( CHECK_AND_TRACK_SESSION, - 1, // KEYS count + 2, // KEYS count key, // KEYS[1] + refKey, // KEYS[2] sessionId, // ARGV[1] limit.toString(), // ARGV[2] now.toString(), // ARGV[3] SESSION_TTL_MS.toString() // ARGV[4] - )) as [number, number, number]; + )) as [number, number, number, number]; - const [allowed, count, tracked] = result; + const [allowed, count, tracked, referenced] = result; if (allowed === 0) { return { allowed: false, count, tracked: false, + referenced: false, reason: `供应商并发 Session 上限已达到(${count}/${limit})`, }; } @@ -849,10 +859,11 @@ export class RateLimitService { allowed: true, count, tracked: tracked === 1, // Lua 返回 1 表示新追踪,0 表示已存在 + referenced: referenced === 1, }; } catch (error) { logger.error("[RateLimit] Atomic check-and-track failed:", error); - return { allowed: true, count: 0, tracked: false }; // Fail Open + return { allowed: true, count: 0, tracked: false, referenced: false }; // Fail Open } } @@ -874,9 +885,21 @@ export class RateLimitService { } const key = `provider:${providerId}:active_sessions`; + const refKey = `provider:${providerId}:active_session_refs`; try { - await redis.zrem(key, sessionId); - logger.debug("[RateLimit] Released provider session", { providerId, sessionId }); + const [removed, remainingRefs] = (await redis.eval( + RELEASE_PROVIDER_SESSION, + 2, + key, + refKey, + sessionId + )) as [number, number]; + logger.debug("[RateLimit] Released provider session", { + providerId, + sessionId, + removed, + remainingRefs, + }); } catch (error) { logger.error("[RateLimit] Failed to release provider session", { providerId, diff --git a/src/lib/redis/lua-scripts.ts b/src/lib/redis/lua-scripts.ts index 402b702d0..7513e4119 100644 --- a/src/lib/redis/lua-scripts.ts +++ b/src/lib/redis/lua-scripts.ts @@ -14,18 +14,21 @@ * 4. If not exceeded, track new session (atomic operation) * * KEYS[1]: provider:${providerId}:active_sessions + * KEYS[2]: provider:${providerId}:active_session_refs * ARGV[1]: sessionId * ARGV[2]: limit (concurrency limit) * ARGV[3]: now (current timestamp, ms) * ARGV[4]: ttlMs (optional, cleanup window in ms, default 300000) * * Return: - * - {1, count, 1} - allowed (new tracking), returns new count and tracked=1 - * - {1, count, 0} - allowed (already tracked), returns current count and tracked=0 - * - {0, count, 0} - rejected (limit reached), returns current count and tracked=0 + * - {1, count, 1, 1} - allowed (new tracking), returns new count, tracked=1, referenced=1 + * - {1, count, 0, 1} - allowed (already tracked with refs), returns count, tracked=0, referenced=1 + * - {1, count, 0, 0} - allowed (legacy tracked without refs), returns count, tracked=0, referenced=0 + * - {0, count, 0, 0} - rejected (limit reached), returns current count and tracked=0 */ export const CHECK_AND_TRACK_SESSION = ` local provider_key = KEYS[1] +local ref_key = KEYS[2] local session_id = ARGV[1] local limit = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) @@ -38,37 +41,86 @@ end -- 1. Cleanup expired sessions (TTL window ago) local cutoff = now - ttl +local expired_sessions = redis.call('ZRANGEBYSCORE', provider_key, '-inf', cutoff) redis.call('ZREMRANGEBYSCORE', provider_key, '-inf', cutoff) +for _, expired_session_id in ipairs(expired_sessions) do + redis.call('HDEL', ref_key, expired_session_id) +end -- 2. Check if session is already tracked local is_tracked = redis.call('ZSCORE', provider_key, session_id) +-- Direct cleanup paths may remove the ZSET member before this script sees the session again. +-- When the member is absent, discard any stale reference hash value before acquiring a new ref. +if not is_tracked then + redis.call('HDEL', ref_key, session_id) +end + +local existing_refs = tonumber(redis.call('HGET', ref_key, session_id) or '0') + -- 3. Get current concurrency count local current_count = redis.call('ZCARD', provider_key) -- 4. Check limit (exclude already tracked session) if limit > 0 and not is_tracked and current_count >= limit then - return {0, current_count, 0} -- {allowed=false, current_count, tracked=0} + return {0, current_count, 0, 0} -- {allowed=false, current_count, tracked=0, referenced=0} end -- 5. Track session (ZADD updates timestamp for existing members) redis.call('ZADD', provider_key, now, session_id) +local referenced = 0 +if not is_tracked or existing_refs > 0 then + redis.call('HINCRBY', ref_key, session_id, 1) + referenced = 1 +end + -- 6. Set TTL based on session TTL (at least 1h to cover active sessions) local ttl_seconds = math.floor(ttl / 1000) local expire_ttl = math.max(3600, ttl_seconds) redis.call('EXPIRE', provider_key, expire_ttl) +redis.call('EXPIRE', ref_key, expire_ttl) -- 7. Return success if is_tracked then -- Already tracked, count unchanged - return {1, current_count, 0} -- {allowed=true, count, tracked=0} + return {1, current_count, 0, referenced} -- {allowed=true, count, tracked=0, referenced} else -- New tracking, count +1 - return {1, current_count + 1, 1} -- {allowed=true, new_count, tracked=1} + return {1, current_count + 1, 1, referenced} -- {allowed=true, new_count, tracked=1, referenced=1} end `; +/** + * Release provider-level active session membership with per-session references. + * + * KEYS[1]: provider:${providerId}:active_sessions + * KEYS[2]: provider:${providerId}:active_session_refs + * ARGV[1]: sessionId + * + * Return: {removed, remainingRefs} + */ +export const RELEASE_PROVIDER_SESSION = ` +local provider_key = KEYS[1] +local ref_key = KEYS[2] +local session_id = ARGV[1] + +local current_refs = tonumber(redis.call('HGET', ref_key, session_id) or '0') +if current_refs <= 0 then + return {0, 0} +end + +local remaining_refs = current_refs - 1 +if remaining_refs > 0 then + redis.call('HSET', ref_key, session_id, remaining_refs) + return {0, remaining_refs} +end + +redis.call('HDEL', ref_key, session_id) +local removed = redis.call('ZREM', provider_key, session_id) +return {removed, remaining_refs} +`; + /** * Key/User 并发:原子性检查 + 追踪(修复竞态条件) * diff --git a/src/lib/session-manager.ts b/src/lib/session-manager.ts index bac9d11bd..1acafdd1b 100644 --- a/src/lib/session-manager.ts +++ b/src/lib/session-manager.ts @@ -2433,6 +2433,7 @@ export class SessionManager { if (providerId) { pipeline.zrem(`provider:${providerId}:active_sessions`, sessionId); + pipeline.hdel(`provider:${providerId}:active_session_refs`, sessionId); } if (keyId) { diff --git a/src/lib/session-tracker.ts b/src/lib/session-tracker.ts index 8690d2128..dd278a521 100644 --- a/src/lib/session-tracker.ts +++ b/src/lib/session-tracker.ts @@ -6,6 +6,13 @@ import { } from "@/lib/redis/active-session-keys"; import { getRedisClient } from "./redis"; +const PROVIDER_ACTIVE_SESSIONS_PATTERN = /^provider:(\d+):active_sessions$/; + +function getProviderActiveSessionRefsKey(activeSessionsKey: string): string | null { + const match = PROVIDER_ACTIVE_SESSIONS_PATTERN.exec(activeSessionsKey); + return match ? `provider:${match[1]}:active_session_refs` : null; +} + /** * Session 追踪器 - 统一管理活跃 Session 集合 * @@ -141,8 +148,11 @@ export class SessionTracker { pipeline.zadd(globalKey, now, sessionId); // 添加到 provider 级集合(ZSET) - pipeline.zadd(`provider:${providerId}:active_sessions`, now, sessionId); - pipeline.expire(`provider:${providerId}:active_sessions`, 3600); + const providerZSetKey = `provider:${providerId}:active_sessions`; + const providerRefKey = `provider:${providerId}:active_session_refs`; + pipeline.zadd(providerZSetKey, now, sessionId); + pipeline.expire(providerZSetKey, 3600); + pipeline.expire(providerRefKey, 3600); const results = await pipeline.exec(); @@ -190,25 +200,42 @@ export class SessionTracker { const pipeline = redis.pipeline(); const ttlSeconds = SessionTracker.SESSION_TTL_SECONDS; const providerZSetKey = `provider:${providerId}:active_sessions`; + const providerRefKey = `provider:${providerId}:active_session_refs`; const globalKey = getGlobalActiveSessionsKey(); const keyZSetKey = getKeyActiveSessionsKey(keyId); + let commandIndex = 0; + let cleanupExpiredSessionsResultIndex: number | null = null; pipeline.zadd(globalKey, now, sessionId); + commandIndex++; pipeline.zadd(keyZSetKey, now, sessionId); + commandIndex++; pipeline.zadd(providerZSetKey, now, sessionId); + commandIndex++; // Use dynamic TTL based on session TTL (at least 1h to cover active sessions) pipeline.expire(providerZSetKey, Math.max(3600, ttlSeconds)); + commandIndex++; + pipeline.expire(providerRefKey, Math.max(3600, ttlSeconds)); + commandIndex++; if (userId !== undefined) { pipeline.zadd(getUserActiveSessionsKey(userId), now, sessionId); + commandIndex++; } pipeline.expire(`session:${sessionId}:provider`, ttlSeconds); + commandIndex++; pipeline.expire(`session:${sessionId}:key`, ttlSeconds); + commandIndex++; pipeline.setex(`session:${sessionId}:last_seen`, ttlSeconds, now.toString()); + commandIndex++; if (Math.random() < SessionTracker.CLEANUP_PROBABILITY) { const cutoffMs = now - SessionTracker.SESSION_TTL_MS; + cleanupExpiredSessionsResultIndex = commandIndex; + pipeline.zrangebyscore(providerZSetKey, "-inf", cutoffMs); + commandIndex++; pipeline.zremrangebyscore(providerZSetKey, "-inf", cutoffMs); + commandIndex++; } const results = await pipeline.exec(); @@ -227,6 +254,18 @@ export class SessionTracker { } } + if (cleanupExpiredSessionsResultIndex !== null && results) { + const expiredResult = results[cleanupExpiredSessionsResultIndex]; + if (!expiredResult?.[0] && Array.isArray(expiredResult?.[1])) { + const expiredSessionIds = expiredResult[1].filter( + (value): value is string => typeof value === "string" && value.length > 0 + ); + if (expiredSessionIds.length > 0) { + await redis.hdel(providerRefKey, ...expiredSessionIds); + } + } + } + logger.trace("SessionTracker: Refreshed session", { sessionId }); } catch (error) { logger.error("SessionTracker: Failed to refresh session", { error }); @@ -397,6 +436,7 @@ export class SessionTracker { for (const providerId of providerIds) { const key = `provider:${providerId}:active_sessions`; // 清理过期 session + cleanupPipeline.zrangebyscore(key, "-inf", cutoffMs); cleanupPipeline.zremrangebyscore(key, "-inf", cutoffMs); // 获取剩余 session IDs cleanupPipeline.zrange(key, 0, -1); @@ -410,11 +450,22 @@ export class SessionTracker { // 收集需要验证的 session IDs const providerSessionMap = new Map(); const allSessionIds: string[] = []; + const expiredProviderSessions = new Map(); for (let i = 0; i < providerIds.length; i++) { const providerId = providerIds[i]; - // 每个 provider 有 2 个命令(zremrangebyscore + zrange) - const zrangeResult = cleanupResults[i * 2 + 1]; + // 每个 provider 有 3 个命令(zrangebyscore + zremrangebyscore + zrange) + const expiredResult = cleanupResults[i * 3]; + const zrangeResult = cleanupResults[i * 3 + 2]; + + if (expiredResult && expiredResult[0] === null && Array.isArray(expiredResult[1])) { + expiredProviderSessions.set( + providerId, + expiredResult[1].filter( + (value): value is string => typeof value === "string" && value.length > 0 + ) + ); + } if (zrangeResult && zrangeResult[0] === null) { const sessionIds = zrangeResult[1] as string[]; @@ -425,6 +476,21 @@ export class SessionTracker { } } + const refCleanupPipeline = redis.pipeline(); + let hasRefCleanup = false; + for (const [providerId, expiredSessionIds] of expiredProviderSessions) { + if (expiredSessionIds.length > 0) { + refCleanupPipeline.hdel( + `provider:${providerId}:active_session_refs`, + ...expiredSessionIds + ); + hasRefCleanup = true; + } + } + if (hasRefCleanup) { + await refCleanupPipeline.exec(); + } + // 如果没有 session,直接返回 if (allSessionIds.length === 0) { return result; @@ -533,7 +599,14 @@ export class SessionTracker { const cutoffMs = now - SessionTracker.SESSION_TTL_MS; // 1. 清理过期 session(5 分钟前) + const providerRefKey = getProviderActiveSessionRefsKey(key); + const expiredSessionIds = providerRefKey + ? await redis.zrangebyscore(key, "-inf", cutoffMs) + : []; await redis.zremrangebyscore(key, "-inf", cutoffMs); + if (providerRefKey && expiredSessionIds.length > 0) { + await redis.hdel(providerRefKey, ...expiredSessionIds); + } // 2. 获取剩余的 session ID const sessionIds = await redis.zrange(key, 0, -1); diff --git a/tests/unit/lib/rate-limit/provider-session-release.test.ts b/tests/unit/lib/rate-limit/provider-session-release.test.ts index 9102db6b6..217083be0 100644 --- a/tests/unit/lib/rate-limit/provider-session-release.test.ts +++ b/tests/unit/lib/rate-limit/provider-session-release.test.ts @@ -2,11 +2,11 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; type RedisClientMock = { status: string; - zrem: (key: string, member: string) => Promise; + eval: (...args: unknown[]) => Promise<[number, number]>; }; let redisClientRef: RedisClientMock | null; -let zremMock: ReturnType Promise>>; +let evalMock: ReturnType Promise<[number, number]>>>; vi.mock("server-only", () => ({})); @@ -25,20 +25,35 @@ vi.mock("@/lib/logger", () => ({ describe("RateLimitService.releaseProviderSession", () => { beforeEach(() => { vi.clearAllMocks(); - zremMock = vi.fn(async () => 1); + evalMock = vi.fn(async () => [1, 0]); redisClientRef = { status: "ready", - zrem: zremMock, + eval: evalMock, }; }); - it("应从供应商 active_sessions ZSET 中释放失败请求的 sessionId", async () => { + it("应通过引用计数脚本释放失败请求的 provider session", async () => { const { RateLimitService } = await import("@/lib/rate-limit/service"); await RateLimitService.releaseProviderSession(42, "sess_failed"); - expect(zremMock).toHaveBeenCalledTimes(1); - expect(zremMock).toHaveBeenCalledWith("provider:42:active_sessions", "sess_failed"); + expect(evalMock).toHaveBeenCalledTimes(1); + expect(evalMock).toHaveBeenCalledWith( + expect.any(String), + 2, + "provider:42:active_sessions", + "provider:42:active_session_refs", + "sess_failed" + ); + }); + + it("仍有并发引用时不应直接 ZREM active session", async () => { + evalMock.mockResolvedValueOnce([0, 1]); + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(evalMock).toHaveBeenCalledTimes(1); }); it("Redis 不可用或未 ready 时应静默跳过", async () => { @@ -47,10 +62,10 @@ describe("RateLimitService.releaseProviderSession", () => { redisClientRef = null; await RateLimitService.releaseProviderSession(42, "sess_failed"); - redisClientRef = { status: "connecting", zrem: zremMock }; + redisClientRef = { status: "connecting", eval: evalMock }; await RateLimitService.releaseProviderSession(42, "sess_failed"); - expect(zremMock).not.toHaveBeenCalled(); + expect(evalMock).not.toHaveBeenCalled(); }); it("非法 providerId 或空 sessionId 不应触发 Redis 命令", async () => { @@ -60,12 +75,12 @@ describe("RateLimitService.releaseProviderSession", () => { await RateLimitService.releaseProviderSession(-1, "sess_failed"); await RateLimitService.releaseProviderSession(42, " "); - expect(zremMock).not.toHaveBeenCalled(); + expect(evalMock).not.toHaveBeenCalled(); }); it("释放失败时应记录日志但不向请求链路抛错", async () => { const error = new Error("redis down"); - zremMock.mockRejectedValueOnce(error); + evalMock.mockRejectedValueOnce(error); const { RateLimitService } = await import("@/lib/rate-limit/service"); const { logger } = await import("@/lib/logger"); diff --git a/tests/unit/lib/rate-limit/service-extra.test.ts b/tests/unit/lib/rate-limit/service-extra.test.ts index 235a22780..46d411f77 100644 --- a/tests/unit/lib/rate-limit/service-extra.test.ts +++ b/tests/unit/lib/rate-limit/service-extra.test.ts @@ -145,7 +145,7 @@ describe("RateLimitService - other quota paths", () => { const { RateLimitService } = await import("@/lib/rate-limit"); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 0); - expect(result).toEqual({ allowed: true, count: 0, tracked: false }); + expect(result).toEqual({ allowed: true, count: 0, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession:Redis 非 ready 时应 Fail Open", async () => { @@ -153,13 +153,13 @@ describe("RateLimitService - other quota paths", () => { redisClientRef.status = "end"; const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); - expect(result).toEqual({ allowed: true, count: 0, tracked: false }); + expect(result).toEqual({ allowed: true, count: 0, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession:达到上限时应返回 not allowed", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([0, 2, 0]); + redisClientRef.eval.mockResolvedValueOnce([0, 2, 0, 0]); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); expect(result.allowed).toBe(false); expect(result.reason).toContain("供应商并发 Session 上限已达到(2/2)"); @@ -168,27 +168,37 @@ describe("RateLimitService - other quota paths", () => { it("checkAndTrackProviderSession:未达到上限时应返回 allowed 且可标记 tracked", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]); + redisClientRef.eval.mockResolvedValueOnce([1, 1, 1, 1]); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); - expect(result).toEqual({ allowed: true, count: 1, tracked: true }); + expect(result).toEqual({ allowed: true, count: 1, tracked: true, referenced: true }); + }); + + it("checkAndTrackProviderSession:旧 membership 无引用计数时不应返回 release 引用", async () => { + const { RateLimitService } = await import("@/lib/rate-limit"); + + redisClientRef.eval.mockResolvedValueOnce([1, 1, 0, 0]); + const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); + expect(result).toEqual({ allowed: true, count: 1, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession: should pass SESSION_TTL_MS as ARGV[4] to Lua script", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]); + redisClientRef.eval.mockResolvedValueOnce([1, 1, 1, 1]); await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); // Verify eval was called with the correct args including ARGV[4] = SESSION_TTL_MS expect(redisClientRef.eval).toHaveBeenCalledTimes(1); const evalCall = redisClientRef.eval.mock.calls[0]; - // evalCall: [script, numkeys, key, sessionId, limit, now, ttlMs] - // Indices: 0 1 2 3 4 5 6 - expect(evalCall.length).toBe(7); // script + 1 key + 5 ARGV - - // ARGV[4] (index 6) should be SESSION_TTL_MS derived from env (default 300s = 300000ms) - const ttlMsArg = evalCall[6]; + // evalCall: [script, numkeys, activeKey, refKey, sessionId, limit, now, ttlMs] + // Indices: 0 1 2 3 4 5 6 7 + expect(evalCall.length).toBe(8); // script + 2 keys + 4 ARGV + expect(evalCall[2]).toBe("provider:9:active_sessions"); + expect(evalCall[3]).toBe("provider:9:active_session_refs"); + + // ARGV[4] (index 7) should be SESSION_TTL_MS derived from env (default 300s = 300000ms) + const ttlMsArg = evalCall[7]; expect(ttlMsArg).toBe("300000"); }); diff --git a/tests/unit/lib/session-manager-terminate-session.test.ts b/tests/unit/lib/session-manager-terminate-session.test.ts index f4de279ac..f61889538 100644 --- a/tests/unit/lib/session-manager-terminate-session.test.ts +++ b/tests/unit/lib/session-manager-terminate-session.test.ts @@ -27,6 +27,7 @@ describe("SessionManager.terminateSession", () => { pipelineRef = { del: vi.fn(() => pipelineRef), zrem: vi.fn(() => pipelineRef), + hdel: vi.fn(() => pipelineRef), exec: vi.fn(async () => [[null, 1]]), }; diff --git a/tests/unit/lib/session-tracker-cleanup.test.ts b/tests/unit/lib/session-tracker-cleanup.test.ts index 554c6723e..e13f718ab 100644 --- a/tests/unit/lib/session-tracker-cleanup.test.ts +++ b/tests/unit/lib/session-tracker-cleanup.test.ts @@ -25,10 +25,18 @@ const makePipeline = () => { pipelineCalls.push(["zremrangebyscore", ...args]); return pipeline; }), + zrangebyscore: vi.fn((...args: unknown[]) => { + pipelineCalls.push(["zrangebyscore", ...args]); + return pipeline; + }), zrange: vi.fn((...args: unknown[]) => { pipelineCalls.push(["zrange", ...args]); return pipeline; }), + hdel: vi.fn((...args: unknown[]) => { + pipelineCalls.push(["hdel", ...args]); + return pipeline; + }), exists: vi.fn((...args: unknown[]) => { pipelineCalls.push(["exists", ...args]); return pipeline; @@ -72,6 +80,8 @@ describe("SessionTracker - TTL and cleanup", () => { exists: vi.fn(async () => 1), type: vi.fn(async () => "zset"), del: vi.fn(async () => 1), + hdel: vi.fn(async () => 0), + zrangebyscore: vi.fn(async () => []), zremrangebyscore: vi.fn(async () => 0), zrange: vi.fn(async () => []), pipeline: vi.fn(() => makePipeline()), diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index ee77b3d6f..7f8fbff8c 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -20,6 +20,12 @@ const mocks = vi.hoisted(() => ({ recordEndpointFailure: vi.fn(async () => {}), isVendorTypeCircuitOpen: vi.fn(async () => false), recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}), + checkAndTrackProviderSession: vi.fn(async () => ({ + allowed: true, + count: 1, + tracked: true, + referenced: true, + })), releaseProviderSession: vi.fn(async (_providerId: number, _sessionId: string) => {}), categorizeErrorAsync: vi.fn(async () => 0), getErrorDetectionResultAsync: vi.fn(async () => ({ matched: false })), @@ -76,6 +82,7 @@ vi.mock("@/lib/vendor-type-circuit-breaker", () => ({ vi.mock("@/lib/rate-limit/service", () => ({ RateLimitService: { + checkAndTrackProviderSession: mocks.checkAndTrackProviderSession, releaseProviderSession: mocks.releaseProviderSession, }, })); @@ -229,6 +236,11 @@ function createSession(clientAbortSignal: AbortSignal | null = null): ProxySessi return session as ProxySession; } +function setProviderWithSessionRef(session: ProxySession, provider: Provider): void { + session.setProvider(provider); + session.recordProviderSessionRef(provider.id); +} + function createStreamingResponse(params: { label: string; firstChunkDelayMs: number; @@ -317,6 +329,12 @@ function withThinkingBlocks(session: ProxySession): void { describe("ProxyForwarder - first-byte hedge scheduling", () => { beforeEach(() => { vi.clearAllMocks(); + mocks.checkAndTrackProviderSession.mockResolvedValue({ + allowed: true, + count: 1, + tracked: true, + referenced: true, + }); }); test("shadow session redirect should not overwrite initial provider redirect and winner should keep its own redirect", () => { @@ -513,7 +531,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const session = createSession(); session.request.model = requestedModel; session.request.message.model = requestedModel; - session.setProvider(fireworks); + setProviderWithSessionRef(session, fireworks); session.addProviderToChain(fireworks, { reason: "initial_selection" }); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(minimax); @@ -825,7 +843,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -888,6 +906,99 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { } }); + test("hedge skips provider when concurrent session acquire is rejected", async () => { + vi.useFakeTimers(); + + try { + const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const provider2 = createProvider({ + id: 2, + name: "p2", + firstByteTimeoutStreamingMs: 100, + limitConcurrentSessions: 1, + }); + const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(); + setProviderWithSessionRef(session, provider1); + + mocks.pickRandomProviderWithExclusion + .mockResolvedValueOnce(provider2) + .mockResolvedValueOnce(provider3); + mocks.checkAndTrackProviderSession + .mockResolvedValueOnce({ + allowed: false, + count: 1, + tracked: false, + referenced: false, + reason: "供应商并发 Session 上限已达到(1/1)", + }) + .mockResolvedValueOnce({ allowed: true, count: 1, tracked: true, referenced: true }); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + const controller1 = new AbortController(); + const controller3 = new AbortController(); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller1; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1", + firstChunkDelayMs: 220, + controller: controller1, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller3; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p3", + firstChunkDelayMs: 40, + controller: controller3, + }); + }); + + const responsePromise = ProxyForwarder.send(session); + + await vi.advanceTimersByTimeAsync(100); + expect(doForward).toHaveBeenCalledTimes(2); + expect(doForward).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ id: 2 }), + expect.anything(), + expect.anything(), + expect.anything(), + expect.anything() + ); + + await vi.advanceTimersByTimeAsync(50); + const response = await responsePromise; + + expect(await response.text()).toContain('"provider":"p3"'); + expect(session.provider?.id).toBe(3); + expect(mocks.checkAndTrackProviderSession).toHaveBeenNthCalledWith(1, 2, "sess-hedge", 1); + expect(mocks.checkAndTrackProviderSession).toHaveBeenNthCalledWith(2, 3, "sess-hedge", 0); + expect(session.getProviderChain()).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 2, reason: "concurrent_limit_failed" }), + expect.objectContaining({ id: 3, reason: "hedge_winner" }), + ]) + ); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); + expect(mocks.releaseProviderSession).not.toHaveBeenCalledWith(2, "sess-hedge"); + } finally { + vi.useRealTimers(); + } + }); + test("高并发模式:hedge winner 成功后不应写 session provider 观测信息", async () => { vi.useFakeTimers(); @@ -896,7 +1007,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); session.setHighConcurrencyModeEnabled(true); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -962,7 +1073,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { firstByteTimeoutStreamingMs: 100, }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1031,7 +1142,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1094,7 +1205,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion .mockResolvedValueOnce(provider2) @@ -1195,7 +1306,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const session = createSession(clientAbortController.signal); session.request.model = requestedModel; session.request.message.model = requestedModel; - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1335,7 +1446,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { }); const session = createSession(); session.requestUrl = new URL("https://example.com/v1/messages"); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(new Error("Redis connection lost")); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(null); @@ -1758,7 +1869,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { }); const session = createSession(); session.requestUrl = new URL("https://example.com/v1/messages"); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); // Provider 1's strict endpoint resolution will fail mocks.getPreferredProviderEndpoints.mockRejectedValueOnce( diff --git a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts index 13a68d3b4..98bf78997 100644 --- a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts +++ b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts @@ -11,12 +11,18 @@ vi.mock("@/lib/rate-limit/service", () => ({ }, })); +vi.mock("@/lib/rate-limit", () => ({ + RateLimitService: { + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + describe("ProxyForwarder provider failure session release", () => { beforeEach(() => { mocks.releaseProviderSession.mockClear(); }); - it("标记供应商失败时应同步释放 provider active session", async () => { + it("标记供应商失败时仅释放本请求已获取的 provider session ref", async () => { const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); const forwarderInternals = ProxyForwarder as unknown as { markProviderFailed: ( @@ -25,15 +31,43 @@ describe("ProxyForwarder provider failure session release", () => { providerId: number ) => void; }; - const session = { sessionId: "sess_failed" } as unknown as ProxySession; + const consumeProviderSessionRef = vi.fn(() => true); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; const failedProviderIds: number[] = []; forwarderInternals.markProviderFailed(session, failedProviderIds, 42); expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledWith(42); expect(mocks.releaseProviderSession).toHaveBeenCalledWith(42, "sess_failed"); }); + it("未获取 provider session ref 的 fallback/hedge provider 不应释放 Redis membership", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const consumeProviderSessionRef = vi.fn(() => false); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledWith(42); + expect(mocks.releaseProviderSession).not.toHaveBeenCalled(); + }); + it("重复标记同一供应商时只释放一次,避免 hedge 路径重复 ZREM", async () => { const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); const forwarderInternals = ProxyForwarder as unknown as { @@ -43,13 +77,18 @@ describe("ProxyForwarder provider failure session release", () => { providerId: number ) => void; }; - const session = { sessionId: "sess_failed" } as unknown as ProxySession; + const consumeProviderSessionRef = vi.fn(() => true); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; const failedProviderIds: number[] = []; forwarderInternals.markProviderFailed(session, failedProviderIds, 42); forwarderInternals.markProviderFailed(session, failedProviderIds, 42); expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledTimes(1); expect(mocks.releaseProviderSession).toHaveBeenCalledTimes(1); });