Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/app/v1/_lib/proxy/client-abort-listener.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
export function bindClientAbortListener(
signal: AbortSignal | null | undefined,
onAbort: () => void
): () => void {
if (!signal) {
return () => {};
}

if (signal.aborted) {
onAbort();
return () => {};
}

let cleaned = false;
signal.addEventListener("abort", onAbort, { once: true });

return () => {
if (cleaned) {
return;
}
cleaned = true;
// 正常完成时也要解绑,避免 listener 闭包继续持有 session 与请求体。
signal.removeEventListener("abort", onAbort);
};
}
71 changes: 36 additions & 35 deletions src/app/v1/_lib/proxy/forwarder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import { GEMINI_PROTOCOL } from "../gemini/protocol";
import { HeaderProcessor, resolveAnthropicAuthHeaders } from "../headers";
import { buildProxyUrl } from "../url";
import { rectifyBillingHeader } from "./billing-header-rectifier";
import { bindClientAbortListener } from "./client-abort-listener";
import { deriveClientSafeUpstreamErrorMessage } from "./client-error-message";
import { isStandardProxyEndpointPath } from "./endpoint-family-catalog";
import { resolveEndpointPolicy, shouldEnforceStrictEndpointPoolPolicy } from "./endpoint-policy";
Expand Down Expand Up @@ -3927,7 +3928,9 @@ export class ProxyForwarder {
provider: Provider,
useOriginalSession: boolean
): Promise<boolean> => {
if (settled || winnerCommitted || launchedProviderIds.has(provider.id)) return false;
if (settled || winnerCommitted || noMoreProviders || launchedProviderIds.has(provider.id)) {
return false;
}

launchedProviderIds.add(provider.id);

Expand Down Expand Up @@ -4021,42 +4024,40 @@ export class ProxyForwarder {
return true;
};

if (session.clientAbortSignal) {
session.clientAbortSignal.addEventListener(
"abort",
() => {
if (settled || winnerCommitted) return;
noMoreProviders = true;
lastError = new ProxyError("Request aborted by client", 499, undefined, true);
lastErrorCategory = ErrorCategory.CLIENT_ABORT;
for (const attempt of Array.from(attempts)) {
if (!attempt.settled) {
session.addProviderToChain(attempt.provider, {
...attempt.endpointAudit,
reason: "client_abort",
attemptNumber: attempt.sequence,
errorMessage: "Client aborted request",
modelRedirect: getAttemptModelRedirect(attempt),
});
}
}
abortAllAttempts(undefined, "client_abort");
void finishIfExhausted();
},
{ once: true }
);
}
const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => {
if (settled || winnerCommitted) return;
noMoreProviders = true;
lastError = new ProxyError("Request aborted by client", 499, undefined, true);
lastErrorCategory = ErrorCategory.CLIENT_ABORT;
for (const attempt of Array.from(attempts)) {
if (!attempt.settled) {
session.addProviderToChain(attempt.provider, {
...attempt.endpointAudit,
reason: "client_abort",
attemptNumber: attempt.sequence,
errorMessage: "Client aborted request",
modelRedirect: getAttemptModelRedirect(attempt),
});
}
}
abortAllAttempts(undefined, "client_abort");
void finishIfExhausted();
});

const initialLaunched = await startAttempt(initialProvider, true);
if (!initialLaunched) {
await launchAlternative();
}
await finishIfExhausted();
const result = await resultPromise;
if (result.error) {
throw result.error;
try {
const initialLaunched = await startAttempt(initialProvider, true);
if (!initialLaunched) {
await launchAlternative();
}
await finishIfExhausted();
const result = await resultPromise;
if (result.error) {
throw result.error;
}
return result.response as Response;
} finally {
cleanupClientAbortListener();
}
return result.response as Response;
}

private static async resolveStreamingHedgeEndpoint(
Expand Down
63 changes: 27 additions & 36 deletions src/app/v1/_lib/proxy/response-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import type { LongContextPricingSpecialSetting } from "@/types/special-settings"
import { GeminiAdapter } from "../gemini/adapter";
import type { GeminiResponse } from "../gemini/types";
import { extractActualResponseModelForProvider } from "./actual-response-model";
import { bindClientAbortListener } from "./client-abort-listener";
import { isClientAbortError, isTransportError } from "./errors";
import type { ProxySession } from "./session";
import { consumeDeferredStreamingFinalization } from "./stream-finalization";
Expand Down Expand Up @@ -1073,6 +1074,10 @@ export class ProxyResponseHandler {
// 使用 AsyncTaskManager 管理后台处理任务
const taskId = `non-stream-${messageContext?.id || `unknown-${Date.now()}`}`;
const abortController = new AbortController();
const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => {
AsyncTaskManager.cancel(taskId);
abortController.abort();
});

const processingPromise = (async () => {
const finalizeNonStreamAbort = async (): Promise<void> => {
Expand Down Expand Up @@ -1502,6 +1507,7 @@ export class ProxyResponseHandler {
});
}
} finally {
cleanupClientAbortListener();
releaseSessionAgent(session);
AsyncTaskManager.cleanup(taskId);
}
Expand All @@ -1526,14 +1532,6 @@ export class ProxyResponseHandler {
});
});

