diff --git a/src/app/v1/_lib/proxy/client-abort-listener.ts b/src/app/v1/_lib/proxy/client-abort-listener.ts new file mode 100644 index 000000000..d47758907 --- /dev/null +++ b/src/app/v1/_lib/proxy/client-abort-listener.ts @@ -0,0 +1,25 @@ +export function bindClientAbortListener( + signal: AbortSignal | null | undefined, + onAbort: () => void +): () => void { + if (!signal) { + return () => {}; + } + + if (signal.aborted) { + onAbort(); + return () => {}; + } + + let cleaned = false; + signal.addEventListener("abort", onAbort, { once: true }); + + return () => { + if (cleaned) { + return; + } + cleaned = true; + // 正常完成时也要解绑,避免 listener 闭包继续持有 session 与请求体。 + signal.removeEventListener("abort", onAbort); + }; +} diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 204ccc722..8b6f383f1 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -50,6 +50,7 @@ import { GEMINI_PROTOCOL } from "../gemini/protocol"; import { HeaderProcessor, resolveAnthropicAuthHeaders } from "../headers"; import { buildProxyUrl } from "../url"; import { rectifyBillingHeader } from "./billing-header-rectifier"; +import { bindClientAbortListener } from "./client-abort-listener"; import { deriveClientSafeUpstreamErrorMessage } from "./client-error-message"; import { isStandardProxyEndpointPath } from "./endpoint-family-catalog"; import { resolveEndpointPolicy, shouldEnforceStrictEndpointPoolPolicy } from "./endpoint-policy"; @@ -3927,7 +3928,9 @@ export class ProxyForwarder { provider: Provider, useOriginalSession: boolean ): Promise => { - if (settled || winnerCommitted || launchedProviderIds.has(provider.id)) return false; + if (settled || winnerCommitted || noMoreProviders || launchedProviderIds.has(provider.id)) { + return false; + } launchedProviderIds.add(provider.id); @@ -4021,42 +4024,40 @@ export class ProxyForwarder { return true; }; - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener( - "abort", - () => { - if (settled || winnerCommitted) return; - noMoreProviders = true; - lastError = new ProxyError("Request aborted by client", 499, undefined, true); - lastErrorCategory = ErrorCategory.CLIENT_ABORT; - for (const attempt of Array.from(attempts)) { - if (!attempt.settled) { - session.addProviderToChain(attempt.provider, { - ...attempt.endpointAudit, - reason: "client_abort", - attemptNumber: attempt.sequence, - errorMessage: "Client aborted request", - modelRedirect: getAttemptModelRedirect(attempt), - }); - } - } - abortAllAttempts(undefined, "client_abort"); - void finishIfExhausted(); - }, - { once: true } - ); - } + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + if (settled || winnerCommitted) return; + noMoreProviders = true; + lastError = new ProxyError("Request aborted by client", 499, undefined, true); + lastErrorCategory = ErrorCategory.CLIENT_ABORT; + for (const attempt of Array.from(attempts)) { + if (!attempt.settled) { + session.addProviderToChain(attempt.provider, { + ...attempt.endpointAudit, + reason: "client_abort", + attemptNumber: attempt.sequence, + errorMessage: "Client aborted request", + modelRedirect: getAttemptModelRedirect(attempt), + }); + } + } + abortAllAttempts(undefined, "client_abort"); + void finishIfExhausted(); + }); - const initialLaunched = await startAttempt(initialProvider, true); - if (!initialLaunched) { - await launchAlternative(); - } - await finishIfExhausted(); - const result = await resultPromise; - if (result.error) { - throw result.error; + try { + const initialLaunched = await startAttempt(initialProvider, true); + if (!initialLaunched) { + await launchAlternative(); + } + await finishIfExhausted(); + const result = await resultPromise; + if (result.error) { + throw result.error; + } + return result.response as Response; + } finally { + cleanupClientAbortListener(); } - return result.response as Response; } private static async resolveStreamingHedgeEndpoint( diff --git a/src/app/v1/_lib/proxy/response-handler.ts b/src/app/v1/_lib/proxy/response-handler.ts index b7ac734cc..e2dd19edc 100644 --- a/src/app/v1/_lib/proxy/response-handler.ts +++ b/src/app/v1/_lib/proxy/response-handler.ts @@ -40,6 +40,7 @@ import type { LongContextPricingSpecialSetting } from "@/types/special-settings" import { GeminiAdapter } from "../gemini/adapter"; import type { GeminiResponse } from "../gemini/types"; import { extractActualResponseModelForProvider } from "./actual-response-model"; +import { bindClientAbortListener } from "./client-abort-listener"; import { isClientAbortError, isTransportError } from "./errors"; import type { ProxySession } from "./session"; import { consumeDeferredStreamingFinalization } from "./stream-finalization"; @@ -1073,6 +1074,10 @@ export class ProxyResponseHandler { // 使用 AsyncTaskManager 管理后台处理任务 const taskId = `non-stream-${messageContext?.id || `unknown-${Date.now()}`}`; const abortController = new AbortController(); + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + AsyncTaskManager.cancel(taskId); + abortController.abort(); + }); const processingPromise = (async () => { const finalizeNonStreamAbort = async (): Promise => { @@ -1502,6 +1507,7 @@ export class ProxyResponseHandler { }); } } finally { + cleanupClientAbortListener(); releaseSessionAgent(session); AsyncTaskManager.cleanup(taskId); } @@ -1526,14 +1532,6 @@ export class ProxyResponseHandler { }); }); - // 客户端断开时取消任务 - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener("abort", () => { - AsyncTaskManager.cancel(taskId); - abortController.abort(); - }); - } - void persistNonStreamAfterSnapshot(finalResponse).catch((error) => { logger.error("[ResponseHandler] Failed to persist non-stream after snapshot", { error }); }); @@ -2128,6 +2126,26 @@ export class ProxyResponseHandler { // ⭐ 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除 let idleTimeoutId: NodeJS.Timeout | null = null; + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + logger.debug("ResponseHandler: Client disconnected, cleaning up", { + taskId, + providerId: provider.id, + messageId: messageContext.id, + }); + + // 客户端断开时清除 idle timeout,避免任务已取消后仍误触发。 + if (idleTimeoutId) { + clearTimeout(idleTimeoutId); + idleTimeoutId = null; + logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", { + taskId, + providerId: provider.id, + }); + } + + AsyncTaskManager.cancel(taskId); + abortController.abort(); + }); const processingPromise = (async () => { const reader = internalStream.getReader(); @@ -2757,6 +2775,7 @@ export class ProxyResponseHandler { } } finally { // 确保资源释放 + cleanupClientAbortListener(); clearIdleTimer(); // ⭐ 清除静默期计时器(防止泄漏) try { reader.releaseLock(); @@ -2791,34 +2810,6 @@ export class ProxyResponseHandler { }); }); - // 客户端断开时取消任务并清除 idle timer - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener("abort", () => { - logger.debug("ResponseHandler: Client disconnected, cleaning up", { - taskId, - providerId: provider.id, - messageId: messageContext.id, - }); - - // ⭐ 1. 清除 idle timeout(避免误触发) - if (idleTimeoutId) { - clearTimeout(idleTimeoutId); - idleTimeoutId = null; - logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", { - taskId, - providerId: provider.id, - }); - } - - // 2. 取消后台任务 - AsyncTaskManager.cancel(taskId); - abortController.abort(); - - // 注意:不需要 streamController.error()(客户端已断开) - // 注意:不需要 responseController.abort()(上游会自然结束) - }); - } - // ⭐ 修复 Bun 运行时的 Transfer-Encoding 重复问题 // 清理上游的传输 headers,让 Response API 自动管理 const finalStreamHeaders = cleanResponseHeaders(response.headers); diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index 7f8fbff8c..0f41408b9 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -1919,4 +1919,59 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { vi.useRealTimers(); } }); + + test("removes streaming hedge client abort listener after winner response is returned", async () => { + const clientAbortController = new AbortController(); + const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener"); + const removeSpy = vi.spyOn(clientAbortController.signal, "removeEventListener"); + const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(clientAbortController.signal); + setProviderWithSessionRef(session, provider); + session.forwardedRequestBody = "x".repeat(512 * 1024); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + const upstreamController = new AbortController(); + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = upstreamController; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1", + firstChunkDelayMs: 0, + controller: upstreamController, + }); + }); + + const response = await ProxyForwarder.send(session); + expect(await response.text()).toContain('"provider":"p1"'); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + test("pre-aborted client signal should settle hedge without launching upstream attempt", async () => { + const clientAbortController = new AbortController(); + clientAbortController.abort(new Error("client_cancelled")); + const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener"); + const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(clientAbortController.signal); + setProviderWithSessionRef(session, provider); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + await expect(ProxyForwarder.send(session)).rejects.toMatchObject({ statusCode: 499 }); + expect(doForward).not.toHaveBeenCalled(); + expect(addSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + }); }); diff --git a/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts new file mode 100644 index 000000000..c5dda43b9 --- /dev/null +++ b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts @@ -0,0 +1,283 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; +import type { ProxySession } from "@/app/v1/_lib/proxy/session"; +import type { Provider } from "@/types/provider"; + +const testState = vi.hoisted(() => ({ + asyncTasks: [] as Promise[], + cancelTask: vi.fn(), + cleanupTask: vi.fn(), +})); + +vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({ + ResponseFixer: { + process: async (_session: unknown, response: Response) => response, + }, +})); + +vi.mock("@/lib/async-task-manager", () => ({ + AsyncTaskManager: { + register: (_taskId: string, promise: Promise) => { + testState.asyncTasks.push(promise); + return new AbortController(); + }, + cleanup: testState.cleanupTask, + cancel: testState.cancelTask, + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/lib/price-sync/cloud-price-updater", () => ({ + requestCloudPriceTableSync: vi.fn(), +})); + +vi.mock("@/lib/proxy-status-tracker", () => ({ + ProxyStatusTracker: { + getInstance: () => ({ + endRequest: vi.fn(), + }), + }, +})); + +vi.mock("@/lib/rate-limit", () => ({ + RateLimitService: { + trackCost: vi.fn(), + trackUserDailyCost: vi.fn(), + decrementLeaseBudget: vi.fn(), + }, +})); + +vi.mock("@/lib/redis/live-chain-store", () => ({ + deleteLiveChain: vi.fn(), +})); + +vi.mock("@/lib/session-manager", () => ({ + SessionManager: { + clearSessionProvider: vi.fn(), + storeSessionResponse: vi.fn(), + updateSessionUsage: vi.fn(), + storeSessionRequestPhaseSnapshot: vi.fn(), + storeSessionResponsePhaseSnapshot: vi.fn(), + storeSessionUpstreamRequestMeta: vi.fn(), + storeSessionSpecialSettings: vi.fn(), + storeSessionRequestHeaders: vi.fn(), + storeSessionResponseHeaders: vi.fn(), + storeSessionUpstreamResponseMeta: vi.fn(), + }, +})); + +vi.mock("@/lib/session-tracker", () => ({ + SessionTracker: { + refreshSession: vi.fn(), + }, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + recordFailure: vi.fn(), +})); + +vi.mock("@/lib/endpoint-circuit-breaker", () => ({ + recordEndpointFailure: vi.fn(), + recordEndpointSuccess: vi.fn(), + resetEndpointCircuit: vi.fn(), +})); + +vi.mock("@/repository/message", () => ({ + updateMessageRequestCostWithBreakdown: vi.fn(), + updateMessageRequestDetails: vi.fn(), + updateMessageRequestDuration: vi.fn(), +})); + +async function drainAsyncTasks(): Promise { + while (testState.asyncTasks.length > 0) { + const tasks = testState.asyncTasks.splice(0); + await Promise.allSettled(tasks); + await new Promise((resolve) => setTimeout(resolve, 0)); + } +} + +function makeProvider(overrides: Partial = {}): Provider { + return { + id: 99, + name: "test-provider", + providerType: "openai", + baseUrl: "https://api.test.invalid", + priority: 1, + weight: 1, + costMultiplier: 1, + groupTag: "default", + isEnabled: true, + models: [], + createdAt: new Date(), + updatedAt: new Date(), + streamingIdleTimeoutMs: 0, + ...overrides, + } as Provider; +} + +function makeSession(clientAbortSignal: AbortSignal | null, stream: boolean): ProxySession { + const endpointPolicy = resolveEndpointPolicy("/v1/chat/completions"); + const provider = makeProvider(); + const session = { + request: { + model: "gpt-5.4", + log: "", + message: { + model: "gpt-5.4", + stream, + messages: [{ role: "user", content: "hello" }], + }, + }, + startTime: Date.now(), + method: "POST", + requestUrl: new URL("http://localhost/v1/chat/completions"), + headers: new Headers(), + headerLog: "", + userAgent: null, + context: {}, + clientAbortSignal, + forwardedRequestBody: "", + userName: "test-user", + authState: { + success: true, + user: { id: 1, name: "test-user" }, + key: { id: 2, name: "test-key" }, + apiKey: "test-key", + }, + provider, + messageContext: { + id: 123, + user: { id: 1, name: "test-user" }, + key: { id: 2, name: "test-key" }, + isSystemPrompt: false, + requireAuth: true, + createdAt: new Date(), + }, + sessionId: null, + requestSequence: 1, + originalFormat: "openai", + providerType: "openai", + originalModelName: "gpt-5.4", + originalUrlPathname: "/v1/chat/completions", + providerChain: [], + cacheTtlResolved: null, + context1mApplied: false, + specialSettings: [], + cachedPriceData: undefined, + cachedBillingModelSource: undefined, + endpointPolicy, + isHeaderModified: () => false, + getEndpointPolicy: () => endpointPolicy, + getContext1mApplied: () => false, + getGroupCostMultiplier: () => 1, + getOriginalModel: () => "gpt-5.4", + getCurrentModel: () => "gpt-5.4", + getProviderChain: () => [], + getSpecialSettings: () => [], + shouldPersistSessionDebugArtifacts: () => false, + shouldTrackSessionObservability: () => false, + getResolvedPricingByBillingSource: async () => null, + recordTtfb: vi.fn(), + ttfbMs: null, + addProviderToChain: vi.fn(), + clearResponseTimeout: vi.fn(), + releaseAgent: vi.fn(), + }; + + return session as unknown as ProxySession; +} + +describe("ProxyResponseHandler client abort listener cleanup", () => { + beforeEach(() => { + testState.asyncTasks = []; + testState.cancelTask.mockClear(); + testState.cleanupTask.mockClear(); + vi.restoreAllMocks(); + }); + + it("removes non-stream client abort listener after response processing completes", async () => { + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, false); + const upstreamResponse = new Response( + JSON.stringify({ + choices: [{ message: { content: "ok" } }], + }), + { + headers: { "content-type": "application/json" }, + } + ); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + it("removes stream client abort listener after stream processing completes", async () => { + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, true); + const upstreamResponse = new Response( + 'data: {"choices":[{"delta":{"content":"ok"}}]}\n\ndata: [DONE]\n\n', + { + headers: { "content-type": "text/event-stream" }, + } + ); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + it("uses no-op cleanup when client abort signal is null", async () => { + const session = makeSession(null, false); + const upstreamResponse = new Response(JSON.stringify({ choices: [] }), { + headers: { "content-type": "application/json" }, + }); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + expect(testState.cancelTask).not.toHaveBeenCalled(); + }); + + it("invokes cancel synchronously when client signal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, false); + const upstreamResponse = new Response(JSON.stringify({ choices: [] }), { + headers: { "content-type": "application/json" }, + }); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + expect(addSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + expect(removeSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + expect(testState.cancelTask).toHaveBeenCalled(); + }); +});