// 客户端断开时取消任务
if (session.clientAbortSignal) {
session.clientAbortSignal.addEventListener("abort", () => {
AsyncTaskManager.cancel(taskId);
abortController.abort();
});
}

void persistNonStreamAfterSnapshot(finalResponse).catch((error) => {
logger.error("[ResponseHandler] Failed to persist non-stream after snapshot", { error });
});
Expand Down Expand Up @@ -2128,6 +2126,26 @@ export class ProxyResponseHandler {

// ⭐ 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除
let idleTimeoutId: NodeJS.Timeout | null = null;
const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => {
logger.debug("ResponseHandler: Client disconnected, cleaning up", {
taskId,
providerId: provider.id,
messageId: messageContext.id,
});

// 客户端断开时清除 idle timeout,避免任务已取消后仍误触发。
if (idleTimeoutId) {
clearTimeout(idleTimeoutId);
idleTimeoutId = null;
logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", {
taskId,
providerId: provider.id,
});
}

AsyncTaskManager.cancel(taskId);
abortController.abort();
});

const processingPromise = (async () => {
const reader = internalStream.getReader();
Expand Down Expand Up @@ -2757,6 +2775,7 @@ export class ProxyResponseHandler {
}
} finally {
// 确保资源释放
cleanupClientAbortListener();
clearIdleTimer(); // ⭐ 清除静默期计时器(防止泄漏)
try {
reader.releaseLock();
Expand Down Expand Up @@ -2791,34 +2810,6 @@ export class ProxyResponseHandler {
});
});

// 客户端断开时取消任务并清除 idle timer
if (session.clientAbortSignal) {
session.clientAbortSignal.addEventListener("abort", () => {
logger.debug("ResponseHandler: Client disconnected, cleaning up", {
taskId,
providerId: provider.id,
messageId: messageContext.id,
});

// ⭐ 1. 清除 idle timeout(避免误触发)
if (idleTimeoutId) {
clearTimeout(idleTimeoutId);
idleTimeoutId = null;
logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", {
taskId,
providerId: provider.id,
});
}

// 2. 取消后台任务
AsyncTaskManager.cancel(taskId);
abortController.abort();

// 注意:不需要 streamController.error()(客户端已断开)
// 注意:不需要 responseController.abort()(上游会自然结束)
});
}

// ⭐ 修复 Bun 运行时的 Transfer-Encoding 重复问题
// 清理上游的传输 headers,让 Response API 自动管理
const finalStreamHeaders = cleanResponseHeaders(response.headers);
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1919,4 +1919,59 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => {
vi.useRealTimers();
}
});

test("removes streaming hedge client abort listener after winner response is returned", async () => {
const clientAbortController = new AbortController();
const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener");
const removeSpy = vi.spyOn(clientAbortController.signal, "removeEventListener");
const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
const session = createSession(clientAbortController.signal);
setProviderWithSessionRef(session, provider);
session.forwardedRequestBody = "x".repeat(512 * 1024);

const doForward = vi.spyOn(
ProxyForwarder as unknown as {
doForward: (...args: unknown[]) => Promise<Response>;
},
"doForward"
);
const upstreamController = new AbortController();
doForward.mockImplementationOnce(async (attemptSession) => {
const runtime = attemptSession as ProxySession & AttemptRuntime;
runtime.responseController = upstreamController;
runtime.clearResponseTimeout = vi.fn();
return createStreamingResponse({
label: "p1",
firstChunkDelayMs: 0,
controller: upstreamController,
});
});

const response = await ProxyForwarder.send(session);
expect(await response.text()).toContain('"provider":"p1"');

const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort");
expect(abortAddCalls).toHaveLength(1);
expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]);
});

test("pre-aborted client signal should settle hedge without launching upstream attempt", async () => {
const clientAbortController = new AbortController();
clientAbortController.abort(new Error("client_cancelled"));
const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener");
const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
const session = createSession(clientAbortController.signal);
setProviderWithSessionRef(session, provider);

const doForward = vi.spyOn(
ProxyForwarder as unknown as {
doForward: (...args: unknown[]) => Promise<Response>;
},
"doForward"
);

await expect(ProxyForwarder.send(session)).rejects.toMatchObject({ statusCode: 499 });
expect(doForward).not.toHaveBeenCalled();
expect(addSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0);
});
});
Loading
Loading